├── Bash_figure_4.sh ├── Bash_figure_5.sh ├── LICENSE ├── README.md ├── imgs └── noe2e_learning_curve.png └── src ├── deep_dialog ├── __init__.py ├── agents │ ├── __init__.py │ ├── agent.py │ ├── agent_baselines.py │ ├── agent_cmd.py │ └── agent_dqn.py ├── checkpoints │ ├── temp_run1 │ │ └── agt_9_performance_records.json │ └── temp_run2 │ │ └── agt_9_performance_records.json ├── data │ ├── count_uniq_slots.py │ ├── dia_act_nl_pairs.v6.json │ ├── dia_acts.txt │ ├── dicts.v3.json │ ├── dicts.v3.p │ ├── human_huamn_data_framed.json │ ├── human_huamn_data_framed_agent_first_turn.json │ ├── movie_kb.1k.json │ ├── movie_kb.1k.p │ ├── movie_kb.v2.json │ ├── movie_kb.v2.p │ ├── slot_set.txt │ ├── slot_set_small.txt │ ├── user_goals.json │ ├── user_goals_all_turns_template.p │ ├── user_goals_first_turn_template.part.movie.v1.p │ ├── user_goals_first_turn_template.v2.p │ └── user_goals_ids.json ├── dialog_config.py ├── dialog_system │ ├── __init__.py │ ├── dialog_manager.py │ ├── dict_reader.py │ ├── kb_helper.py │ ├── state_tracker.py │ └── utils.py ├── models │ ├── nlg │ │ ├── convert.py │ │ ├── lstm_tanh_relu_[1468202263.38]_2_0.610.p │ │ └── model.nlg │ └── nlu │ │ ├── convert.py │ │ └── lstm_[1468447442.91]_39_80_0.921.p ├── nlg │ ├── __init__.py │ ├── decoder.py │ ├── lstm_decoder_tanh.py │ ├── nlg.py │ └── utils.py ├── nlu │ ├── __init__.py │ ├── bi_lstm.py │ ├── lstm.py │ ├── nlu.py │ ├── seq_seq.py │ └── utils.py ├── qlearning │ ├── __init__.py │ ├── dqn.py │ ├── dqn_torch.py │ └── utils.py └── usersims │ ├── __init__.py │ ├── user_model.py │ ├── usersim.py │ ├── usersim_model.py │ └── usersim_rule.py ├── draw_learning_curve.py └── run.py /Bash_figure_4.sh: -------------------------------------------------------------------------------- 1 | #Below is the script used for figure 4 2 | for ((i=1; i<= 5; i++));do 3 | let "seed=$i*100" 4 | python run.py --agt 9 \ 5 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 6 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 7 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 8 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 9 | --warm_start 1 --warm_start_epochs 100 \ 10 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k0_run$i \ 11 | --planning_steps 0 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 1 12 | done 13 | 14 | for ((i=1; i<= 5; i++));do 15 | let "seed=$i*100" 16 | python run.py --agt 9 \ 17 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 18 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 19 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 20 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 21 | --warm_start 1 --warm_start_epochs 100 \ 22 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k2_run$i \ 23 | --planning_steps 1 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 1 24 | done 25 | 26 | for ((i=1; i<= 5; i++));do 27 | let "seed=$i*100" 28 | python run.py --agt 9 \ 29 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 30 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 31 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 32 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 33 | --warm_start 1 --warm_start_epochs 100 \ 34 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k5_run$i \ 35 | --planning_steps 4 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 1 36 | done 37 | 38 | for ((i=1; i<= 5; i++));do 39 | let "seed=$i*100" 40 | python run.py --agt 9 \ 41 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 42 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 43 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 44 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 45 | --warm_start 1 --warm_start_epochs 100 \ 46 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k10_run$i \ 47 | --planning_steps 9 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 1 48 | done 49 | 50 | for ((i=1; i<= 5; i++));do 51 | let "seed=$i*100" 52 | python run.py --agt 9 \ 53 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 54 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 55 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 56 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 57 | --warm_start 1 --warm_start_epochs 100 \ 58 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k20_run$i \ 59 | --planning_steps 19 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 1 60 | done -------------------------------------------------------------------------------- /Bash_figure_5.sh: -------------------------------------------------------------------------------- 1 | #Below is the script used for figure 5 2 | 3 | ##DQN 10, upper bound 4 | for ((i=1; i<= 5; i++));do 5 | let "seed=$i*100" 6 | python run.py --agt 9 \ 7 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 8 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 9 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 10 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 11 | --warm_start 1 --warm_start_epochs 100 \ 12 | --write_model_dir ./deep_dialog/checkpoints/DQN_k10_run$i \ 13 | --planning_steps 9 --torch_seed $seed --grounded 1 --boosted 1 --train_world_model 1 14 | done 15 | 16 | ##DDQ 10 17 | for ((i=1; i<= 5; i++));do 18 | let "seed=$i*100" 19 | python run.py --agt 9 \ 20 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 21 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 22 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 23 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 24 | --warm_start 1 --warm_start_epochs 100 \ 25 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k10_run$i \ 26 | --planning_steps 9 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 1 27 | done 28 | 29 | ##DDQ 10 rand-init 30 | for ((i=1; i<= 5; i++));do 31 | let "seed=$i*100" 32 | python run.py --agt 9 \ 33 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 34 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 35 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 36 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 37 | --warm_start 1 --warm_start_epochs 100 \ 38 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k10_rand_run$i \ 39 | --planning_steps 9 --torch_seed $seed --grounded 0 --boosted 0 --train_world_model 1 40 | done 41 | 42 | ##DDQ 10 fixed, run 5 or 10 to smooth the results 43 | for ((i=1; i<= 5; i++));do 44 | let "seed=$i*100" 45 | python run.py --agt 9 \ 46 | --usr 1 --max_turn 40 --movie_kb_path ./deep_dialog/data/movie_kb.1k.p --dqn_hidden_size 80 \ 47 | --experience_replay_pool_size 5000 --episodes 500 --simulation_epoch_size 100 \ 48 | --run_mode 3 --act_level 0 --slot_err_prob 0.0 --intent_err_prob 0.00 --batch_size 16 \ 49 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p \ 50 | --warm_start 1 --warm_start_epochs 100 \ 51 | --write_model_dir ./deep_dialog/checkpoints/DDQ_k10_fixed_run$i \ 52 | --planning_steps 9 --torch_seed $seed --grounded 0 --boosted 1 --train_world_model 0 53 | done -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 WISELab, MiuLab and Microsoft Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Dyna-Q: Integrating Planning for Task-Completion Dialogue Policy Learning 2 | *An implementation of the 3 | [Deep Dyna-Q: Integrating Planning for Task-Completion Dialogue Policy Learning](https://arxiv.org/abs/1801.06176)* 4 | 5 | This document describes how to run the simulation of DDQ Agent. 6 | 7 | ## Content 8 | * [Data](#data) 9 | * [Parameter](#parameter) 10 | * [Running Dialogue Agents](#running-dialogue-agents) 11 | * [Evaluation](#evaluation) 12 | * [Reference](#reference) 13 | 14 | ## Data 15 | all the data is under this folder: ./src/deep_dialog/data 16 | 17 | * Movie Knowledge Bases
18 | `movie_kb.1k.p` --- 94% success rate (for `user_goals_first_turn_template_subsets.v1.p`)
19 | `movie_kb.v2.p` --- 36% success rate (for `user_goals_first_turn_template_subsets.v1.p`) 20 | 21 | * User Goals
22 | `user_goals_first_turn_template.v2.p` --- user goals extracted from the first user turn
23 | `user_goals_first_turn_template.part.movie.v1.p` --- a subset of user goals [Please use this one, the upper bound success rate on movie_kb.1k.json is 0.9765.] 24 | 25 | * NLG Rule Template
26 | `dia_act_nl_pairs.v6.json` --- some predefined NLG rule templates for both User simulator and Agent. 27 | 28 | * Dialog Act Intent
29 | `dia_acts.txt` 30 | 31 | * Dialog Act Slot
32 | `slot_set.txt` 33 | 34 | ## Parameter 35 | 36 | ### Basic setting 37 | 38 | `--agt`: the agent id
39 | `--usr`: the user (simulator) id
40 | `--max_turn`: maximum turns
41 | `--episodes`: how many dialogues to run
42 | `--slot_err_prob`: slot level err probability
43 | `--slot_err_mode`: which kind of slot err mode
44 | `--intent_err_prob`: intent level err probability 45 | 46 | ### DDQ Agent setting 47 | `--grounded`: planning k steps with environment rather than world model, serving as a upper bound.
48 | `--boosted`: boost the world model with examles generated by rule agent
49 | `--train_world_model`: train world model on the fly
50 | 51 | 52 | ### Data setting 53 | 54 | `--movie_kb_path`: the movie kb path for agent side
55 | `--goal_file_path`: the user goal file path for user simulator side 56 | 57 | ### Model setting 58 | 59 | `--dqn_hidden_size`: hidden size for RL agent
60 | `--batch_size`: batch size for DDQ training
61 | `--simulation_epoch_size`: how many dialogue to be simulated in one epoch
62 | `--warm_start`: use rule policy to fill the experience replay buffer at the beginning
63 | `--warm_start_epochs`: how many dialogues to run in the warm start 64 | 65 | ### Display setting 66 | 67 | `--run_mode`: 0 for display mode (NL); 1 for debug mode (Dia_Act); 2 for debug mode (Dia_Act and NL); >3 for no display (i.e. training)
68 | `--act_level`: 0 for user simulator is Dia_Act level; 1 for user simulator is NL level
69 | `--auto_suggest`: 0 for no auto_suggest; 1 for auto_suggest
70 | `--cmd_input_mode`: 0 for NL input; 1 for Dia_Act input. (this parameter is for AgentCmd only) 71 | 72 | ### Others 73 | 74 | `--write_model_dir`: the directory to write the models
75 | `--trained_model_path`: the path of the trained RL agent model; load the trained model for prediction purpose. 76 | 77 | `--learning_phase`: train/test/all, default is all. You can split the user goal set into train and test set, or do not split (all); We introduce some randomness at the first sampled user action, even for the same user goal, the generated dialogue might be different.
78 | 79 | ## Running Dialogue Agents 80 | 81 | Train DDQ Agent with K planning steps: 82 | ```sh 83 | python run.py --agt 9 --usr 1 --max_turn 40 84 | --movie_kb_path ./deep_dialog/data/movie_kb.1k.p 85 | --dqn_hidden_size 80 --experience_replay_pool_size 5000 86 | --episodes 500 87 | --simulation_epoch_size 100 88 | --run_mode 3 89 | --act_level 0 90 | --slot_err_prob 0.0 91 | --intent_err_prob 0.00 92 | --batch_size 16 93 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p 94 | --warm_start 1 --warm_start_epochs 100 95 | --planning_steps K-1 96 | --write_model_dir ./deep_dialog/checkpoints/DDQAgent 97 | --torch_seed 100 98 | --grounded 0 99 | --boosted 1 100 | --train_world_model 1 101 | 102 | ``` 103 | Test RL Agent with N dialogues: 104 | ```sh 105 | python run.py --agt 9 --usr 1 --max_turn 40 106 | --movie_kb_path ./deep_dialog/data/movie_kb.1k.p 107 | --dqn_hidden_size 80 108 | --experience_replay_pool_size 1000 109 | --episodes 300 110 | --simulation_epoch_size 100 111 | --write_model_dir ./deep_dialog/checkpoints/DDQAgent/ 112 | --slot_err_prob 0.00 113 | --intent_err_prob 0.00 114 | --batch_size 16 115 | --goal_file_path ./deep_dialog/data/user_goals_first_turn_template.part.movie.v1.p 116 | --trained_model_path ./deep_dialog/checkpoints/DDQAgent/TRAINED_MODEL 117 | --run_mode 3 118 | ``` 119 | ## Experiments 120 | To run the scripts, move the two bash files under src folder. 121 | 1. Bash_figure_4.sh is the script for figure 4. 122 | 2. Bash_figure_5.sh is the script for figure 5. 123 | 124 | ## Evaluation 125 | To evaluate the performance of agents, three metrics are available: success rate, average reward, average turns. Here we show the learning curve with success rate. 126 | 127 | 1. Plotting Learning Curve 128 | ``` python draw_learning_curve.py --result_file ./deep_dialog/checkpoints/DDQAgent/noe2e/TRAINED_MODEL.json``` 129 | 2. Pull out the numbers and draw the curves in Excel 130 | 131 | ## Reference 132 | 133 | Main papers to be cited 134 | ``` 135 | @inproceedings{Peng2018DeepDynaQ, 136 | title={Deep Dyna-Q: Integrating Planning for Task-Completion Dialogue Policy Learning}, 137 | author={Peng, Baolin and Li, Xiujun and Gao, Jianfeng and Liu, Jingjing and Wong, Kam-Fai and Su, Shang-Yu}, 138 | booktitle={ACL}, 139 | year={2018} 140 | } 141 | 142 | @article{li2016user, 143 | title={A User Simulator for Task-Completion Dialogues}, 144 | author={Li, Xiujun and Lipton, Zachary C and Dhingra, Bhuwan and Li, Lihong and Gao, Jianfeng and Chen, Yun-Nung}, 145 | journal={arXiv preprint arXiv:1612.05688}, 146 | year={2016} 147 | } 148 | -------------------------------------------------------------------------------- /imgs/noe2e_learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiuLab/DDQ/f65611c2358581bb72be61b5b389b1e3c046b73d/imgs/noe2e_learning_curve.png -------------------------------------------------------------------------------- /src/deep_dialog/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/deep_dialog/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_cmd import * 2 | from .agent_baselines import * 3 | from .agent_dqn import * -------------------------------------------------------------------------------- /src/deep_dialog/agents/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 17, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | from deep_dialog import dialog_config 8 | 9 | class Agent: 10 | """ Prototype for all agent classes, defining the interface they must uphold """ 11 | 12 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, params=None): 13 | """ Constructor for the Agent class 14 | 15 | Arguments: 16 | movie_dict -- This is here now but doesn't belong - the agent doesn't know about movies 17 | act_set -- The set of acts. #### Shouldn't this be more abstract? Don't we want our agent to be more broadly usable? 18 | slot_set -- The set of available slots 19 | """ 20 | self.movie_dict = movie_dict 21 | self.act_set = act_set 22 | self.slot_set = slot_set 23 | self.act_cardinality = len(act_set.keys()) 24 | self.slot_cardinality = len(slot_set.keys()) 25 | 26 | self.epsilon = params['epsilon'] 27 | self.agent_run_mode = params['agent_run_mode'] 28 | self.agent_act_level = params['agent_act_level'] 29 | 30 | 31 | def initialize_episode(self): 32 | """ Initialize a new episode. This function is called every time a new episode is run. """ 33 | self.current_action = {} # TODO Changed this variable's name to current_action 34 | self.current_action['diaact'] = None # TODO Does it make sense to call it a state if it has an act? Which act? The Most recent? 35 | self.current_action['inform_slots'] = {} 36 | self.current_action['request_slots'] = {} 37 | self.current_action['turn'] = 0 38 | 39 | def state_to_action(self, state, available_actions): 40 | """ Take the current state and return an action according to the current exploration/exploitation policy 41 | 42 | We define the agents flexibly so that they can either operate on act_slot representations or act_slot_value representations. 43 | We also define the responses flexibly, returning a dictionary with keys [act_slot_response, act_slot_value_response]. This way the command-line agent can continue to operate with values 44 | 45 | Arguments: 46 | state -- A tuple of (history, kb_results) where history is a sequence of previous actions and kb_results contains information on the number of results matching the current constraints. 47 | user_action -- A legacy representation used to run the command line agent. We should remove this ASAP but not just yet 48 | available_actions -- A list of the allowable actions in the current state 49 | 50 | Returns: 51 | act_slot_action -- An action consisting of one act and >= 0 slots as well as which slots are informed vs requested. 52 | act_slot_value_action -- An action consisting of acts slots and values in the legacy format. This can be used in the future for training agents that take value into account and interact directly with the database 53 | """ 54 | act_slot_response = None 55 | act_slot_value_response = None 56 | return {"act_slot_response": act_slot_response, "act_slot_value_response": act_slot_value_response} 57 | 58 | 59 | def register_experience_replay_tuple(self, s_t, a_t, reward, s_tplus1, episode_over): 60 | """ Register feedback from the environment, to be stored as future training data 61 | 62 | Arguments: 63 | s_t -- The state in which the last action was taken 64 | a_t -- The previous agent action 65 | reward -- The reward received immediately following the action 66 | s_tplus1 -- The state transition following the latest action 67 | episode_over -- A boolean value representing whether the this is the final action. 68 | 69 | Returns: 70 | None 71 | """ 72 | pass 73 | 74 | 75 | def set_nlg_model(self, nlg_model): 76 | self.nlg_model = nlg_model 77 | 78 | def set_nlu_model(self, nlu_model): 79 | self.nlu_model = nlu_model 80 | 81 | 82 | def add_nl_to_action(self, agent_action): 83 | """ Add NL to Agent Dia_Act """ 84 | 85 | if agent_action['act_slot_response']: 86 | agent_action['act_slot_response']['nl'] = "" 87 | #TODO 88 | user_nlg_sentence = self.nlg_model.convert_diaact_to_nl(agent_action['act_slot_response'], 'agt') #self.nlg_model.translate_diaact(agent_action['act_slot_response']) # NLG 89 | agent_action['act_slot_response']['nl'] = user_nlg_sentence 90 | elif agent_action['act_slot_value_response']: 91 | agent_action['act_slot_value_response']['nl'] = "" 92 | user_nlg_sentence = self.nlg_model.convert_diaact_to_nl(agent_action['act_slot_value_response'], 'agt') #self.nlg_model.translate_diaact(agent_action['act_slot_value_response']) # NLG 93 | agent_action['act_slot_response']['nl'] = user_nlg_sentence -------------------------------------------------------------------------------- /src/deep_dialog/agents/agent_baselines.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 25, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | import copy, random 8 | from deep_dialog import dialog_config 9 | from agent import Agent 10 | 11 | 12 | class InformAgent(Agent): 13 | """ A simple agent to test the system. This agent should simply inform all the slots and then issue: taskcomplete. """ 14 | 15 | def initialize_episode(self): 16 | self.state = {} 17 | self.state['diaact'] = '' 18 | self.state['inform_slots'] = {} 19 | self.state['request_slots'] = {} 20 | self.state['turn'] = -1 21 | self.current_slot_id = 0 22 | 23 | def state_to_action(self, state): 24 | """ Run current policy on state and produce an action """ 25 | 26 | self.state['turn'] += 2 27 | if self.current_slot_id < len(self.slot_set.keys()): 28 | slot = self.slot_set.keys()[self.current_slot_id] 29 | self.current_slot_id += 1 30 | 31 | act_slot_response = {} 32 | act_slot_response['diaact'] = "inform" 33 | act_slot_response['inform_slots'] = {slot: "PLACEHOLDER"} 34 | act_slot_response['request_slots'] = {} 35 | act_slot_response['turn'] = self.state['turn'] 36 | else: 37 | act_slot_response = {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {}, 'turn': self.state['turn']} 38 | return {'act_slot_response': act_slot_response, 'act_slot_value_response': None} 39 | 40 | 41 | 42 | class RequestAllAgent(Agent): 43 | """ A simple agent to test the system. This agent should simply request all the slots and then issue: thanks(). """ 44 | 45 | def initialize_episode(self): 46 | self.state = {} 47 | self.state['diaact'] = '' 48 | self.state['inform_slots'] = {} 49 | self.state['request_slots'] = {} 50 | self.state['turn'] = -1 51 | self.current_slot_id = 0 52 | 53 | def state_to_action(self, state): 54 | """ Run current policy on state and produce an action """ 55 | 56 | self.state['turn'] += 2 57 | if self.current_slot_id < len(dialog_config.sys_request_slots): 58 | slot = dialog_config.sys_request_slots[self.current_slot_id] 59 | self.current_slot_id += 1 60 | 61 | act_slot_response = {} 62 | act_slot_response['diaact'] = "request" 63 | act_slot_response['inform_slots'] = {} 64 | act_slot_response['request_slots'] = {slot: "PLACEHOLDER"} 65 | act_slot_response['turn'] = self.state['turn'] 66 | else: 67 | act_slot_response = {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {}, 'turn': self.state['turn']} 68 | return {'act_slot_response': act_slot_response, 'act_slot_value_response': None} 69 | 70 | 71 | 72 | class RandomAgent(Agent): 73 | """ A simple agent to test the interface. This agent should choose actions randomly. """ 74 | 75 | def initialize_episode(self): 76 | self.state = {} 77 | self.state['diaact'] = '' 78 | self.state['inform_slots'] = {} 79 | self.state['request_slots'] = {} 80 | self.state['turn'] = -1 81 | 82 | 83 | def state_to_action(self, state): 84 | """ Run current policy on state and produce an action """ 85 | 86 | self.state['turn'] += 2 87 | act_slot_response = copy.deepcopy(random.choice(dialog_config.feasible_actions)) 88 | act_slot_response['turn'] = self.state['turn'] 89 | return {'act_slot_response': act_slot_response, 'act_slot_value_response': None} 90 | 91 | 92 | 93 | class EchoAgent(Agent): 94 | """ A simple agent that informs all requested slots, then issues inform(taskcomplete) when the user stops making requests. """ 95 | 96 | def initialize_episode(self): 97 | self.state = {} 98 | self.state['diaact'] = '' 99 | self.state['inform_slots'] = {} 100 | self.state['request_slots'] = {} 101 | self.state['turn'] = -1 102 | 103 | 104 | def state_to_action(self, state): 105 | """ Run current policy on state and produce an action """ 106 | user_action = state['user_action'] 107 | 108 | self.state['turn'] += 2 109 | act_slot_response = {} 110 | act_slot_response['inform_slots'] = {} 111 | act_slot_response['request_slots'] = {} 112 | ######################################################################## 113 | # find out if the user is requesting anything 114 | # if so, inform it 115 | ######################################################################## 116 | if user_action['diaact'] == 'request': 117 | requested_slot = user_action['request_slots'].keys()[0] 118 | 119 | act_slot_response['diaact'] = "inform" 120 | act_slot_response['inform_slots'][requested_slot] = "PLACEHOLDER" 121 | else: 122 | act_slot_response['diaact'] = "thanks" 123 | 124 | act_slot_response['turn'] = self.state['turn'] 125 | return {'act_slot_response': act_slot_response, 'act_slot_value_response': None} 126 | 127 | 128 | class RequestBasicsAgent(Agent): 129 | """ A simple agent to test the system. This agent should simply request all the basic slots and then issue: thanks(). """ 130 | 131 | def initialize_episode(self): 132 | self.state = {} 133 | self.state['diaact'] = 'UNK' 134 | self.state['inform_slots'] = {} 135 | self.state['request_slots'] = {} 136 | self.state['turn'] = -1 137 | self.current_slot_id = 0 138 | self.request_set = ['moviename', 'starttime', 'city', 'date', 'theater', 'numberofpeople'] 139 | self.phase = 0 140 | 141 | def state_to_action(self, state): 142 | """ Run current policy on state and produce an action """ 143 | 144 | self.state['turn'] += 2 145 | if self.current_slot_id < len(self.request_set): 146 | slot = self.request_set[self.current_slot_id] 147 | self.current_slot_id += 1 148 | 149 | act_slot_response = {} 150 | act_slot_response['diaact'] = "request" 151 | act_slot_response['inform_slots'] = {} 152 | act_slot_response['request_slots'] = {slot: "UNK"} 153 | act_slot_response['turn'] = self.state['turn'] 154 | elif self.phase == 0: 155 | act_slot_response = {'diaact': "inform", 'inform_slots': {'taskcomplete': "PLACEHOLDER"}, 'request_slots': {}, 'turn':self.state['turn']} 156 | self.phase += 1 157 | elif self.phase == 1: 158 | act_slot_response = {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {}, 'turn': self.state['turn']} 159 | else: 160 | raise Exception("THIS SHOULD NOT BE POSSIBLE (AGENT CALLED IN UNANTICIPATED WAY)") 161 | return {'act_slot_response': act_slot_response, 'act_slot_value_response': None} 162 | 163 | -------------------------------------------------------------------------------- /src/deep_dialog/agents/agent_cmd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 17, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | 8 | from agent import Agent 9 | 10 | class AgentCmd(Agent): 11 | 12 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, params=None): 13 | """ Constructor for the Agent class """ 14 | 15 | self.movie_dict = movie_dict 16 | self.act_set = act_set 17 | self.slot_set = slot_set 18 | self.act_cardinality = len(act_set.keys()) 19 | self.slot_cardinality = len(slot_set.keys()) 20 | 21 | self.agent_run_mode = params['agent_run_mode'] 22 | self.agent_act_level = params['agent_act_level'] 23 | self.agent_input_mode = params['cmd_input_mode'] 24 | 25 | 26 | def state_to_action(self, state): 27 | """ Generate an action by getting input interactively from the command line """ 28 | 29 | user_action = state['user_action'] 30 | # get input from the command line 31 | print "Turn", user_action['turn'] + 1, "sys:", 32 | command = raw_input() 33 | 34 | if self.agent_input_mode == 0: # nl 35 | act_slot_value_response = self.generate_diaact_from_nl(command) 36 | elif self.agent_input_mode == 1: # dia_act 37 | act_slot_value_response = self.parse_str_to_diaact(command) 38 | 39 | return {"act_slot_response": act_slot_value_response, "act_slot_value_response": act_slot_value_response} 40 | 41 | def parse_str_to_diaact(self, string): 42 | """ Parse string into Dia_Act Form """ 43 | 44 | annot = string.strip(' ').strip('\n').strip('\r') 45 | act = annot 46 | 47 | if annot.find('(') > 0 and annot.find(')') > 0: 48 | act = annot[0: annot.find('(')].strip(' ').lower() #Dia act 49 | annot = annot[annot.find('(')+1:-1].strip(' ') #slot-value pairs 50 | else: annot = '' 51 | 52 | act_slot_value_response = {} 53 | act_slot_value_response['diaact'] = 'UNK' 54 | act_slot_value_response['inform_slots'] = {} 55 | act_slot_value_response['request_slots'] = {} 56 | 57 | if act in self.act_set: # dialog_config.all_acts 58 | act_slot_value_response['diaact'] = act 59 | else: 60 | print ("Something wrong for your input dialog act! Please check your input ...") 61 | 62 | if len(annot) > 0: # slot-pair values: slot[val] = id 63 | annot_segs = annot.split(';') #slot-value pairs 64 | sent_slot_vals = {} # slot-pair real value 65 | sent_rep_vals = {} # slot-pair id value 66 | 67 | for annot_seg in annot_segs: 68 | annot_seg = annot_seg.strip(' ') 69 | annot_slot = annot_seg 70 | if annot_seg.find('=') > 0: 71 | annot_slot = annot_seg[:annot_seg.find('=')] 72 | annot_val = annot_seg[annot_seg.find('=')+1:] 73 | else: #requested 74 | annot_val = 'UNK' # for request 75 | if annot_slot == 'taskcomplete': annot_val = 'FINISH' 76 | 77 | if annot_slot == 'mc_list': continue 78 | 79 | # slot may have multiple values 80 | sent_slot_vals[annot_slot] = [] 81 | sent_rep_vals[annot_slot] = [] 82 | 83 | if annot_val.startswith('{') and annot_val.endswith('}'): 84 | annot_val = annot_val[1:-1] 85 | 86 | if annot_slot == 'result': 87 | result_annot_seg_arr = annot_val.strip(' ').split('&') 88 | if len(annot_val.strip(' '))> 0: 89 | for result_annot_seg_item in result_annot_seg_arr: 90 | result_annot_seg_arr = result_annot_seg_item.strip(' ').split('=') 91 | result_annot_seg_slot = result_annot_seg_arr[0] 92 | result_annot_seg_slot_val = result_annot_seg_arr[1] 93 | 94 | if result_annot_seg_slot_val == 'UNK': act_slot_value_response['request_slots'][result_annot_seg_slot] = 'UNK' 95 | else: act_slot_value_response['inform_slots'][result_annot_seg_slot] = result_annot_seg_slot_val 96 | else: # result={} 97 | pass 98 | else: # multi-choice or mc_list 99 | annot_val_arr = annot_val.split('#') 100 | act_slot_value_response['inform_slots'][annot_slot] = [] 101 | for annot_val_ele in annot_val_arr: 102 | act_slot_value_response['inform_slots'][annot_slot].append(annot_val_ele) 103 | else: # single choice 104 | if annot_slot in self.slot_set.keys(): 105 | if annot_val == 'UNK': 106 | act_slot_value_response['request_slots'][annot_slot] = 'UNK' 107 | else: 108 | act_slot_value_response['inform_slots'][annot_slot] = annot_val 109 | 110 | return act_slot_value_response 111 | 112 | def generate_diaact_from_nl(self, string): 113 | """ Generate Dia_Act Form with NLU """ 114 | 115 | agent_action = {} 116 | agent_action['diaact'] = 'UNK' 117 | agent_action['inform_slots'] = {} 118 | agent_action['request_slots'] = {} 119 | 120 | if len(string) > 0: 121 | agent_action = self.nlu_model.generate_dia_act(string) 122 | 123 | agent_action['nl'] = string 124 | return agent_action 125 | 126 | def add_nl_to_action(self, agent_action): 127 | """ Add NL to Agent Dia_Act """ 128 | 129 | if self.agent_input_mode == 1: 130 | if agent_action['act_slot_response']: 131 | agent_action['act_slot_response']['nl'] = "" 132 | user_nlg_sentence = self.nlg_model.convert_diaact_to_nl(agent_action['act_slot_response'], 'agt') 133 | agent_action['act_slot_response']['nl'] = user_nlg_sentence 134 | elif agent_action['act_slot_value_response']: 135 | agent_action['act_slot_value_response']['nl'] = "" 136 | user_nlg_sentence = self.nlg_model.convert_diaact_to_nl(agent_action['act_slot_value_response'], 'agt') 137 | agent_action['act_slot_response']['nl'] = user_nlg_sentence 138 | -------------------------------------------------------------------------------- /src/deep_dialog/agents/agent_dqn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 30, 2017 3 | 4 | An DQN Agent modified for DDQ Agent 5 | 6 | Some methods are not consistent with super class Agent. 7 | 8 | @author: Baolin Peng 9 | ''' 10 | 11 | import random, copy, json 12 | import cPickle as pickle 13 | import numpy as np 14 | from collections import namedtuple, deque 15 | 16 | from deep_dialog import dialog_config 17 | 18 | from agent import Agent 19 | from deep_dialog.qlearning import DQN 20 | 21 | import torch 22 | import torch.optim as optim 23 | import torch.nn.functional as F 24 | 25 | DEVICE = torch.device('cpu') 26 | 27 | Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'term')) 28 | 29 | 30 | class AgentDQN(Agent): 31 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, params=None): 32 | self.movie_dict = movie_dict 33 | self.act_set = act_set 34 | self.slot_set = slot_set 35 | self.act_cardinality = len(act_set.keys()) 36 | self.slot_cardinality = len(slot_set.keys()) 37 | 38 | self.feasible_actions = dialog_config.feasible_actions 39 | self.num_actions = len(self.feasible_actions) 40 | 41 | self.epsilon = params['epsilon'] 42 | self.agent_run_mode = params['agent_run_mode'] 43 | self.agent_act_level = params['agent_act_level'] 44 | 45 | self.experience_replay_pool_size = params.get('experience_replay_pool_size', 5000) 46 | self.experience_replay_pool = deque( 47 | maxlen=self.experience_replay_pool_size) # experience replay pool 48 | self.experience_replay_pool_from_model = deque( 49 | maxlen=self.experience_replay_pool_size) # experience replay pool 50 | self.running_expereince_pool = None # hold experience from both user and world model 51 | 52 | self.hidden_size = params.get('dqn_hidden_size', 60) 53 | self.gamma = params.get('gamma', 0.9) 54 | self.predict_mode = params.get('predict_mode', False) 55 | self.warm_start = params.get('warm_start', 0) 56 | 57 | self.max_turn = params['max_turn'] + 5 58 | self.state_dimension = 2 * self.act_cardinality + 7 * self.slot_cardinality + 3 + self.max_turn 59 | 60 | self.dqn = DQN(self.state_dimension, self.hidden_size, self.num_actions).to(DEVICE) 61 | self.target_dqn = DQN(self.state_dimension, self.hidden_size, self.num_actions).to(DEVICE) 62 | self.target_dqn.load_state_dict(self.dqn.state_dict()) 63 | self.target_dqn.eval() 64 | 65 | self.optimizer = optim.RMSprop(self.dqn.parameters(), lr=1e-3) 66 | 67 | self.cur_bellman_err = 0 68 | 69 | # Prediction Mode: load trained DQN model 70 | if params['trained_model_path'] != None: 71 | self.load(params['trained_model_path']) 72 | self.predict_mode = True 73 | self.warm_start = 2 74 | 75 | def initialize_episode(self): 76 | """ Initialize a new episode. This function is called every time a new episode is run. """ 77 | 78 | self.current_slot_id = 0 79 | self.phase = 0 80 | self.request_set = ['moviename', 'starttime', 'city', 'date', 'theater', 'numberofpeople'] 81 | 82 | def state_to_action(self, state): 83 | """ DQN: Input state, output action """ 84 | # self.state['turn'] += 2 85 | self.representation = self.prepare_state_representation(state) 86 | self.action = self.run_policy(self.representation) 87 | if self.warm_start == 1: 88 | act_slot_response = copy.deepcopy(self.feasible_actions[self.action]) 89 | else: 90 | act_slot_response = copy.deepcopy(self.feasible_actions[self.action[0]]) 91 | 92 | return {'act_slot_response': act_slot_response, 'act_slot_value_response': None} 93 | 94 | def prepare_state_representation(self, state): 95 | """ Create the representation for each state """ 96 | 97 | user_action = state['user_action'] 98 | current_slots = state['current_slots'] 99 | kb_results_dict = state['kb_results_dict'] 100 | agent_last = state['agent_action'] 101 | 102 | ######################################################################## 103 | # Create one-hot of acts to represent the current user action 104 | ######################################################################## 105 | user_act_rep = np.zeros((1, self.act_cardinality)) 106 | user_act_rep[0, self.act_set[user_action['diaact']]] = 1.0 107 | 108 | ######################################################################## 109 | # Create bag of inform slots representation to represent the current user action 110 | ######################################################################## 111 | user_inform_slots_rep = np.zeros((1, self.slot_cardinality)) 112 | for slot in user_action['inform_slots'].keys(): 113 | user_inform_slots_rep[0, self.slot_set[slot]] = 1.0 114 | 115 | ######################################################################## 116 | # Create bag of request slots representation to represent the current user action 117 | ######################################################################## 118 | user_request_slots_rep = np.zeros((1, self.slot_cardinality)) 119 | for slot in user_action['request_slots'].keys(): 120 | user_request_slots_rep[0, self.slot_set[slot]] = 1.0 121 | 122 | ######################################################################## 123 | # Creat bag of filled_in slots based on the current_slots 124 | ######################################################################## 125 | current_slots_rep = np.zeros((1, self.slot_cardinality)) 126 | for slot in current_slots['inform_slots']: 127 | current_slots_rep[0, self.slot_set[slot]] = 1.0 128 | 129 | ######################################################################## 130 | # Encode last agent act 131 | ######################################################################## 132 | agent_act_rep = np.zeros((1, self.act_cardinality)) 133 | if agent_last: 134 | agent_act_rep[0, self.act_set[agent_last['diaact']]] = 1.0 135 | 136 | ######################################################################## 137 | # Encode last agent inform slots 138 | ######################################################################## 139 | agent_inform_slots_rep = np.zeros((1, self.slot_cardinality)) 140 | if agent_last: 141 | for slot in agent_last['inform_slots'].keys(): 142 | agent_inform_slots_rep[0, self.slot_set[slot]] = 1.0 143 | 144 | ######################################################################## 145 | # Encode last agent request slots 146 | ######################################################################## 147 | agent_request_slots_rep = np.zeros((1, self.slot_cardinality)) 148 | if agent_last: 149 | for slot in agent_last['request_slots'].keys(): 150 | agent_request_slots_rep[0, self.slot_set[slot]] = 1.0 151 | 152 | # turn_rep = np.zeros((1,1)) + state['turn'] / 10. 153 | turn_rep = np.zeros((1, 1)) 154 | 155 | ######################################################################## 156 | # One-hot representation of the turn count? 157 | ######################################################################## 158 | turn_onehot_rep = np.zeros((1, self.max_turn)) 159 | turn_onehot_rep[0, state['turn']] = 1.0 160 | 161 | # ######################################################################## 162 | # # Representation of KB results (scaled counts) 163 | # ######################################################################## 164 | # kb_count_rep = np.zeros((1, self.slot_cardinality + 1)) + kb_results_dict['matching_all_constraints'] / 100. 165 | # for slot in kb_results_dict: 166 | # if slot in self.slot_set: 167 | # kb_count_rep[0, self.slot_set[slot]] = kb_results_dict[slot] / 100. 168 | # 169 | # ######################################################################## 170 | # # Representation of KB results (binary) 171 | # ######################################################################## 172 | # kb_binary_rep = np.zeros((1, self.slot_cardinality + 1)) + np.sum( kb_results_dict['matching_all_constraints'] > 0.) 173 | # for slot in kb_results_dict: 174 | # if slot in self.slot_set: 175 | # kb_binary_rep[0, self.slot_set[slot]] = np.sum( kb_results_dict[slot] > 0.) 176 | 177 | kb_count_rep = np.zeros((1, self.slot_cardinality + 1)) 178 | 179 | ######################################################################## 180 | # Representation of KB results (binary) 181 | ######################################################################## 182 | kb_binary_rep = np.zeros((1, self.slot_cardinality + 1)) 183 | 184 | self.final_representation = np.hstack( 185 | [user_act_rep, user_inform_slots_rep, user_request_slots_rep, agent_act_rep, agent_inform_slots_rep, 186 | agent_request_slots_rep, current_slots_rep, turn_rep, turn_onehot_rep, kb_binary_rep, kb_count_rep]) 187 | return self.final_representation 188 | 189 | def run_policy(self, representation): 190 | """ epsilon-greedy policy """ 191 | 192 | if random.random() < self.epsilon: 193 | return random.randint(0, self.num_actions - 1) 194 | else: 195 | if self.warm_start == 1: 196 | if len(self.experience_replay_pool) > self.experience_replay_pool_size: 197 | self.warm_start = 2 198 | return self.rule_policy() 199 | else: 200 | return self.DQN_policy(representation) 201 | 202 | def rule_policy(self): 203 | """ Rule Policy """ 204 | 205 | act_slot_response = {} 206 | 207 | if self.current_slot_id < len(self.request_set): 208 | slot = self.request_set[self.current_slot_id] 209 | self.current_slot_id += 1 210 | 211 | act_slot_response = {} 212 | act_slot_response['diaact'] = "request" 213 | act_slot_response['inform_slots'] = {} 214 | act_slot_response['request_slots'] = {slot: "UNK"} 215 | elif self.phase == 0: 216 | act_slot_response = {'diaact': "inform", 'inform_slots': {'taskcomplete': "PLACEHOLDER"}, 217 | 'request_slots': {}} 218 | self.phase += 1 219 | elif self.phase == 1: 220 | act_slot_response = {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {}} 221 | 222 | return self.action_index(act_slot_response) 223 | 224 | def DQN_policy(self, state_representation): 225 | """ Return action from DQN""" 226 | 227 | with torch.no_grad(): 228 | action = self.dqn.predict(torch.FloatTensor(state_representation)) 229 | return action 230 | 231 | def action_index(self, act_slot_response): 232 | """ Return the index of action """ 233 | 234 | for (i, action) in enumerate(self.feasible_actions): 235 | if act_slot_response == action: 236 | return i 237 | print act_slot_response 238 | raise Exception("action index not found") 239 | return None 240 | 241 | def register_experience_replay_tuple(self, s_t, a_t, reward, s_tplus1, episode_over, st_user, from_model=False): 242 | """ Register feedback from either environment or world model, to be stored as future training data """ 243 | 244 | state_t_rep = self.prepare_state_representation(s_t) 245 | action_t = self.action 246 | reward_t = reward 247 | state_tplus1_rep = self.prepare_state_representation(s_tplus1) 248 | st_user = self.prepare_state_representation(s_tplus1) 249 | training_example = (state_t_rep, action_t, reward_t, state_tplus1_rep, episode_over, st_user) 250 | 251 | if self.predict_mode == False: # Training Mode 252 | if self.warm_start == 1: 253 | self.experience_replay_pool.append(training_example) 254 | else: # Prediction Mode 255 | if not from_model: 256 | self.experience_replay_pool.append(training_example) 257 | else: 258 | self.experience_replay_pool_from_model.append(training_example) 259 | 260 | def sample_from_buffer(self, batch_size): 261 | """Sample batch size examples from experience buffer and convert it to torch readable format""" 262 | # type: (int, ) -> Transition 263 | 264 | batch = [random.choice(self.running_expereince_pool) for i in xrange(batch_size)] 265 | np_batch = [] 266 | for x in range(len(Transition._fields)): 267 | v = [] 268 | for i in xrange(batch_size): 269 | v.append(batch[i][x]) 270 | np_batch.append(np.vstack(v)) 271 | 272 | return Transition(*np_batch) 273 | 274 | def train(self, batch_size=1, num_batches=100): 275 | """ Train DQN with experience buffer that comes from both user and world model interaction.""" 276 | 277 | self.cur_bellman_err = 0. 278 | self.cur_bellman_err_planning = 0. 279 | self.running_expereince_pool = list(self.experience_replay_pool) + list(self.experience_replay_pool_from_model) 280 | 281 | for iter_batch in range(num_batches): 282 | for iter in range(len(self.running_expereince_pool) / (batch_size)): 283 | self.optimizer.zero_grad() 284 | batch = self.sample_from_buffer(batch_size) 285 | 286 | state_value = self.dqn(torch.FloatTensor(batch.state)).gather(1, torch.tensor(batch.action)) 287 | next_state_value, _ = self.target_dqn(torch.FloatTensor(batch.next_state)).max(1) 288 | next_state_value = next_state_value.unsqueeze(1) 289 | term = np.asarray(batch.term, dtype=np.float32) 290 | expected_value = torch.FloatTensor(batch.reward) + self.gamma * next_state_value * ( 291 | 1 - torch.FloatTensor(term)) 292 | 293 | loss = F.mse_loss(state_value, expected_value) 294 | loss.backward() 295 | self.optimizer.step() 296 | self.cur_bellman_err += loss.item() 297 | 298 | if len(self.experience_replay_pool) != 0: 299 | print ( 300 | "cur bellman err %.4f, experience replay pool %s, model replay pool %s, cur bellman err for planning %.4f" % ( 301 | float(self.cur_bellman_err) / (len(self.experience_replay_pool) / (float(batch_size))), 302 | len(self.experience_replay_pool), len(self.experience_replay_pool_from_model), 303 | self.cur_bellman_err_planning)) 304 | 305 | # def train_one_iter(self, batch_size=1, num_batches=100, planning=False): 306 | # """ Train DQN with experience replay """ 307 | # self.cur_bellman_err = 0 308 | # self.cur_bellman_err_planning = 0 309 | # running_expereince_pool = self.experience_replay_pool + self.experience_replay_pool_from_model 310 | # for iter_batch in range(num_batches): 311 | # batch = [random.choice(self.experience_replay_pool) for i in xrange(batch_size)] 312 | # np_batch = [] 313 | # for x in range(5): 314 | # v = [] 315 | # for i in xrange(len(batch)): 316 | # v.append(batch[i][x]) 317 | # np_batch.append(np.vstack(v)) 318 | # 319 | # batch_struct = self.dqn.singleBatch(np_batch) 320 | # self.cur_bellman_err += batch_struct['cost']['total_cost'] 321 | # if planning: 322 | # plan_step = 3 323 | # for _ in xrange(plan_step): 324 | # batch_planning = [random.choice(self.experience_replay_pool) for i in 325 | # xrange(batch_size)] 326 | # np_batch_planning = [] 327 | # for x in range(5): 328 | # v = [] 329 | # for i in xrange(len(batch_planning)): 330 | # v.append(batch_planning[i][x]) 331 | # np_batch_planning.append(np.vstack(v)) 332 | # 333 | # s_tp1, r, t = self.user_planning.predict(np_batch_planning[0], np_batch_planning[1]) 334 | # s_tp1[np.where(s_tp1 >= 0.5)] = 1 335 | # s_tp1[np.where(s_tp1 <= 0.5)] = 0 336 | # 337 | # t[np.where(t >= 0.5)] = 1 338 | # 339 | # np_batch_planning[2] = r 340 | # np_batch_planning[3] = s_tp1 341 | # np_batch_planning[4] = t 342 | # 343 | # batch_struct = self.dqn.singleBatch(np_batch_planning) 344 | # self.cur_bellman_err_planning += batch_struct['cost']['total_cost'] 345 | # 346 | # if len(self.experience_replay_pool) != 0: 347 | # print ("cur bellman err %.4f, experience replay pool %s, cur bellman err for planning %.4f" % ( 348 | # float(self.cur_bellman_err) / (len(self.experience_replay_pool) / (float(batch_size))), 349 | # len(self.experience_replay_pool), self.cur_bellman_err_planning)) 350 | 351 | ################################################################################ 352 | # Debug Functions 353 | ################################################################################ 354 | def save_experience_replay_to_file(self, path): 355 | """ Save the experience replay pool to a file """ 356 | 357 | try: 358 | pickle.dump(self.experience_replay_pool, open(path, "wb")) 359 | print 'saved model in %s' % (path,) 360 | except Exception, e: 361 | print 'Error: Writing model fails: %s' % (path,) 362 | print e 363 | 364 | def load_experience_replay_from_file(self, path): 365 | """ Load the experience replay pool from a file""" 366 | 367 | self.experience_replay_pool = pickle.load(open(path, 'rb')) 368 | 369 | def load_trained_DQN(self, path): 370 | """ Load the trained DQN from a file """ 371 | 372 | trained_file = pickle.load(open(path, 'rb')) 373 | model = trained_file['model'] 374 | print "Trained DQN Parameters:", json.dumps(trained_file['params'], indent=2) 375 | return model 376 | 377 | def set_user_planning(self, user_planning): 378 | self.user_planning = user_planning 379 | 380 | def save(self, filename): 381 | torch.save(self.dqn.state_dict(), filename) 382 | 383 | def load(self, filename): 384 | self.dqn.load_state_dict(torch.load(filename)) 385 | 386 | def reset_dqn_target(self): 387 | self.target_dqn.load_state_dict(self.dqn.state_dict()) 388 | -------------------------------------------------------------------------------- /src/deep_dialog/checkpoints/temp_run1/agt_9_performance_records.json: -------------------------------------------------------------------------------- 1 | {"ave_turns": {"0": 32.76, "1": 42.0, "2": 12.56, "3": 4.0, "4": 5.52, "5": 14.2, "6": 21.92, "7": 16.28, "8": 22.32, "9": 10.04, "10": 5.8, "11": 6.64, "12": 6.96, "13": 6.24, "14": 9.88, "15": 8.84, "16": 9.52, "17": 9.68, "18": 11.12, "19": 13.56, "20": 30.52, "21": 24.16, "22": 17.44, "23": 36.92, "24": 18.32, "25": 15.04, "26": 28.32, "27": 20.08, "28": 25.0, "29": 32.48, "30": 34.6, "31": 39.0, "32": 34.2, "33": 28.2, "34": 30.0, "35": 39.2, "36": 39.2, "37": 39.2, "38": 32.32, "39": 31.32, "40": 30.8, "41": 34.44, "42": 34.6, "43": 40.28, "44": 36.48, "45": 34.16, "46": 38.68, "47": 38.08, "48": 33.8, "49": 36.8, "50": 36.36, "51": 38.76, "52": 36.44, "53": 37.92, "54": 37.12, "55": 38.28, "56": 40.12, "57": 42.0, "58": 40.44, "59": 40.96, "60": 39.4, "61": 37.08, "62": 37.84, "63": 37.92, "64": 39.92, "65": 37.72, "66": 40.28, "67": 37.12, "68": 39.92, "69": 38.64, "70": 39.84, "71": 40.56, "72": 39.2, "73": 36.92, "74": 37.24, "75": 41.32, "76": 37.8, "77": 38.4, "78": 38.68, "79": 36.32, "80": 37.32, "81": 35.84, "82": 35.24, "83": 32.64, "84": 36.96, "85": 35.8, "86": 34.76, "87": 35.68, "88": 36.8, "89": 34.76, "90": 31.6, "91": 35.92, "92": 35.76, "93": 38.12, "94": 40.96, "95": 38.76, "96": 35.96, "97": 37.24, "98": 37.32, "99": 38.88, "100": 37.32, "101": 35.24, "102": 32.88, "103": 33.08, "104": 37.68, "105": 33.12, "106": 34.64, "107": 37.68, "108": 34.4, "109": 33.84, "110": 34.36, "111": 33.36, "112": 36.68, "113": 36.56, "114": 32.24, "115": 32.76, "116": 35.04, "117": 38.0, "118": 36.4, "119": 32.64, "120": 39.88, "121": 33.2, "122": 40.48, "123": 39.0, "124": 33.4, "125": 32.72, "126": 32.32, "127": 30.88, "128": 29.0, "129": 33.36, "130": 28.92, "131": 30.44, "132": 33.16, "133": 30.44, "134": 32.76, "135": 31.36, "136": 29.36, "137": 27.04, "138": 28.72, "139": 30.12, "140": 29.96, "141": 31.72, "142": 28.8, "143": 30.84, "144": 30.6, "145": 32.48, "146": 30.28, "147": 30.2, "148": 30.12, "149": 30.92, "150": 29.88, "151": 31.96, "152": 31.72, "153": 28.28, "154": 34.16, "155": 30.0, "156": 28.08, "157": 28.96, "158": 29.72, "159": 29.4, "160": 32.8, "161": 31.68, "162": 30.08, "163": 32.52, "164": 30.08, "165": 27.56, "166": 29.68, "167": 29.92, "168": 27.8, "169": 31.92, "170": 28.32, "171": 28.16, "172": 29.44, "173": 31.32, "174": 29.48, "175": 30.16, "176": 29.48, "177": 28.16, "178": 27.04, "179": 27.32, "180": 26.6, "181": 25.6, "182": 27.48, "183": 23.44, "184": 24.68, "185": 25.44, "186": 26.64, "187": 25.2, "188": 28.64, "189": 28.64, "190": 25.6, "191": 22.04, "192": 23.84, "193": 26.52, "194": 23.24, "195": 27.48, "196": 26.88, "197": 27.64, "198": 21.88, "199": 23.44, "200": 21.88, "201": 22.48, "202": 23.92, "203": 26.44, "204": 25.92, "205": 24.32, "206": 21.0, "207": 23.24, "208": 24.08, "209": 22.68, "210": 24.48, "211": 25.16, "212": 24.08, "213": 22.16, "214": 24.56, "215": 23.64, "216": 22.84, "217": 22.76, "218": 24.44, "219": 24.04, "220": 25.92, "221": 22.96, "222": 23.28, "223": 25.72, "224": 22.88, "225": 23.16, "226": 23.84, "227": 23.64, "228": 23.08, "229": 21.76, "230": 23.84, "231": 22.28, "232": 23.92, "233": 23.24, "234": 20.8, "235": 24.92, "236": 23.16, "237": 22.72, "238": 25.52, "239": 25.16, "240": 21.76, "241": 24.08, "242": 23.04, "243": 25.4, "244": 23.16, "245": 19.88, "246": 25.52, "247": 23.6, "248": 22.32, "249": 25.92, "250": 20.12, "251": 23.44, "252": 25.84, "253": 24.44, "254": 24.36, "255": 23.24, "256": 25.44, "257": 21.88, "258": 25.12, "259": 24.0, "260": 22.44, "261": 23.12, "262": 23.88, "263": 20.88, "264": 23.0, "265": 24.96, "266": 24.36, "267": 19.64, "268": 24.6, "269": 21.56, "270": 22.64, "271": 24.32, "272": 23.0, "273": 22.56, "274": 24.4, "275": 21.52, "276": 24.8, "277": 22.52, "278": 21.76, "279": 25.8, "280": 24.96, "281": 22.8, "282": 22.16, "283": 23.92, "284": 22.12, "285": 24.16, "286": 22.12, "287": 21.92, "288": 25.04, "289": 25.52, "290": 23.36, "291": 23.72, "292": 25.36, "293": 22.92, "294": 25.0, "295": 25.72, "296": 26.4, "297": 24.16, "298": 23.8, "299": 21.76, "300": 24.8, "301": 22.8, "302": 23.2, "303": 21.88, "304": 24.16, "305": 24.24, "306": 24.88, "307": 21.28, "308": 21.96, "309": 21.44, "310": 23.88, "311": 23.24, "312": 19.88, "313": 22.76, "314": 21.4, "315": 25.32, "316": 22.72, "317": 25.28, "318": 22.32, "319": 23.76, "320": 23.36, "321": 22.8, "322": 22.08, "323": 19.96, "324": 25.96, "325": 23.32, "326": 22.4, "327": 24.2, "328": 26.64, "329": 22.36, "330": 24.32, "331": 23.04, "332": 25.52, "333": 25.08, "334": 25.28, "335": 22.68, "336": 25.92, "337": 26.08, "338": 22.92, "339": 27.08, "340": 23.96, "341": 28.08, "342": 23.84, "343": 24.28, "344": 26.0, "345": 22.96, "346": 26.12, "347": 26.72, "348": 25.28, "349": 22.32, "350": 27.6, "351": 21.48, "352": 25.84, "353": 22.2, "354": 23.36, "355": 25.88, "356": 25.2, "357": 23.84, "358": 24.28, "359": 27.88, "360": 23.6, "361": 23.32, "362": 24.0, "363": 24.84, "364": 25.44, "365": 26.6, "366": 26.08, "367": 25.56, "368": 23.6, "369": 26.0, "370": 24.44, "371": 23.32, "372": 24.04, "373": 22.88, "374": 23.56, "375": 24.0, "376": 22.56, "377": 21.48, "378": 25.24, "379": 26.24, "380": 25.16, "381": 23.08, "382": 26.28, "383": 22.52, "384": 21.24, "385": 21.04, "386": 22.32, "387": 22.36, "388": 23.84, "389": 24.4, "390": 23.0, "391": 23.44, "392": 23.4, "393": 24.12, "394": 25.0, "395": 22.32, "396": 20.84, "397": 22.68, "398": 24.28, "399": 20.44, "400": 26.24, "401": 23.48, "402": 26.04, "403": 24.32, "404": 25.36, "405": 24.44, "406": 22.92, "407": 24.0, "408": 23.24, "409": 25.4, "410": 26.44, "411": 23.96, "412": 25.04, "413": 25.04, "414": 26.6, "415": 24.88, "416": 24.0, "417": 23.4, "418": 25.08, "419": 24.04, "420": 22.12, "421": 25.32, "422": 25.24, "423": 21.64, "424": 23.48, "425": 22.68, "426": 23.52, "427": 24.48, "428": 20.6, "429": 21.48, "430": 25.36, "431": 24.4, "432": 23.48, "433": 23.32, "434": 23.84, "435": 21.4, "436": 24.0, "437": 21.72, "438": 22.12, "439": 23.6, "440": 22.6, "441": 22.64, "442": 20.88, "443": 23.76, "444": 20.84, "445": 22.76, "446": 22.56, "447": 28.36, "448": 25.84, "449": 24.0, "450": 26.16, "451": 22.12, "452": 25.4, "453": 25.76, "454": 21.56, "455": 23.08, "456": 23.56, "457": 24.08, "458": 23.92, "459": 22.24, "460": 24.2, "461": 21.2, "462": 23.2, "463": 21.64, "464": 21.68, "465": 23.04, "466": 22.8, "467": 21.88, "468": 22.6, "469": 21.16, "470": 22.68, "471": 21.52, "472": 22.44, "473": 19.84, "474": 22.04, "475": 23.12, "476": 21.6, "477": 23.6, "478": 23.96, "479": 22.6, "480": 23.72, "481": 22.08, "482": 21.28, "483": 19.44, "484": 23.84, "485": 22.36, "486": 22.92, "487": 18.8, "488": 21.24, "489": 23.32, "490": 23.24, "491": 24.6, "492": 21.48, "493": 22.28, "494": 23.72, "495": 22.24, "496": 23.0, "497": 24.36, "498": 26.4, "499": 23.16}, "ave_reward": {"0": -55.38, "1": -60.0, "2": -45.28, "3": -41.0, "4": -41.76, "5": -46.1, "6": -49.96, "7": -47.14, "8": -50.16, "9": -44.02, "10": -41.9, "11": -42.32, "12": -42.48, "13": -42.12, "14": -43.94, "15": -43.42, "16": -43.76, "17": -43.84, "18": -44.56, "19": -45.78, "20": -54.26, "21": -51.08, "22": -47.72, "23": -57.46, "24": -48.16, "25": -46.52, "26": -53.16, "27": -49.04, "28": -51.5, "29": -55.24, "30": -56.3, "31": -58.5, "32": -56.1, "33": -53.1, "34": -54.0, "35": -58.6, "36": -58.6, "37": -58.6, "38": -55.16, "39": -54.66, "40": -54.4, "41": -41.82, "42": -44.3, "43": -49.54, "44": -57.24, "45": -56.08, "46": -51.14, "47": -58.04, "48": -5.5, "49": -33.4, "50": -13.98, "51": -41.58, "52": -26.02, "53": -33.96, "54": -31.16, "55": -38.94, "56": -51.86, "57": -60.0, "58": -52.02, "59": -54.68, "60": -27.5, "61": -28.74, "62": -17.12, "63": -33.96, "64": -44.56, "65": -29.06, "66": -54.34, "67": -26.36, "68": -51.76, "69": -39.12, "70": -44.52, "71": -54.48, "72": -46.6, "73": -26.26, "74": -33.62, "75": -54.86, "76": -57.9, "77": -53.4, "78": -51.14, "79": -16.36, "80": -36.06, "81": -52.12, "82": -44.62, "83": -0.12, "84": -33.48, "85": -32.9, "86": -17.98, "87": -32.84, "88": -33.4, "89": -29.98, "90": -28.4, "91": -30.56, "92": -30.48, "93": -43.66, "94": -54.68, "95": -39.18, "96": -30.58, "97": -33.62, "98": -38.46, "99": -44.04, "100": -36.06, "101": -30.22, "102": -9.84, "103": -12.34, "104": -38.64, "105": -12.36, "106": -20.32, "107": -29.04, "108": -13.0, "109": -5.52, "110": -10.58, "111": -10.08, "112": -18.94, "113": -26.08, "114": -7.12, "115": -4.98, "116": -15.72, "117": -12.4, "118": -4.4, "119": -4.92, "120": -30.14, "121": -2.8, "122": -8.84, "123": -34.5, "124": -14.9, "125": -0.16, "126": -2.36, "127": 0.76, "128": 20.9, "129": -14.88, "130": 4.14, "131": 3.38, "132": -12.38, "133": -1.42, "134": -2.58, "135": -4.28, "136": 11.12, "137": 29.08, "138": 21.04, "139": 8.34, "140": 3.62, "141": -11.66, "142": 4.2, "143": -1.62, "144": -1.5, "145": 4.76, "146": 3.46, "147": 3.5, "148": -1.26, "149": 0.74, "150": -1.14, "151": -2.18, "152": 0.34, "153": -0.34, "154": -22.48, "155": 6.0, "156": 6.96, "157": 11.32, "158": 3.74, "159": 1.5, "160": -12.2, "161": -14.04, "162": 3.56, "163": 11.94, "164": 15.56, "165": 24.02, "166": 32.56, "167": 30.04, "168": 21.5, "169": 7.44, "170": 14.04, "171": 18.92, "172": 13.48, "173": -1.86, "174": 15.86, "175": 8.32, "176": 3.86, "177": 26.12, "178": 26.68, "179": 16.94, "180": 29.3, "181": 27.4, "182": 19.26, "183": 23.68, "184": 35.06, "185": 29.88, "186": 17.28, "187": 20.4, "188": 11.48, "189": 16.28, "190": 29.8, "191": 29.18, "192": 16.28, "193": 24.54, "194": 26.18, "195": 14.46, "196": 14.76, "197": 14.38, "198": 31.66, "199": 35.68, "200": 31.66, "201": 14.56, "202": 18.64, "203": 12.58, "204": 10.44, "205": 16.04, "206": 39.3, "207": 28.58, "208": 13.76, "209": 38.46, "210": 30.36, "211": 39.62, "212": 25.76, "213": 29.12, "214": 15.92, "215": 30.78, "216": 35.98, "217": 31.22, "218": 15.98, "219": 28.18, "220": 20.04, "221": 19.12, "222": 26.16, "223": 20.14, "224": 28.76, "225": 23.82, "226": 33.08, "227": 28.38, "228": 19.06, "229": 22.12, "230": 33.08, "231": 36.26, "232": 18.64, "233": 9.38, "234": 39.4, "235": 25.34, "236": 31.02, "237": 28.84, "238": 17.84, "239": 27.62, "240": 26.92, "241": 25.76, "242": 21.48, "243": 22.7, "244": 31.02, "245": 47.06, "246": 20.24, "247": 42.8, "248": 38.64, "249": 27.24, "250": 51.74, "251": 26.08, "252": 20.08, "253": 20.78, "254": 13.62, "255": 35.78, "256": 13.08, "257": 34.06, "258": 27.64, "259": 23.4, "260": 28.98, "261": 28.64, "262": 33.06, "263": 41.76, "264": 35.9, "265": 27.72, "266": 25.62, "267": 39.98, "268": 18.3, "269": 24.62, "270": 24.08, "271": 13.64, "272": 31.1, "273": 33.72, "274": 11.2, "275": 24.64, "276": 20.6, "277": 16.94, "278": 36.52, "279": 3.3, "280": 13.32, "281": 31.2, "282": 29.12, "283": 25.84, "284": 33.94, "285": 13.72, "286": 26.74, "287": 31.64, "288": 18.08, "289": 17.84, "290": 33.32, "291": 33.14, "292": 15.52, "293": 40.74, "294": 30.1, "295": 17.74, "296": 22.2, "297": 23.32, "298": 25.9, "299": 31.72, "300": 18.2, "301": 33.6, "302": 21.4, "303": 41.26, "304": 28.12, "305": 37.68, "306": 25.36, "307": 51.16, "308": 34.02, "309": 36.68, "310": 35.46, "311": 26.18, "312": 42.26, "313": 36.02, "314": 39.1, "315": 17.94, "316": 38.44, "317": 20.36, "318": 38.64, "319": 25.92, "320": 26.12, "321": 24.0, "322": 31.56, "323": 37.42, "324": 29.62, "325": 38.14, "326": 36.2, "327": 23.3, "328": 12.48, "329": 31.42, "330": 20.84, "331": 23.88, "332": 22.64, "333": 15.66, "334": 5.96, "335": 33.66, "336": 17.64, "337": 17.56, "338": 21.54, "339": 12.26, "340": 30.62, "341": 9.36, "342": 28.28, "343": 23.26, "344": 15.2, "345": 31.12, "346": 17.54, "347": 14.84, "348": 17.96, "349": 36.24, "350": 4.8, "351": 39.06, "352": 15.28, "353": 33.9, "354": 30.92, "355": 27.26, "356": 22.8, "357": 35.48, "358": 30.46, "359": 7.06, "360": 28.4, "361": 33.34, "362": 35.4, "363": 25.38, "364": 13.08, "365": 12.5, "366": 17.56, "367": 13.02, "368": 30.8, "369": 17.6, "370": 30.38, "371": 38.14, "372": 37.78, "373": 38.36, "374": 30.82, "375": 30.6, "376": 33.72, "377": 43.86, "378": 22.78, "379": 31.88, "380": 22.82, "381": 45.46, "382": 17.46, "383": 33.74, "384": 41.58, "385": 44.08, "386": 38.64, "387": 33.82, "388": 30.68, "389": 32.8, "390": 26.3, "391": 4.48, "392": 33.3, "393": 23.34, "394": 30.1, "395": 36.24, "396": 39.38, "397": 31.26, "398": 30.46, "399": 41.98, "400": 34.28, "401": 45.26, "402": 27.18, "403": 25.64, "404": 34.72, "405": 25.58, "406": 28.74, "407": 42.6, "408": 47.78, "409": 27.5, "410": 12.58, "411": 28.22, "412": 18.08, "413": 22.88, "414": 19.7, "415": 32.56, "416": 21.0, "417": 33.3, "418": 34.86, "419": 23.38, "420": 41.14, "421": 25.14, "422": 27.58, "423": 41.38, "424": 23.66, "425": 31.26, "426": 33.24, "427": 27.96, "428": 32.3, "429": 48.66, "430": 20.32, "431": 23.2, "432": 21.26, "433": 26.14, "434": 25.88, "435": 31.9, "436": 23.4, "437": 31.74, "438": 41.14, "439": 16.4, "440": 21.7, "441": 40.88, "442": 44.16, "443": 30.72, "444": 36.98, "445": 38.42, "446": 21.72, "447": 4.42, "448": 10.48, "449": 21.0, "450": 10.32, "451": 31.54, "452": 17.9, "453": 5.72, "454": 34.22, "455": 33.46, "456": 30.82, "457": 28.16, "458": 25.84, "459": 29.08, "460": 23.3, "461": 41.6, "462": 33.4, "463": 36.58, "464": 38.96, "465": 28.68, "466": 38.4, "467": 34.06, "468": 26.5, "469": 34.42, "470": 26.46, "471": 34.24, "472": 28.98, "473": 42.28, "474": 31.58, "475": 19.04, "476": 29.4, "477": 30.8, "478": 18.62, "479": 33.7, "480": 21.14, "481": 33.96, "482": 31.96, "483": 35.28, "484": 25.88, "485": 43.42, "486": 33.54, "487": 40.4, "488": 34.38, "489": 28.54, "490": 33.38, "491": 20.7, "492": 34.26, "493": 29.06, "494": 28.34, "495": 31.48, "496": 31.1, "497": 37.62, "498": 7.8, "499": 28.62}, "success_rate": {"0": 0.0, "1": 0.0, "2": 0.0, "3": 0.0, "4": 0.0, "5": 0.0, "6": 0.0, "7": 0.0, "8": 0.0, "9": 0.0, "10": 0.0, "11": 0.0, "12": 0.0, "13": 0.0, "14": 0.0, "15": 0.0, "16": 0.0, "17": 0.0, "18": 0.0, "19": 0.0, "20": 0.0, "21": 0.0, "22": 0.0, "23": 0.0, "24": 0.0, "25": 0.0, "26": 0.0, "27": 0.0, "28": 0.0, "29": 0.0, "30": 0.0, "31": 0.0, "32": 0.0, "33": 0.0, "34": 0.0, "35": 0.0, "36": 0.0, "37": 0.0, "38": 0.0, "39": 0.0, "40": 0.0, "41": 0.12, "42": 0.1, "43": 0.08, "44": 0.0, "45": 0.0, "46": 0.06, "47": 0.0, "48": 0.42, "49": 0.2, "50": 0.36, "51": 0.14, "52": 0.26, "53": 0.2, "54": 0.22, "55": 0.16, "56": 0.06, "57": 0.0, "58": 0.06, "59": 0.04, "60": 0.26, "61": 0.24, "62": 0.34, "63": 0.2, "64": 0.12, "65": 0.24, "66": 0.04, "67": 0.26, "68": 0.06, "69": 0.16, "70": 0.12, "71": 0.04, "72": 0.1, "73": 0.26, "74": 0.2, "75": 0.04, "76": 0.0, "77": 0.04, "78": 0.06, "79": 0.34, "80": 0.18, "81": 0.04, "82": 0.1, "83": 0.46, "84": 0.2, "85": 0.2, "86": 0.32, "87": 0.2, "88": 0.2, "89": 0.22, "90": 0.22, "91": 0.22, "92": 0.22, "93": 0.12, "94": 0.04, "95": 0.16, "96": 0.22, "97": 0.2, "98": 0.16, "99": 0.12, "100": 0.18, "101": 0.22, "102": 0.38, "103": 0.36, "104": 0.16, "105": 0.36, "106": 0.3, "107": 0.24, "108": 0.36, "109": 0.42, "110": 0.38, "111": 0.38, "112": 0.32, "113": 0.26, "114": 0.4, "115": 0.42, "116": 0.34, "117": 0.38, "118": 0.44, "119": 0.42, "120": 0.24, "121": 0.44, "122": 0.42, "123": 0.2, "124": 0.34, "125": 0.46, "126": 0.44, "127": 0.46, "128": 0.62, "129": 0.34, "130": 0.48, "131": 0.48, "132": 0.36, "133": 0.44, "134": 0.44, "135": 0.42, "136": 0.54, "137": 0.68, "138": 0.62, "139": 0.52, "140": 0.48, "141": 0.36, "142": 0.48, "143": 0.44, "144": 0.44, "145": 0.5, "146": 0.48, "147": 0.48, "148": 0.44, "149": 0.46, "150": 0.44, "151": 0.44, "152": 0.46, "153": 0.44, "154": 0.28, "155": 0.5, "156": 0.5, "157": 0.54, "158": 0.48, "159": 0.46, "160": 0.36, "161": 0.34, "162": 0.48, "163": 0.56, "164": 0.58, "165": 0.64, "166": 0.72, "167": 0.7, "168": 0.62, "169": 0.52, "170": 0.56, "171": 0.6, "172": 0.56, "173": 0.44, "174": 0.58, "175": 0.52, "176": 0.48, "177": 0.66, "178": 0.66, "179": 0.58, "180": 0.68, "181": 0.66, "182": 0.6, "183": 0.62, "184": 0.72, "185": 0.68, "186": 0.58, "187": 0.6, "188": 0.54, "189": 0.58, "190": 0.68, "191": 0.66, "192": 0.56, "193": 0.64, "194": 0.64, "195": 0.56, "196": 0.56, "197": 0.56, "198": 0.68, "199": 0.72, "200": 0.68, "201": 0.54, "202": 0.58, "203": 0.54, "204": 0.52, "205": 0.56, "206": 0.74, "207": 0.66, "208": 0.54, "209": 0.74, "210": 0.68, "211": 0.76, "212": 0.64, "213": 0.66, "214": 0.56, "215": 0.68, "216": 0.72, "217": 0.68, "218": 0.56, "219": 0.66, "220": 0.6, "221": 0.58, "222": 0.64, "223": 0.6, "224": 0.66, "225": 0.62, "226": 0.7, "227": 0.66, "228": 0.58, "229": 0.6, "230": 0.7, "231": 0.72, "232": 0.58, "233": 0.5, "234": 0.74, "235": 0.64, "236": 0.68, "237": 0.66, "238": 0.58, "239": 0.66, "240": 0.64, "241": 0.64, "242": 0.6, "243": 0.62, "244": 0.68, "245": 0.8, "246": 0.6, "247": 0.78, "248": 0.74, "249": 0.66, "250": 0.84, "251": 0.64, "252": 0.6, "253": 0.6, "254": 0.54, "255": 0.72, "256": 0.54, "257": 0.7, "258": 0.66, "259": 0.62, "260": 0.66, "261": 0.66, "262": 0.7, "263": 0.76, "264": 0.72, "265": 0.66, "266": 0.64, "267": 0.74, "268": 0.58, "269": 0.62, "270": 0.62, "271": 0.54, "272": 0.68, "273": 0.7, "274": 0.52, "275": 0.62, "276": 0.6, "277": 0.56, "278": 0.72, "279": 0.46, "280": 0.54, "281": 0.68, "282": 0.66, "283": 0.64, "284": 0.7, "285": 0.54, "286": 0.64, "287": 0.68, "288": 0.58, "289": 0.58, "290": 0.7, "291": 0.7, "292": 0.56, "293": 0.76, "294": 0.68, "295": 0.58, "296": 0.62, "297": 0.62, "298": 0.64, "299": 0.68, "300": 0.58, "301": 0.7, "302": 0.6, "303": 0.76, "304": 0.66, "305": 0.74, "306": 0.64, "307": 0.84, "308": 0.7, "309": 0.72, "310": 0.72, "311": 0.64, "312": 0.76, "313": 0.72, "314": 0.74, "315": 0.58, "316": 0.74, "317": 0.6, "318": 0.74, "319": 0.64, "320": 0.64, "321": 0.62, "322": 0.68, "323": 0.72, "324": 0.68, "325": 0.74, "326": 0.72, "327": 0.62, "328": 0.54, "329": 0.68, "330": 0.6, "331": 0.62, "332": 0.62, "333": 0.56, "334": 0.48, "335": 0.7, "336": 0.58, "337": 0.58, "338": 0.6, "339": 0.54, "340": 0.68, "341": 0.52, "342": 0.66, "343": 0.62, "344": 0.56, "345": 0.68, "346": 0.58, "347": 0.56, "348": 0.58, "349": 0.72, "350": 0.48, "351": 0.74, "352": 0.56, "353": 0.7, "354": 0.68, "355": 0.66, "356": 0.62, "357": 0.72, "358": 0.68, "359": 0.5, "360": 0.66, "361": 0.7, "362": 0.72, "363": 0.64, "364": 0.54, "365": 0.54, "366": 0.58, "367": 0.54, "368": 0.68, "369": 0.58, "370": 0.68, "371": 0.74, "372": 0.74, "373": 0.74, "374": 0.68, "375": 0.68, "376": 0.7, "377": 0.78, "378": 0.62, "379": 0.7, "380": 0.62, "381": 0.8, "382": 0.58, "383": 0.7, "384": 0.76, "385": 0.78, "386": 0.74, "387": 0.7, "388": 0.68, "389": 0.7, "390": 0.64, "391": 0.46, "392": 0.7, "393": 0.62, "394": 0.68, "395": 0.72, "396": 0.74, "397": 0.68, "398": 0.68, "399": 0.76, "400": 0.72, "401": 0.8, "402": 0.66, "403": 0.64, "404": 0.72, "405": 0.64, "406": 0.66, "407": 0.78, "408": 0.82, "409": 0.66, "410": 0.54, "411": 0.66, "412": 0.58, "413": 0.62, "414": 0.6, "415": 0.7, "416": 0.6, "417": 0.7, "418": 0.72, "419": 0.62, "420": 0.76, "421": 0.64, "422": 0.66, "423": 0.76, "424": 0.62, "425": 0.68, "426": 0.7, "427": 0.66, "428": 0.68, "429": 0.82, "430": 0.6, "431": 0.62, "432": 0.6, "433": 0.64, "434": 0.64, "435": 0.68, "436": 0.62, "437": 0.68, "438": 0.76, "439": 0.56, "440": 0.6, "441": 0.76, "442": 0.78, "443": 0.68, "444": 0.72, "445": 0.74, "446": 0.6, "447": 0.48, "448": 0.52, "449": 0.6, "450": 0.52, "451": 0.68, "452": 0.58, "453": 0.48, "454": 0.7, "455": 0.7, "456": 0.68, "457": 0.66, "458": 0.64, "459": 0.66, "460": 0.62, "461": 0.76, "462": 0.7, "463": 0.72, "464": 0.74, "465": 0.66, "466": 0.74, "467": 0.7, "468": 0.64, "469": 0.7, "470": 0.64, "471": 0.7, "472": 0.66, "473": 0.76, "474": 0.68, "475": 0.58, "476": 0.66, "477": 0.68, "478": 0.58, "479": 0.7, "480": 0.6, "481": 0.7, "482": 0.68, "483": 0.7, "484": 0.64, "485": 0.78, "486": 0.7, "487": 0.74, "488": 0.7, "489": 0.66, "490": 0.7, "491": 0.6, "492": 0.7, "493": 0.66, "494": 0.66, "495": 0.68, "496": 0.68, "497": 0.74, "498": 0.5, "499": 0.66}} -------------------------------------------------------------------------------- /src/deep_dialog/checkpoints/temp_run2/agt_9_performance_records.json: -------------------------------------------------------------------------------- 1 | {"ave_turns": {"0": 42.0, "1": 18.28, "2": 9.08, "3": 12.24, "4": 15.64, "5": 8.96, "6": 15.28, "7": 8.4, "8": 3.04, "9": 3.36, "10": 4.36, "11": 4.68, "12": 4.4, "13": 3.96, "14": 5.32, "15": 6.0, "16": 7.08, "17": 7.0, "18": 6.64, "19": 7.2, "20": 15.36, "21": 13.56, "22": 18.4, "23": 21.28, "24": 17.16, "25": 24.2, "26": 24.32, "27": 32.88, "28": 29.08, "29": 40.08, "30": 42.0, "31": 42.0, "32": 42.0, "33": 42.0, "34": 42.0, "35": 39.44, "36": 23.92, "37": 26.16, "38": 27.92, "39": 42.0, "40": 42.0, "41": 42.0, "42": 38.8, "43": 36.88, "44": 37.52, "45": 36.88, "46": 38.8, "47": 39.44, "48": 39.44, "49": 40.72, "50": 40.08, "51": 39.44, "52": 38.8, "53": 40.08, "54": 36.24, "55": 38.8, "56": 42.0, "57": 41.68, "58": 32.96, "59": 36.76, "60": 38.72, "61": 39.8, "62": 40.4, "63": 39.72, "64": 34.84, "65": 34.0, "66": 35.56, "67": 34.04, "68": 35.64, "69": 34.8, "70": 31.36, "71": 27.76, "72": 29.72, "73": 30.48, "74": 28.24, "75": 29.48, "76": 30.2, "77": 31.76, "78": 27.24, "79": 30.72, "80": 29.16, "81": 31.72, "82": 31.08, "83": 32.84, "84": 32.96, "85": 31.6, "86": 31.8, "87": 29.8, "88": 33.28, "89": 30.32, "90": 34.24, "91": 31.52, "92": 37.36, "93": 35.04, "94": 30.2, "95": 29.76, "96": 32.64, "97": 28.8, "98": 30.68, "99": 26.48, "100": 29.4, "101": 30.28, "102": 27.04, "103": 28.28, "104": 28.68, "105": 28.64, "106": 25.4, "107": 27.04, "108": 26.48, "109": 29.68, "110": 26.84, "111": 31.28, "112": 27.04, "113": 28.36, "114": 26.52, "115": 25.08, "116": 29.56, "117": 27.92, "118": 26.92, "119": 27.04, "120": 29.12, "121": 26.88, "122": 29.0, "123": 29.36, "124": 27.32, "125": 28.2, "126": 26.24, "127": 30.52, "128": 27.36, "129": 26.72, "130": 27.08, "131": 24.36, "132": 30.24, "133": 25.52, "134": 27.04, "135": 24.76, "136": 30.08, "137": 26.6, "138": 27.6, "139": 23.32, "140": 27.92, "141": 29.16, "142": 28.36, "143": 28.12, "144": 28.36, "145": 27.32, "146": 25.24, "147": 27.0, "148": 24.8, "149": 23.44, "150": 27.24, "151": 28.16, "152": 25.44, "153": 28.36, "154": 22.96, "155": 24.28, "156": 24.72, "157": 25.36, "158": 25.96, "159": 21.6, "160": 25.28, "161": 23.04, "162": 23.12, "163": 25.0, "164": 20.04, "165": 19.96, "166": 22.24, "167": 21.08, "168": 21.6, "169": 19.48, "170": 22.32, "171": 22.28, "172": 22.72, "173": 22.32, "174": 20.24, "175": 23.8, "176": 22.04, "177": 17.2, "178": 17.64, "179": 21.4, "180": 19.84, "181": 18.32, "182": 18.76, "183": 22.48, "184": 22.96, "185": 19.92, "186": 18.04, "187": 17.96, "188": 19.36, "189": 19.28, "190": 18.84, "191": 17.4, "192": 18.56, "193": 15.16, "194": 14.84, "195": 19.44, "196": 15.52, "197": 16.36, "198": 14.92, "199": 15.44, "200": 18.2, "201": 16.64, "202": 16.0, "203": 17.72, "204": 15.96, "205": 17.24, "206": 15.64, "207": 17.44, "208": 16.88, "209": 15.92, "210": 15.64, "211": 16.52, "212": 22.6, "213": 18.2, "214": 15.44, "215": 16.24, "216": 14.64, "217": 16.32, "218": 17.0, "219": 15.24, "220": 16.16, "221": 16.68, "222": 16.52, "223": 15.56, "224": 17.92, "225": 15.28, "226": 16.72, "227": 17.36, "228": 15.68, "229": 16.48, "230": 16.04, "231": 15.24, "232": 16.16, "233": 16.72, "234": 15.6, "235": 14.8, "236": 15.6, "237": 15.28, "238": 14.84, "239": 19.76, "240": 16.52, "241": 14.08, "242": 19.08, "243": 16.28, "244": 18.52, "245": 15.52, "246": 14.76, "247": 13.72, "248": 15.32, "249": 18.16, "250": 14.64, "251": 15.4, "252": 15.4, "253": 14.24, "254": 13.84, "255": 14.04, "256": 14.56, "257": 14.56, "258": 19.36, "259": 17.56, "260": 14.36, "261": 15.0, "262": 14.72}, "ave_reward": {"0": -60.0, "1": -48.14, "2": -43.54, "3": -45.12, "4": -46.82, "5": -43.48, "6": -46.64, "7": -43.2, "8": -40.52, "9": -40.68, "10": -41.18, "11": -41.34, "12": -41.2, "13": -40.98, "14": -41.66, "15": -42.0, "16": -42.54, "17": -42.5, "18": -42.32, "19": -42.6, "20": -46.68, "21": -45.78, "22": -48.2, "23": -49.64, "24": -47.58, "25": -51.1, "26": -51.16, "27": -55.44, "28": -53.54, "29": -59.04, "30": -60.0, "31": -60.0, "32": -60.0, "33": -60.0, "34": -60.0, "35": -58.72, "36": -50.96, "37": -52.08, "38": -52.96, "39": -60.0, "40": -60.0, "41": -60.0, "42": -58.4, "43": -57.44, "44": -57.76, "45": -57.44, "46": -58.4, "47": -58.72, "48": -58.72, "49": -59.36, "50": -59.04, "51": -58.72, "52": -58.4, "53": -59.04, "54": -57.12, "55": -58.4, "56": -60.0, "57": -59.84, "58": -19.48, "59": -40.58, "60": -53.56, "61": -51.7, "62": -47.2, "63": -49.26, "64": -10.82, "65": -3.2, "66": -13.58, "67": 1.58, "68": -18.42, "69": -18.0, "70": -1.88, "71": 9.52, "72": -3.46, "73": -6.24, "74": 18.88, "75": -12.94, "76": -1.3, "77": -9.28, "78": 24.18, "79": -3.96, "80": 13.62, "81": 17.14, "82": 17.46, "83": 4.58, "84": -5.08, "85": 2.8, "86": 5.1, "87": 22.9, "88": -7.64, "89": 8.24, "90": -3.32, "91": 7.64, "92": 7.12, "93": 5.88, "94": 22.7, "95": 27.72, "96": -7.32, "97": 11.4, "98": 5.66, "99": 24.56, "100": 1.5, "101": 3.46, "102": 19.48, "103": 14.06, "104": 6.66, "105": 9.08, "106": 22.7, "107": 19.48, "108": 22.16, "109": 3.76, "110": 17.18, "111": -9.04, "112": 19.48, "113": 11.62, "114": 26.94, "115": 27.66, "116": 3.82, "117": 16.64, "118": 21.94, "119": 21.88, "120": 6.44, "121": 14.76, "122": 4.1, "123": 3.92, "124": 21.74, "125": 11.7, "126": 22.28, "127": -1.46, "128": 14.52, "129": 19.64, "130": 12.26, "131": 30.42, "132": -1.32, "133": 20.24, "134": 14.68, "135": 23.02, "136": -3.64, "137": 14.9, "138": 14.4, "139": 35.74, "140": 14.24, "141": 4.02, "142": 6.82, "143": 6.94, "144": 4.42, "145": 19.34, "146": 37.18, "147": 21.9, "148": 27.8, "149": 33.28, "150": 16.98, "151": 14.12, "152": 25.08, "153": 18.82, "154": 33.52, "155": 40.06, "156": 11.04, "157": 15.52, "158": 20.02, "159": 27.0, "160": 13.16, "161": 26.28, "162": 35.84, "163": 22.9, "164": 27.78, "165": 30.22, "166": 29.08, "167": 22.46, "168": 36.6, "169": 30.46, "170": 17.04, "171": 24.26, "172": 31.24, "173": 26.64, "174": 34.88, "175": 21.1, "176": 33.98, "177": 55.6, "178": 45.78, "179": 31.9, "180": 37.48, "181": 43.04, "182": 38.02, "183": 28.96, "184": 26.32, "185": 35.04, "186": 45.58, "187": 38.42, "188": 42.52, "189": 37.76, "190": 40.38, "191": 48.3, "192": 47.72, "193": 56.62, "194": 61.58, "195": 42.48, "196": 51.64, "197": 48.82, "198": 49.54, "199": 58.88, "200": 43.1, "201": 43.88, "202": 56.2, "203": 38.54, "204": 56.22, "205": 53.18, "206": 53.98, "207": 48.28, "208": 46.16, "209": 51.44, "210": 49.18, "211": 58.34, "212": 40.9, "213": 40.7, "214": 46.88, "215": 46.48, "216": 61.68, "217": 41.64, "218": 46.1, "219": 58.98, "220": 53.72, "221": 63.06, "222": 51.14, "223": 51.62, "224": 48.04, "225": 54.16, "226": 43.84, "227": 53.12, "228": 56.36, "229": 48.76, "230": 56.18, "231": 56.58, "232": 60.92, "233": 63.04, "234": 56.4, "235": 56.8, "236": 56.4, "237": 58.96, "238": 37.58, "239": 54.32, "240": 53.54, "241": 54.76, "242": 54.66, "243": 46.46, "244": 54.94, "245": 54.04, "246": 49.62, "247": 59.74, "248": 58.94, "249": 47.92, "250": 56.88, "251": 54.1, "252": 61.3, "253": 61.88, "254": 52.48, "255": 57.18, "256": 52.12, "257": 56.92, "258": 47.32, "259": 53.02, "260": 61.82, "261": 63.9, "262": 61.64}, "success_rate": {"0": 0.0, "1": 0.0, "2": 0.0, "3": 0.0, "4": 0.0, "5": 0.0, "6": 0.0, "7": 0.0, "8": 0.0, "9": 0.0, "10": 0.0, "11": 0.0, "12": 0.0, "13": 0.0, "14": 0.0, "15": 0.0, "16": 0.0, "17": 0.0, "18": 0.0, "19": 0.0, "20": 0.0, "21": 0.0, "22": 0.0, "23": 0.0, "24": 0.0, "25": 0.0, "26": 0.0, "27": 0.0, "28": 0.0, "29": 0.0, "30": 0.0, "31": 0.0, "32": 0.0, "33": 0.0, "34": 0.0, "35": 0.0, "36": 0.0, "37": 0.0, "38": 0.0, "39": 0.0, "40": 0.0, "41": 0.0, "42": 0.0, "43": 0.0, "44": 0.0, "45": 0.0, "46": 0.0, "47": 0.0, "48": 0.0, "49": 0.0, "50": 0.0, "51": 0.0, "52": 0.0, "53": 0.0, "54": 0.0, "55": 0.0, "56": 0.0, "57": 0.0, "58": 0.3, "59": 0.14, "60": 0.04, "61": 0.06, "62": 0.1, "63": 0.08, "64": 0.38, "65": 0.44, "66": 0.36, "67": 0.48, "68": 0.32, "69": 0.32, "70": 0.44, "71": 0.52, "72": 0.42, "73": 0.4, "74": 0.6, "75": 0.34, "76": 0.44, "77": 0.38, "78": 0.64, "79": 0.42, "80": 0.56, "81": 0.6, "82": 0.6, "83": 0.5, "84": 0.42, "85": 0.48, "86": 0.5, "87": 0.64, "88": 0.4, "89": 0.52, "90": 0.44, "91": 0.52, "92": 0.54, "93": 0.52, "94": 0.64, "95": 0.68, "96": 0.4, "97": 0.54, "98": 0.5, "99": 0.64, "100": 0.46, "101": 0.48, "102": 0.6, "103": 0.56, "104": 0.5, "105": 0.52, "106": 0.62, "107": 0.6, "108": 0.62, "109": 0.48, "110": 0.58, "111": 0.38, "112": 0.6, "113": 0.54, "114": 0.66, "115": 0.66, "116": 0.48, "117": 0.58, "118": 0.62, "119": 0.62, "120": 0.5, "121": 0.56, "122": 0.48, "123": 0.48, "124": 0.62, "125": 0.54, "126": 0.62, "127": 0.44, "128": 0.56, "129": 0.6, "130": 0.54, "131": 0.68, "132": 0.44, "133": 0.6, "134": 0.56, "135": 0.62, "136": 0.42, "137": 0.56, "138": 0.56, "139": 0.72, "140": 0.56, "141": 0.48, "142": 0.5, "143": 0.5, "144": 0.48, "145": 0.6, "146": 0.74, "147": 0.62, "148": 0.66, "149": 0.7, "150": 0.58, "151": 0.56, "152": 0.64, "153": 0.6, "154": 0.7, "155": 0.76, "156": 0.52, "157": 0.56, "158": 0.6, "159": 0.64, "160": 0.54, "161": 0.64, "162": 0.72, "163": 0.62, "164": 0.64, "165": 0.66, "166": 0.66, "167": 0.6, "168": 0.72, "169": 0.66, "170": 0.56, "171": 0.62, "172": 0.68, "173": 0.64, "174": 0.7, "175": 0.6, "176": 0.7, "177": 0.86, "178": 0.78, "179": 0.68, "180": 0.72, "181": 0.76, "182": 0.72, "183": 0.66, "184": 0.64, "185": 0.7, "186": 0.78, "187": 0.72, "188": 0.76, "189": 0.72, "190": 0.74, "191": 0.8, "192": 0.8, "193": 0.86, "194": 0.9, "195": 0.76, "196": 0.82, "197": 0.8, "198": 0.8, "199": 0.88, "200": 0.76, "201": 0.76, "202": 0.86, "203": 0.72, "204": 0.86, "205": 0.84, "206": 0.84, "207": 0.8, "208": 0.78, "209": 0.82, "210": 0.8, "211": 0.88, "212": 0.76, "213": 0.74, "214": 0.78, "215": 0.78, "216": 0.9, "217": 0.74, "218": 0.78, "219": 0.88, "220": 0.84, "221": 0.92, "222": 0.82, "223": 0.82, "224": 0.8, "225": 0.84, "226": 0.76, "227": 0.84, "228": 0.86, "229": 0.8, "230": 0.86, "231": 0.86, "232": 0.9, "233": 0.92, "234": 0.86, "235": 0.86, "236": 0.86, "237": 0.88, "238": 0.7, "239": 0.86, "240": 0.84, "241": 0.84, "242": 0.86, "243": 0.78, "244": 0.86, "245": 0.84, "246": 0.8, "247": 0.88, "248": 0.88, "249": 0.8, "250": 0.86, "251": 0.84, "252": 0.9, "253": 0.9, "254": 0.82, "255": 0.86, "256": 0.82, "257": 0.86, "258": 0.8, "259": 0.84, "260": 0.9, "261": 0.92, "262": 0.9}} -------------------------------------------------------------------------------- /src/deep_dialog/data/count_uniq_slots.py: -------------------------------------------------------------------------------- 1 | import json, cPickle 2 | goals = cPickle.load(open('user_goals_first_turn_template.part.movie.v1.p')) 3 | 4 | slots = [] 5 | for i in goals: 6 | for j in i['inform_slots'].keys(): 7 | slots.append(j) 8 | for j in i['request_slots'].keys(): 9 | slots.append(j) 10 | 11 | print slots -------------------------------------------------------------------------------- /src/deep_dialog/data/dia_acts.txt: -------------------------------------------------------------------------------- 1 | request 2 | inform 3 | confirm_question 4 | confirm_answer 5 | greeting 6 | closing 7 | multiple_choice 8 | thanks 9 | welcome 10 | deny 11 | not_sure -------------------------------------------------------------------------------- /src/deep_dialog/data/slot_set.txt: -------------------------------------------------------------------------------- 1 | actor 2 | actress 3 | city 4 | closing 5 | critic_rating 6 | date 7 | description 8 | distanceconstraints 9 | genre 10 | greeting 11 | implicit_value 12 | movie_series 13 | moviename 14 | mpaa_rating 15 | numberofpeople 16 | numberofkids 17 | taskcomplete 18 | other 19 | price 20 | seating 21 | starttime 22 | state 23 | theater 24 | theater_chain 25 | video_format 26 | zip 27 | result 28 | ticket 29 | mc_list -------------------------------------------------------------------------------- /src/deep_dialog/data/slot_set_small.txt: -------------------------------------------------------------------------------- 1 | city 2 | closing 3 | date 4 | distanceconstraints 5 | greeting 6 | moviename 7 | numberofpeople 8 | taskcomplete 9 | price 10 | starttime 11 | state 12 | theater 13 | theater_chain 14 | video_format 15 | zip 16 | ticket -------------------------------------------------------------------------------- /src/deep_dialog/data/user_goals_ids.json: -------------------------------------------------------------------------------- 1 | [ 2 | "6cdda7d7-5f47-4e1a-9ac7-062df28eb09e", 3 | "07e18f3d-812d-4148-acca-88b3c0c0661d", 4 | "8e94d520-5209-4e9a-8f24-44e7919b331c", 5 | "eebf8ebf-afd8-4412-96df-9617ba75a4bc", 6 | "ae80d86b-8740-4690-98ed-ceed290046e6", 7 | "ec858812-d9fa-4cbc-93a5-423e5c61b197", 8 | "23890181-ac05-4aae-a2d7-288ec767cf46", 9 | "40aa2972-0dfd-41f0-9549-2f9e920d2aee", 10 | "20fc87d9-2dc9-4c2c-8c2e-90fe38ca9a47", 11 | "2cfa2b1f-abdf-40bd-b8c9-8e176293ad80", 12 | "b87f65c2-2c90-48a5-92a3-d6eb105a36fe", 13 | "84dc7cbe-9f30-434c-b829-3b5d94c76f64", 14 | "2309641d-29e3-469a-8fcc-e5d2b92c636f", 15 | "bec447a1-3033-4ee0-b426-18fdecd83990", 16 | "483c9908-2bd2-4944-852e-58e8cc91e0ce", 17 | "399c4977-6348-4e0f-9d1f-ac0d40965ad9", 18 | "533cd2af-69eb-48ee-aa26-2be61d269942", 19 | "76522eb0-22e3-444d-90cc-6d6802bb1b48", 20 | "c3ae808c-ee38-4527-b88c-1e2921e29a41", 21 | "59b87661-a61a-44aa-997e-3aa9fca819e2", 22 | "20a9e32d-5d03-4ff0-b01d-5e623cb172ed", 23 | "aa516e27-a65c-4fee-b84e-b3133483ffd0", 24 | "408e5aca-a86e-4c81-aded-996ea3b74b60", 25 | "37420c2f-201d-4e10-99e2-1d3807b4fed5", 26 | "a24601eb-1296-4696-a9c1-b09d671f3ff7", 27 | "286501cb-8fa8-4987-8a74-26ec40c855b4", 28 | "8611f1c0-3873-4838-9d9c-32226f45c632", 29 | "98687258-793e-4711-ab88-3b03ec57415f", 30 | "f0e4ecc2-8fce-4ed5-be0b-a9a78bb857eb", 31 | "ad8d559a-4d41-4d5c-b9f1-ff55472f0ccb", 32 | "d60cff39-e90b-4c2e-af42-3b08dc13384b", 33 | "26015d87-9870-46d7-9d9a-7897e612e5e4", 34 | "5cac12bc-413e-48dc-8704-01aea558bf9b", 35 | "b97251e8-1924-49fa-90d1-78c59cc73f67", 36 | "07739bea-80fa-45c0-a4dc-779aaeaac6cd", 37 | "ad9bd9ce-7d57-4b4c-ab0b-85d93a7e7a83", 38 | "40391646-886c-404d-9fcd-fa68d069b5d4", 39 | "16194e5d-7ef7-438d-9297-847d5d9e5f9d", 40 | "9bcb57cb-e3c9-40f8-bf5c-0cd3cd147e34", 41 | "6d3768b4-4b95-4000-864f-9fdf328d7aa4", 42 | "96fec6f3-d51e-447a-9981-d775f5a41e8e", 43 | "8189a79a-0d4e-4851-a370-0464f9e41c07", 44 | "50534277-eb2e-43d2-80e9-53f43748cfa0", 45 | "e8ab62bc-cb2d-4060-9786-dbd604ba8824", 46 | "ee7ca75a-e7c7-4ffb-bfce-423b6e755c24", 47 | "a9066c3e-0bb5-4179-90f5-5acb615326ee", 48 | "bdd28963-5807-4162-934b-1d78c947a075", 49 | "011ae85d-3f36-48ad-b550-ef31a65869db", 50 | "ae8d13cb-70b1-445f-b7d3-36dd5eeda4f0", 51 | "4472aa88-3475-4463-b36e-3efd84c1359d", 52 | "3517e133-8e3c-4ca2-a3f1-24e3357d2a18", 53 | "67cfcebf-3e8a-47e4-83e8-a8da18661475", 54 | "34624da6-07e7-403b-880b-28395201c494", 55 | "d8f82c80-c552-4594-a232-2d6c46ef3fb6", 56 | "6d08c795-6258-4892-968d-c41ca29cb41b", 57 | "43abc4d5-7cb3-432e-bc7e-53136b82280e", 58 | "297c09dd-572d-49d3-ba8e-07c3713d580a", 59 | "9bf116b5-3e1b-4cce-ad37-adb7238784be", 60 | "da21c48d-026d-478f-b297-73dcca348f8f", 61 | "79e71fb1-a8d6-40dd-8175-79ae4699aa44", 62 | "fca4c5f8-7980-4525-b2d0-9d6968b1ce22", 63 | "3ef9b9d1-9cc2-4fff-a3d1-13ffc8f3bcb9", 64 | "2db4ef77-e643-4d43-bb32-344830050cdb", 65 | "b6853aa2-378a-4f1e-bb6b-7b1bf842643b", 66 | "77183042-e1ad-445b-b342-8ab090af5c2e", 67 | "83adb592-fa6e-4a34-a191-b3e81cfc4572", 68 | "0c0965c4-a152-48e9-a06d-fce7ba0d2116", 69 | "2b399703-c66b-43f0-ba9a-225e1e258ee4", 70 | "e34f59d8-6be9-4e3a-b571-ffcc4f807e9f", 71 | "4fbca27c-4b8f-4b24-9414-56c1cb311322", 72 | "fcc562b3-1818-4c9c-a92e-0f2e054f5275", 73 | "7cda218d-7e53-4dd4-aa93-5062decc44c9", 74 | "2b831410-cf4a-4ecf-9f7d-6dba97cee339", 75 | "ce198848-5b9e-4389-a38d-8b02218887bd", 76 | "be39025e-dcc1-4093-a1aa-e7f242216c06", 77 | "b134e80b-7465-4c8a-835e-a70cb0cd15fb", 78 | "20e83af6-12cf-446a-bad1-9c6d57817b70", 79 | "a4ab33ee-8b94-4182-a048-2ae749af61ec", 80 | "03734d33-30ce-4e8c-917a-e9cc73f71170", 81 | "c85b92f7-aea3-4178-add0-1c1302d8b7ae", 82 | "41dc1472-5103-4b46-8fd5-a250ed092d0e", 83 | "b75f9d50-6574-4f38-85ed-310ab3d240f5", 84 | "1d8ee5b9-a286-4d05-945f-0630ec785e0a", 85 | "d816f4c4-1b3f-46ad-bd5f-406a0903b4de", 86 | "8e1f3937-52d5-43ac-b2e3-aa554922248a", 87 | "b1f2159f-086c-481d-90fe-8e90013b3812", 88 | "db091bba-63ad-4017-9f2b-8276f3dd55e2", 89 | "47a86822-2e6e-4344-82e1-1a0e2a7cee98", 90 | "483a53e1-1b3b-4cd4-b34f-4747ecac9a81", 91 | "f48ab6ff-27c8-4af1-b726-5d6eea33846c", 92 | "781979f3-aa8f-47a3-b69c-2926c2648db7", 93 | "2e73cd1c-db13-4510-8c36-a374ee9e0f8d", 94 | "2e29af64-328f-42da-ac5d-212018f7b263", 95 | "2f059cc8-ac3f-4e82-9192-0e8336bd3ba2", 96 | "c6761105-5204-4405-853b-c95cdb976786", 97 | "cab59005-1855-44d6-b74f-4b95e6abb986", 98 | "eecb5ad2-f622-4745-ac24-9af997b6af7a", 99 | "babfd616-09aa-4817-834e-b0a1c5487303", 100 | "167115ac-04d5-4ebe-a9c4-c95e6146e98b", 101 | "1c4f60ab-42eb-4563-8414-0be1a85e7b62", 102 | "1820c648-4ba0-489e-ad23-ee8efbbc1596", 103 | "1c0e0dc7-3ec2-46e4-b552-3d8dbdcda588", 104 | "1fc2880b-9531-4d01-ab68-62ae9dbaf1d9", 105 | "a17da9c5-9512-4fc8-94ba-8d8d968659a0", 106 | "ac60f207-9087-4b09-aaa4-0a4c203cb1f5", 107 | "31bb4b13-8467-4e56-9f07-6a2b01930996", 108 | "58162f20-6ad1-44d4-a1c3-7bcc403c90ac", 109 | "298303bd-1b2f-48ff-bce3-bebefdc3e5ff", 110 | "0df2a80f-7964-4d6f-a0d1-a33a9ba0ad84", 111 | "140cb755-c345-43a5-b21a-43c00b4a67d1", 112 | "77992ac0-9900-48c9-9544-4c22b3e368a2", 113 | "c7e0a289-4314-4d01-8e42-b9b0c56668ef", 114 | "2b99e197-070c-4b6c-8b6c-07bf079dee3e", 115 | "07a47ee8-297b-48ab-a83d-772ce51755cb", 116 | "c8c80c17-df18-4b43-a7fd-b3c5477d88d1", 117 | "f4a2dcbe-8860-45ba-93f6-229b59a091fe", 118 | "ad6af7dd-12b0-40e5-be3a-fb83d917f592", 119 | "a17a59c8-351a-4817-8374-0359163b888f", 120 | "ae8c4980-bb65-4433-8b19-21788039ca6a", 121 | "0f29c12a-140f-430d-8776-024e7f6cb9cc", 122 | "fa6cfa89-304f-48c2-b559-a319eec7b0dc", 123 | "72e7d2fc-3107-429f-9ddc-0ca54f5f8d3b", 124 | "eb4bcabb-58ed-42ec-a44c-76d990d6c494", 125 | "8cc30097-7254-4dca-b20d-43a0cf63d11f", 126 | "0b070308-6444-40b1-acf9-896115d1f5ca", 127 | "bd652411-9467-4c1c-8408-931af12211dd", 128 | "80c277f0-fc73-48ff-be09-da04c41eabb3", 129 | "b7ae313f-4dea-4982-9b2b-85f37db53654", 130 | "e62d0172-0880-4235-a6d7-bc4957a34af8", 131 | "89b56763-cdf7-4c90-b6d7-875a87b909ba", 132 | "3c9085d6-3595-4569-a53a-713ab9f2c333", 133 | "53e4b4b4-0b55-46f2-bcb5-2000eba844e3", 134 | "0bd2e714-0579-4bf2-968e-cd5106a1f506", 135 | "58b8fef4-f43d-4a98-ad85-8d747fd881d3", 136 | "96ab6333-fc2c-4839-a84f-2fc45ac6488a", 137 | "f380ab93-0877-4495-9e19-85656e1c7977", 138 | "48249f21-205c-4ebf-8849-3ed9fdc5eaed", 139 | "637b5848-1821-410e-a0da-0e9244937c42", 140 | "5772ce45-4a7f-4a22-93b2-1acbf928534e", 141 | "0f8f0572-bb9d-417d-b125-5c0b48c0b5c4", 142 | "442a1b1a-fc6f-43f1-acf4-cfadb3de3207", 143 | "be64db44-edeb-4e54-9670-320abb7ccd3b", 144 | "eb5e4094-0110-4672-bc7c-4c8c05c12bd5", 145 | "bde2a4ed-00a5-45db-8fd5-dd78548944a2", 146 | "fe80cd5d-0211-4506-a286-b2513709d8a2", 147 | "c41c84fd-1b55-4c17-bbe5-21a16c662a46", 148 | "a4d75e93-0a19-4551-b839-acd1ee88e69b", 149 | "0a272523-26e1-45ab-97b4-a1784114f76c", 150 | "cd5cf4c9-e9bd-4a4f-a275-eff6bd2bc526", 151 | "0427a839-fbff-4223-a8cd-476b32895384", 152 | "f81a696c-59ba-4ba6-8676-744b43e177aa", 153 | "773baa08-b203-4233-8b5f-3a2dd3b87c5b", 154 | "9661972d-99ce-43d1-a8eb-576c6ed816c7", 155 | "078842e2-66f7-44c6-add8-05d558677de7", 156 | "30316791-5f5e-4dd6-8f2a-f06b54b6e0d4", 157 | "154893bf-e05a-481f-8c7a-fccfe4d0db70", 158 | "4f514663-1a49-43de-9f07-f572a8420a37", 159 | "b93933b4-90ba-4750-b17f-270e6af7d273", 160 | "b055bec7-c6f9-4ac5-b10d-6025a53b1671", 161 | "856c5979-404f-4bc3-9b1c-92c49131022e", 162 | "043564db-b296-4fd9-b6ca-f201e0a31564", 163 | "192e8c18-b37d-4073-ad15-eaefb8c88116", 164 | "1a225f08-fefe-412a-8d3d-823815a3456d", 165 | "54800e92-65d1-49f4-992c-ca30345a397b", 166 | "5f099cec-56af-42bc-87a0-0ec1a3cfda29", 167 | "40761539-bda4-4377-8925-ad3a4a06b511", 168 | "78762289-7081-4f33-bb76-1342c11547ea", 169 | "7751222e-1c46-4d1e-840d-9a5fee9f2d0a", 170 | "99302780-8073-4203-a924-767e4c5ccdc7", 171 | "719c13c8-d1bc-4c3f-8ff9-171159818a16", 172 | "9b206223-b450-4f3b-9189-ce9dc851691e", 173 | "829f7f20-639f-407f-a7bb-6d8a232eeecd", 174 | "d48a73c9-4902-4c7c-9c2b-e6bfeea8bcfa", 175 | "56e482bb-787f-4161-8daf-8e1533146411", 176 | "57813be1-f901-4a6c-9ed9-45398eaa0200", 177 | "751c4265-569c-407f-a5ea-fa6b24186d57", 178 | "3b343e7b-ccd5-48bb-9376-facf12a5b51b", 179 | "1c9f4917-ebf3-40fd-a5bc-0bbf4bbed528", 180 | "ae0adb98-c55e-4c71-89b2-dd67bc7c1a6c", 181 | "ac277815-7755-4ca4-8574-a9d2b8b576ad", 182 | "762b8509-76c1-4eea-837e-31fe710e47cf", 183 | "e039dadc-92e0-4a01-b45e-fb1dd240ae18", 184 | "5fdcf89e-0ebf-455a-a221-0ff93ac0a900", 185 | "fb6ce50c-fa7f-4da6-8740-cae20bd9cd5f", 186 | "03602ac4-60a2-49b9-bb25-1625821eb41e", 187 | "a656ec76-8c45-4f6e-9472-7ec149cb7a82", 188 | "8be66033-fa24-4713-8bb9-f2cabd799f8e", 189 | "0e8d9dd7-95d2-499f-bcde-3b32e014edf3", 190 | "0d7d7ea9-951b-468b-893c-57dc5f242738", 191 | "c8184e14-ff74-4ed8-ab4e-4ec9af1ad13a", 192 | "4fa0d136-507f-40f9-94e4-cac137c84980", 193 | "a6e6cf16-be15-45bf-be2f-3d0e2c37de27", 194 | "130a606c-f3db-4e06-86e4-ef9122b21289", 195 | "26dbb561-cbea-41ee-baa6-b87bf7d2f39b", 196 | "e74bff07-e65a-4a54-a6e3-59d312c75486", 197 | "896f3c02-5f81-4f18-a011-6a0d99564716", 198 | "0951b157-b455-44b3-9d83-b09367d0c88c", 199 | "1e0c7d21-587d-46e4-8052-6b4e20eeaf3a", 200 | "2af0f406-ab97-4d7a-8255-b899377abf71", 201 | "379229ca-bb32-445b-b7f2-acf277dda052" 202 | ] -------------------------------------------------------------------------------- /src/deep_dialog/dialog_config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 17, 2016 3 | 4 | @author: xiul, t-zalipt 5 | ''' 6 | 7 | sys_inform_slots_for_user = ['city', 'closing', 'date', 'distanceconstraints', 'greeting', 'moviename', 8 | 'numberofpeople', 'taskcomplete', 'price', 'starttime', 'state', 'theater', 9 | 'theater_chain', 'video_format', 'zip'] 10 | 11 | sys_request_slots = ['moviename', 'theater', 'starttime', 'date', 'numberofpeople', 'state', 'city', 'zip', 12 | 'distanceconstraints', 'video_format', 'theater_chain', 'price'] 13 | sys_inform_slots = ['moviename', 'theater', 'starttime', 'date', 'state', 'city', 'zip', 14 | 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'taskcomplete', 'ticket'] 15 | # 16 | # sys_request_slots = ['moviename', 'theater', 'starttime', 'date', 'numberofpeople', 'genre', 'state', 'city', 'zip', 'critic_rating', 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'actor', 'description', 'numberofkids'] 17 | # sys_inform_slots = ['moviename', 'theater', 'starttime', 'date', 'genre', 'state', 'city', 'zip', 'critic_rating', 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'actor', 'description', 'numberofkids', 'taskcomplete', 'ticket'] 18 | # 19 | start_dia_acts = { 20 | # 'greeting':[], 21 | 'request': ['moviename', 'starttime', 'theater', 'city', 'state', 'date', 'ticket', 'numberofpeople'] 22 | } 23 | 24 | # sys_request_slots = ['moviename', 'theater', 'starttime', 'date', 'numberofpeople', 'genre', 'state', 'city', 'zip', 25 | # 'critic_rating', 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 26 | # 'actor', 'description', 'other', 'numberofkids'] 27 | # sys_inform_slots = ['moviename', 'theater', 'starttime', 'date', 'genre', 'state', 'city', 'zip', 'critic_rating', 28 | # 'mpaa_rating', 'distanceconstraints', 'video_format', 'theater_chain', 'price', 'actor', 29 | # 'description', 'other', 'numberofkids', 'taskcomplete', 'ticket'] 30 | # 31 | # start_dia_acts = { 32 | # # 'greeting':[], 33 | # 'request': ['moviename', 'starttime', 'theater', 'city', 'state', 'date', 'genre', 'ticket', 'numberofpeople'] 34 | # } 35 | 36 | ################################################################################ 37 | # Dialog status 38 | ################################################################################ 39 | FAILED_DIALOG = -1 40 | SUCCESS_DIALOG = 1 41 | NO_OUTCOME_YET = 0 42 | 43 | # Rewards 44 | SUCCESS_REWARD = 50 45 | FAILURE_REWARD = 0 46 | PER_TURN_REWARD = 0 47 | 48 | ################################################################################ 49 | # Special Slot Values 50 | ################################################################################ 51 | I_DO_NOT_CARE = "I do not care" 52 | NO_VALUE_MATCH = "NO VALUE MATCHES!!!" 53 | TICKET_AVAILABLE = 'Ticket Available' 54 | 55 | ################################################################################ 56 | # Constraint Check 57 | ################################################################################ 58 | CONSTRAINT_CHECK_FAILURE = 0 59 | CONSTRAINT_CHECK_SUCCESS = 1 60 | 61 | ################################################################################ 62 | # NLG Beam Search 63 | ################################################################################ 64 | nlg_beam_size = 10 65 | 66 | ################################################################################ 67 | # run_mode: 0 for dia-act; 1 for NL; 2 for no output 68 | ################################################################################ 69 | run_mode = 0 70 | auto_suggest = 0 71 | 72 | ################################################################################ 73 | # A Basic Set of Feasible actions to be Consdered By an RL agent 74 | ################################################################################ 75 | feasible_actions = [ 76 | ############################################################################ 77 | # greeting actions 78 | ############################################################################ 79 | # {'diaact':"greeting", 'inform_slots':{}, 'request_slots':{}}, 80 | ############################################################################ 81 | # confirm_question actions 82 | ############################################################################ 83 | {'diaact': "confirm_question", 'inform_slots': {}, 'request_slots': {}}, 84 | ############################################################################ 85 | # confirm_answer actions 86 | ############################################################################ 87 | {'diaact': "confirm_answer", 'inform_slots': {}, 'request_slots': {}}, 88 | ############################################################################ 89 | # thanks actions 90 | ############################################################################ 91 | {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {}}, 92 | ############################################################################ 93 | # deny actions 94 | ############################################################################ 95 | {'diaact': "deny", 'inform_slots': {}, 'request_slots': {}}, 96 | ] 97 | 98 | ############################################################################ 99 | # Adding the inform actions 100 | ############################################################################ 101 | 102 | 103 | sys_inform_slots_for_user = ['city', 'closing', 'date', 'distanceconstraints', 'greeting', 'moviename', 104 | 'numberofpeople', 'taskcomplete', 'price', 'starttime', 'state', 'theater', 105 | 'theater_chain', 'video_format', 'zip', 'description','numberofkids','genre'] 106 | 107 | sys_request_slots_for_user = ['city', 'date', 'moviename', 'numberofpeople', 'starttime', 'state', 'theater', 108 | 'theater_chain', 'video_format', 'zip', 'ticket'] 109 | 110 | for slot in sys_inform_slots: 111 | feasible_actions.append({'diaact': 'inform', 'inform_slots': {slot: "PLACEHOLDER"}, 'request_slots': {}}) 112 | 113 | ############################################################################ 114 | # Adding the request actions 115 | ############################################################################ 116 | for slot in sys_request_slots: 117 | feasible_actions.append({'diaact': 'request', 'inform_slots': {}, 'request_slots': {slot: "UNK"}}) 118 | 119 | feasible_actions_users = [ 120 | {'diaact': "thanks", 'inform_slots': {}, 'request_slots': {}}, 121 | {'diaact': "deny", 'inform_slots': {}, 'request_slots': {}}, 122 | {'diaact': "closing", 'inform_slots': {}, 'request_slots': {}}, 123 | {'diaact': "confirm_answer", 'inform_slots': {}, 'request_slots': {}} 124 | ] 125 | 126 | # for slot in sys_inform_slots_for_user: 127 | for slot in sys_inform_slots_for_user: 128 | feasible_actions_users.append({'diaact': 'inform', 'inform_slots': {slot: "PLACEHOLDER"}, 'request_slots': {}}) 129 | 130 | feasible_actions_users.append( 131 | {'diaact': 'inform', 'inform_slots': {'numberofpeople': "PLACEHOLDER"}, 'request_slots': {}}) 132 | 133 | ############################################################################ 134 | # Adding the request actions 135 | ############################################################################ 136 | for slot in sys_request_slots_for_user: 137 | feasible_actions_users.append({'diaact': 'request', 'inform_slots': {}, 'request_slots': {slot: "UNK"}}) 138 | 139 | feasible_actions_users.append({'diaact': 'inform', 'inform_slots': {}, 'request_slots': {}}) 140 | -------------------------------------------------------------------------------- /src/deep_dialog/dialog_system/__init__.py: -------------------------------------------------------------------------------- 1 | from .kb_helper import * 2 | from .state_tracker import * 3 | from .dialog_manager import * 4 | from .dict_reader import * 5 | from .utils import * -------------------------------------------------------------------------------- /src/deep_dialog/dialog_system/dialog_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 17, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | import json 8 | from . import StateTracker 9 | from deep_dialog import dialog_config 10 | import copy 11 | 12 | 13 | class DialogManager: 14 | """ A dialog manager to mediate the interaction between an agent and a customer """ 15 | 16 | def __init__(self, agent, user, world_model, act_set, slot_set, movie_dictionary): 17 | self.agent = agent 18 | self.user = user 19 | self.world_model = world_model 20 | self.act_set = act_set 21 | self.slot_set = slot_set 22 | self.state_tracker = StateTracker(act_set, slot_set, movie_dictionary) 23 | self.user_action = None 24 | self.reward = 0 25 | self.episode_over = False 26 | 27 | 28 | self.use_world_model = False 29 | self.running_user = self.user 30 | 31 | def initialize_episode(self, use_environment=False): 32 | """ Refresh state for new dialog """ 33 | 34 | self.reward = 0 35 | self.episode_over = False 36 | 37 | self.state_tracker.initialize_episode() 38 | self.running_user = self.user 39 | self.use_world_model = False 40 | 41 | if not use_environment: 42 | self.running_user = self.world_model 43 | self.use_world_model = True 44 | else: 45 | self.running_user = self.user 46 | self.use_world_model = False 47 | 48 | self.user_action = self.running_user.initialize_episode() 49 | 50 | if use_environment: 51 | self.world_model.sample_goal = self.user.sample_goal 52 | 53 | self.state_tracker.update(user_action=self.user_action) 54 | 55 | if dialog_config.run_mode < 3: 56 | print ("New episode, user goal:") 57 | print json.dumps(self.user.goal, indent=2) 58 | self.print_function(user_action=self.user_action) 59 | 60 | self.agent.initialize_episode() 61 | 62 | def next_turn(self, record_training_data=True, record_training_data_for_user=True): 63 | """ This function initiates each subsequent exchange between agent and user (agent first) """ 64 | 65 | ######################################################################## 66 | # CALL AGENT TO TAKE HER TURN 67 | ######################################################################## 68 | self.state = self.state_tracker.get_state_for_agent() 69 | self.agent_action = self.agent.state_to_action(self.state) 70 | 71 | ######################################################################## 72 | # Register AGENT action with the state_tracker 73 | ######################################################################## 74 | self.state_tracker.update(agent_action=self.agent_action) 75 | 76 | self.state_user = self.state_tracker.get_state_for_user() 77 | 78 | self.agent.add_nl_to_action(self.agent_action) # add NL to Agent Dia_Act 79 | self.print_function(agent_action=self.agent_action['act_slot_response']) 80 | 81 | ######################################################################## 82 | # CALL USER TO TAKE HER TURN 83 | ######################################################################## 84 | self.sys_action = self.state_tracker.dialog_history_dictionaries()[-1] 85 | if self.use_world_model: 86 | self.user_action, self.episode_over, self.reward = self.running_user.next(self.state_user, 87 | self.agent.action) 88 | else: 89 | self.user_action, self.episode_over, dialog_status = self.running_user.next(self.sys_action) 90 | self.reward = self.reward_function(dialog_status) 91 | 92 | ######################################################################## 93 | # Update state tracker with latest user action 94 | ######################################################################## 95 | if self.episode_over != True: 96 | self.state_tracker.update(user_action=self.user_action) 97 | self.print_function(user_action=self.user_action) 98 | 99 | self.state_user_next = self.state_tracker.get_state_for_agent() 100 | 101 | ######################################################################## 102 | # Inform agent of the outcome for this timestep (s_t, a_t, r, s_{t+1}, episode_over, s_t_u, user_world_model) 103 | ######################################################################## 104 | if record_training_data: 105 | self.agent.register_experience_replay_tuple(self.state, self.agent_action, self.reward, 106 | self.state_tracker.get_state_for_agent(), self.episode_over, 107 | self.state_user, self.use_world_model) 108 | 109 | ######################################################################## 110 | # Inform world model of the outcome for this timestep 111 | # (s_t, a_t, s_{t+1}, r, t, ua_t) 112 | ######################################################################## 113 | 114 | if record_training_data_for_user and not self.use_world_model: 115 | self.world_model.register_experience_replay_tuple(self.state_user, self.agent.action, 116 | self.state_user_next, self.reward, self.episode_over, 117 | self.user_action) 118 | 119 | return (self.episode_over, self.reward) 120 | 121 | def reward_function(self, dialog_status): 122 | """ Reward Function 1: a reward function based on the dialog_status """ 123 | if dialog_status == dialog_config.FAILED_DIALOG: 124 | reward = -self.user.max_turn # 10 125 | elif dialog_status == dialog_config.SUCCESS_DIALOG: 126 | reward = 2 * self.user.max_turn # 20 127 | else: 128 | reward = -1 129 | return reward 130 | 131 | def reward_function_without_penalty(self, dialog_status): 132 | """ Reward Function 2: a reward function without penalty on per turn and failure dialog """ 133 | if dialog_status == dialog_config.FAILED_DIALOG: 134 | reward = 0 135 | elif dialog_status == dialog_config.SUCCESS_DIALOG: 136 | reward = 2 * self.user.max_turn 137 | else: 138 | reward = 0 139 | return reward 140 | 141 | def print_function(self, agent_action=None, user_action=None): 142 | """ Print Function """ 143 | 144 | if agent_action: 145 | if dialog_config.run_mode == 0: 146 | if self.agent.__class__.__name__ != 'AgentCmd': 147 | print ("Turn %d sys: %s" % (agent_action['turn'], agent_action['nl'])) 148 | elif dialog_config.run_mode == 1: 149 | if self.agent.__class__.__name__ != 'AgentCmd': 150 | print("Turn %d sys: %s, inform_slots: %s, request slots: %s" % ( 151 | agent_action['turn'], agent_action['diaact'], agent_action['inform_slots'], 152 | agent_action['request_slots'])) 153 | elif dialog_config.run_mode == 2: # debug mode 154 | print("Turn %d sys: %s, inform_slots: %s, request slots: %s" % ( 155 | agent_action['turn'], agent_action['diaact'], agent_action['inform_slots'], 156 | agent_action['request_slots'])) 157 | print ("Turn %d sys: %s" % (agent_action['turn'], agent_action['nl'])) 158 | 159 | if dialog_config.auto_suggest == 1: 160 | print( 161 | '(Suggested Values: %s)' % ( 162 | self.state_tracker.get_suggest_slots_values(agent_action['request_slots']))) 163 | elif user_action: 164 | if dialog_config.run_mode == 0: 165 | print ("Turn %d usr: %s" % (user_action['turn'], user_action['nl'])) 166 | elif dialog_config.run_mode == 1: 167 | print ("Turn %s usr: %s, inform_slots: %s, request_slots: %s" % ( 168 | user_action['turn'], user_action['diaact'], user_action['inform_slots'], 169 | user_action['request_slots'])) 170 | elif dialog_config.run_mode == 2: # debug mode, show both 171 | print ("Turn %d usr: %s, inform_slots: %s, request_slots: %s" % ( 172 | user_action['turn'], user_action['diaact'], user_action['inform_slots'], 173 | user_action['request_slots'])) 174 | print ("Turn %d usr: %s" % (user_action['turn'], user_action['nl'])) 175 | 176 | if self.agent.__class__.__name__ == 'AgentCmd': # command line agent 177 | user_request_slots = user_action['request_slots'] 178 | if 'ticket' in user_request_slots.keys(): del user_request_slots['ticket'] 179 | if len(user_request_slots) > 0: 180 | possible_values = self.state_tracker.get_suggest_slots_values(user_action['request_slots']) 181 | for slot in possible_values.keys(): 182 | if len(possible_values[slot]) > 0: 183 | print('(Suggested Values: %s: %s)' % (slot, possible_values[slot])) 184 | elif len(possible_values[slot]) == 0: 185 | print('(Suggested Values: there is no available %s)' % (slot)) 186 | else: 187 | kb_results = self.state_tracker.get_current_kb_results() 188 | print ('(Number of movies in KB satisfying current constraints: %s)' % len(kb_results)) 189 | -------------------------------------------------------------------------------- /src/deep_dialog/dialog_system/dict_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 18, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | 8 | def text_to_dict(path): 9 | """ Read in a text file as a dictionary where keys are text and values are indices (line numbers) """ 10 | 11 | slot_set = {} 12 | with open(path, 'r') as f: 13 | index = 0 14 | for line in f.readlines(): 15 | slot_set[line.strip('\n').strip('\r')] = index 16 | index += 1 17 | return slot_set -------------------------------------------------------------------------------- /src/deep_dialog/dialog_system/kb_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 18, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | import copy 8 | from collections import defaultdict 9 | from deep_dialog import dialog_config 10 | 11 | class KBHelper: 12 | """ An assistant to fill in values for the agent (which knows about slots of values) """ 13 | 14 | def __init__(self, movie_dictionary): 15 | """ Constructor for a KBHelper """ 16 | 17 | self.movie_dictionary = movie_dictionary 18 | self.cached_kb = defaultdict(list) 19 | self.cached_kb_slot = defaultdict(list) 20 | 21 | 22 | def fill_inform_slots(self, inform_slots_to_be_filled, current_slots): 23 | """ Takes unfilled inform slots and current_slots, returns dictionary of filled informed slots (with values) 24 | 25 | Arguments: 26 | inform_slots_to_be_filled -- Something that looks like {starttime:None, theater:None} where starttime and theater are slots that the agent needs filled 27 | current_slots -- Contains a record of all filled slots in the conversation so far - for now, just use current_slots['inform_slots'] which is a dictionary of the already filled-in slots 28 | 29 | Returns: 30 | filled_in_slots -- A dictionary of form {slot1:value1, slot2:value2} for each sloti in inform_slots_to_be_filled 31 | """ 32 | 33 | kb_results = self.available_results_from_kb(current_slots) 34 | if dialog_config.auto_suggest == 1: 35 | print 'Number of movies in KB satisfying current constraints: ', len(kb_results) 36 | 37 | filled_in_slots = {} 38 | if 'taskcomplete' in inform_slots_to_be_filled.keys(): 39 | filled_in_slots.update(current_slots['inform_slots']) 40 | 41 | for slot in inform_slots_to_be_filled.keys(): 42 | if slot == 'numberofpeople': 43 | if slot in current_slots['inform_slots'].keys(): 44 | filled_in_slots[slot] = current_slots['inform_slots'][slot] 45 | elif slot in inform_slots_to_be_filled.keys(): 46 | filled_in_slots[slot] = inform_slots_to_be_filled[slot] 47 | continue 48 | 49 | if slot == 'ticket' or slot == 'taskcomplete': 50 | filled_in_slots[slot] = dialog_config.TICKET_AVAILABLE if len(kb_results)>0 else dialog_config.NO_VALUE_MATCH 51 | continue 52 | 53 | if slot == 'closing': continue 54 | 55 | #################################################################### 56 | # Grab the value for the slot with the highest count and fill it 57 | #################################################################### 58 | values_dict = self.available_slot_values(slot, kb_results) 59 | 60 | values_counts = [(v, values_dict[v]) for v in values_dict.keys()] 61 | if len(values_counts) > 0: 62 | filled_in_slots[slot] = sorted(values_counts, key = lambda x: -x[1])[0][0] 63 | else: 64 | filled_in_slots[slot] = dialog_config.NO_VALUE_MATCH #"NO VALUE MATCHES SNAFU!!!" 65 | 66 | return filled_in_slots 67 | 68 | 69 | def available_slot_values(self, slot, kb_results): 70 | """ Return the set of values available for the slot based on the current constraints """ 71 | 72 | slot_values = {} 73 | for movie_id in kb_results.keys(): 74 | if slot in kb_results[movie_id].keys(): 75 | slot_val = kb_results[movie_id][slot] 76 | if slot_val in slot_values.keys(): 77 | slot_values[slot_val] += 1 78 | else: slot_values[slot_val] = 1 79 | return slot_values 80 | 81 | def available_results_from_kb(self, current_slots): 82 | """ Return the available movies in the movie_kb based on the current constraints """ 83 | 84 | ret_result = [] 85 | current_slots = current_slots['inform_slots'] 86 | constrain_keys = current_slots.keys() 87 | 88 | constrain_keys = filter(lambda k : k != 'ticket' and \ 89 | k != 'numberofpeople' and \ 90 | k!= 'taskcomplete' and \ 91 | k != 'closing' , constrain_keys) 92 | constrain_keys = [k for k in constrain_keys if current_slots[k] != dialog_config.I_DO_NOT_CARE] 93 | 94 | query_idx_keys = frozenset(current_slots.items()) 95 | cached_kb_ret = self.cached_kb[query_idx_keys] 96 | 97 | cached_kb_length = len(cached_kb_ret) if cached_kb_ret != None else -1 98 | if cached_kb_length > 0: 99 | return dict(cached_kb_ret) 100 | elif cached_kb_length == -1: 101 | return dict([]) 102 | 103 | # kb_results = copy.deepcopy(self.movie_dictionary) 104 | for id in self.movie_dictionary.keys(): 105 | kb_keys = self.movie_dictionary[id].keys() 106 | if len(set(constrain_keys).union(set(kb_keys)) ^ (set(constrain_keys) ^ set(kb_keys))) == len( 107 | constrain_keys): 108 | match = True 109 | for idx, k in enumerate(constrain_keys): 110 | if str(current_slots[k]).lower() == str(self.movie_dictionary[id][k]).lower(): 111 | continue 112 | else: 113 | match = False 114 | if match: 115 | self.cached_kb[query_idx_keys].append((id, self.movie_dictionary[id])) 116 | ret_result.append((id, self.movie_dictionary[id])) 117 | 118 | # for slot in current_slots['inform_slots'].keys(): 119 | # if slot == 'ticket' or slot == 'numberofpeople' or slot == 'taskcomplete' or slot == 'closing': continue 120 | # if current_slots['inform_slots'][slot] == dialog_config.I_DO_NOT_CARE: continue 121 | # 122 | # if slot not in self.movie_dictionary[movie_id].keys(): 123 | # if movie_id in kb_results.keys(): 124 | # del kb_results[movie_id] 125 | # else: 126 | # if current_slots['inform_slots'][slot].lower() != self.movie_dictionary[movie_id][slot].lower(): 127 | # if movie_id in kb_results.keys(): 128 | # del kb_results[movie_id] 129 | 130 | if len(ret_result) == 0: 131 | self.cached_kb[query_idx_keys] = None 132 | 133 | ret_result = dict(ret_result) 134 | return ret_result 135 | 136 | def available_results_from_kb_for_slots(self, inform_slots): 137 | """ Return the count statistics for each constraint in inform_slots """ 138 | 139 | kb_results = {key:0 for key in inform_slots.keys()} 140 | kb_results['matching_all_constraints'] = 0 141 | return kb_results 142 | 143 | query_idx_keys = frozenset(inform_slots.items()) 144 | cached_kb_slot_ret = self.cached_kb_slot[query_idx_keys] 145 | 146 | if len(cached_kb_slot_ret) > 0: 147 | return cached_kb_slot_ret[0] 148 | 149 | for movie_id in self.movie_dictionary.keys(): 150 | all_slots_match = 1 151 | for slot in inform_slots.keys(): 152 | if slot == 'ticket' or inform_slots[slot] == dialog_config.I_DO_NOT_CARE: 153 | continue 154 | 155 | if slot in self.movie_dictionary[movie_id].keys(): 156 | if inform_slots[slot].lower() == self.movie_dictionary[movie_id][slot].lower(): 157 | kb_results[slot] += 1 158 | else: 159 | all_slots_match = 0 160 | else: 161 | all_slots_match = 0 162 | kb_results['matching_all_constraints'] += all_slots_match 163 | 164 | self.cached_kb_slot[query_idx_keys].append(kb_results) 165 | return kb_results 166 | 167 | 168 | def database_results_for_agent(self, current_slots): 169 | """ A dictionary of the number of results matching each current constraint. The agent needs this to decide what to do next. """ 170 | 171 | database_results ={} # { date:100, distanceconstraints:60, theater:30, matching_all_constraints: 5} 172 | database_results = self.available_results_from_kb_for_slots(current_slots['inform_slots']) 173 | return database_results 174 | 175 | def suggest_slot_values(self, request_slots, current_slots): 176 | """ Return the suggest slot values """ 177 | 178 | avail_kb_results = self.available_results_from_kb(current_slots) 179 | return_suggest_slot_vals = {} 180 | for slot in request_slots.keys(): 181 | avail_values_dict = self.available_slot_values(slot, avail_kb_results) 182 | values_counts = [(v, avail_values_dict[v]) for v in avail_values_dict.keys()] 183 | 184 | if len(values_counts) > 0: 185 | return_suggest_slot_vals[slot] = [] 186 | sorted_dict = sorted(values_counts, key = lambda x: -x[1]) 187 | for k in sorted_dict: return_suggest_slot_vals[slot].append(k[0]) 188 | else: 189 | return_suggest_slot_vals[slot] = [] 190 | 191 | return return_suggest_slot_vals -------------------------------------------------------------------------------- /src/deep_dialog/dialog_system/state_tracker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 20, 2016 3 | 4 | state tracker 5 | 6 | @author: xiul, t-zalipt 7 | """ 8 | 9 | from . import KBHelper 10 | import numpy as np 11 | import copy 12 | 13 | 14 | class StateTracker: 15 | """ The state tracker maintains a record of which request slots are filled and which inform slots are filled """ 16 | 17 | def __init__(self, act_set, slot_set, movie_dictionary): 18 | """ constructor for statetracker takes movie knowledge base and initializes a new episode 19 | 20 | Arguments: 21 | act_set -- The set of all acts availavle 22 | slot_set -- The total set of available slots 23 | movie_dictionary -- A representation of all the available movies. Generally this object is accessed via the KBHelper class 24 | 25 | Class Variables: 26 | history_vectors -- A record of the current dialog so far in vector format (act-slot, but no values) 27 | history_dictionaries -- A record of the current dialog in dictionary format 28 | current_slots -- A dictionary that keeps a running record of which slots are filled current_slots['inform_slots'] and which are requested current_slots['request_slots'] (but not filed) 29 | action_dimension -- # TODO indicates the dimensionality of the vector representaiton of the action 30 | kb_result_dimension -- A single integer denoting the dimension of the kb_results features. 31 | turn_count -- A running count of which turn we are at in the present dialog 32 | """ 33 | self.movie_dictionary = movie_dictionary 34 | self.initialize_episode() 35 | self.history_vectors = None 36 | self.history_dictionaries = None 37 | self.current_slots = None 38 | self.action_dimension = 10 # TODO REPLACE WITH REAL VALUE 39 | self.kb_result_dimension = 10 # TODO REPLACE WITH REAL VALUE 40 | self.turn_count = 0 41 | self.kb_helper = KBHelper(movie_dictionary) 42 | 43 | 44 | def initialize_episode(self): 45 | """ Initialize a new episode (dialog), flush the current state and tracked slots """ 46 | 47 | self.action_dimension = 10 48 | self.history_vectors = np.zeros((1, self.action_dimension)) 49 | self.history_dictionaries = [] 50 | self.turn_count = 0 51 | self.current_slots = {} 52 | 53 | self.current_slots['inform_slots'] = {} 54 | self.current_slots['request_slots'] = {} 55 | self.current_slots['proposed_slots'] = {} 56 | self.current_slots['agent_request_slots'] = {} 57 | 58 | 59 | def dialog_history_vectors(self): 60 | """ Return the dialog history (both user and agent actions) in vector representation """ 61 | return self.history_vectors 62 | 63 | 64 | def dialog_history_dictionaries(self): 65 | """ Return the dictionary representation of the dialog history (includes values) """ 66 | return self.history_dictionaries 67 | 68 | 69 | def kb_results_for_state(self): 70 | """ Return the information about the database results based on the currently informed slots """ 71 | ######################################################################## 72 | # TODO Calculate results based on current informed slots 73 | ######################################################################## 74 | kb_results = self.kb_helper.database_results_for_agent(self.current_slots) # replace this with something less ridiculous 75 | # TODO turn results into vector (from dictionary) 76 | results = np.zeros((0, self.kb_result_dimension)) 77 | return results 78 | 79 | 80 | def get_state_for_agent(self): 81 | """ Get the state representatons to send to agent """ 82 | #state = {'user_action': self.history_dictionaries[-1], 'current_slots': self.current_slots, 'kb_results': self.kb_results_for_state()} 83 | state = {'user_action': self.history_dictionaries[-1], 'current_slots': self.current_slots, #'kb_results': self.kb_results_for_state(), 84 | 'kb_results_dict':self.kb_helper.database_results_for_agent(self.current_slots), 'turn': self.turn_count, 'history': self.history_dictionaries, 85 | 'agent_action': self.history_dictionaries[-2] if len(self.history_dictionaries) > 1 else None} 86 | return copy.deepcopy(state) 87 | 88 | def get_state_for_user(self): 89 | """ Get the state representatons to send to user """ 90 | #state = {'user_action': self.history_dictionaries[-1], 'current_slots': self.current_slots, 'kb_results': self.kb_results_for_state()} 91 | state = {'user_action': self.history_dictionaries[-2], 'current_slots': self.current_slots, #'kb_results': self.kb_results_for_state(), 92 | 'kb_results_dict':self.kb_helper.database_results_for_agent(self.current_slots), 'turn': self.turn_count, 'history': self.history_dictionaries, 93 | 'agent_action': self.history_dictionaries[-1] if len(self.history_dictionaries) > 1 else None} 94 | return copy.deepcopy(state) 95 | 96 | def get_suggest_slots_values(self, request_slots): 97 | """ Get the suggested values for request slots """ 98 | 99 | suggest_slot_vals = {} 100 | if len(request_slots) > 0: 101 | suggest_slot_vals = self.kb_helper.suggest_slot_values(request_slots, self.current_slots) 102 | 103 | return suggest_slot_vals 104 | 105 | def get_current_kb_results(self): 106 | """ get the kb_results for current state """ 107 | kb_results = self.kb_helper.available_results_from_kb(self.current_slots) 108 | return kb_results 109 | 110 | 111 | def update(self, agent_action=None, user_action=None): 112 | """ Update the state based on the latest action """ 113 | 114 | ######################################################################## 115 | # Make sure that the function was called properly 116 | ######################################################################## 117 | assert(not (user_action and agent_action)) 118 | assert(user_action or agent_action) 119 | 120 | ######################################################################## 121 | # Update state to reflect a new action by the agent. 122 | ######################################################################## 123 | if agent_action: 124 | 125 | #################################################################### 126 | # Handles the act_slot response (with values needing to be filled) 127 | #################################################################### 128 | if agent_action['act_slot_response']: 129 | response = copy.deepcopy(agent_action['act_slot_response']) 130 | 131 | inform_slots = self.kb_helper.fill_inform_slots(response['inform_slots'], self.current_slots) # TODO this doesn't actually work yet, remove this warning when kb_helper is functional 132 | agent_action_values = {'turn': self.turn_count, 'speaker': "agent", 'diaact': response['diaact'], 'inform_slots': inform_slots, 'request_slots':response['request_slots']} 133 | 134 | agent_action['act_slot_response'].update({'diaact': response['diaact'], 'inform_slots': inform_slots, 'request_slots':response['request_slots'], 'turn':self.turn_count}) 135 | 136 | elif agent_action['act_slot_value_response']: 137 | agent_action_values = copy.deepcopy(agent_action['act_slot_value_response']) 138 | # print("Updating state based on act_slot_value action from agent") 139 | agent_action_values['turn'] = self.turn_count 140 | agent_action_values['speaker'] = "agent" 141 | 142 | #################################################################### 143 | # This code should execute regardless of which kind of agent produced action 144 | #################################################################### 145 | for slot in agent_action_values['inform_slots'].keys(): 146 | self.current_slots['proposed_slots'][slot] = agent_action_values['inform_slots'][slot] 147 | self.current_slots['inform_slots'][slot] = agent_action_values['inform_slots'][slot] # add into inform_slots 148 | if slot in self.current_slots['request_slots'].keys(): 149 | del self.current_slots['request_slots'][slot] 150 | 151 | for slot in agent_action_values['request_slots'].keys(): 152 | if slot not in self.current_slots['agent_request_slots']: 153 | self.current_slots['agent_request_slots'][slot] = "UNK" 154 | 155 | self.history_dictionaries.append(agent_action_values) 156 | current_agent_vector = np.ones((1, self.action_dimension)) 157 | self.history_vectors = np.vstack([self.history_vectors, current_agent_vector]) 158 | 159 | ######################################################################## 160 | # Update the state to reflect a new action by the user 161 | ######################################################################## 162 | elif user_action: 163 | 164 | #################################################################### 165 | # Update the current slots 166 | #################################################################### 167 | for slot in user_action['inform_slots'].keys(): 168 | self.current_slots['inform_slots'][slot] = user_action['inform_slots'][slot] 169 | if slot in self.current_slots['request_slots'].keys(): 170 | del self.current_slots['request_slots'][slot] 171 | 172 | for slot in user_action['request_slots'].keys(): 173 | if slot not in self.current_slots['request_slots']: 174 | self.current_slots['request_slots'][slot] = "UNK" 175 | 176 | self.history_vectors = np.vstack([self.history_vectors, np.zeros((1,self.action_dimension))]) 177 | new_move = {'turn': self.turn_count, 'speaker': "user", 'request_slots': user_action['request_slots'], 'inform_slots': user_action['inform_slots'], 'diaact': user_action['diaact']} 178 | self.history_dictionaries.append(copy.deepcopy(new_move)) 179 | 180 | ######################################################################## 181 | # This should never happen if the asserts passed 182 | ######################################################################## 183 | else: 184 | pass 185 | 186 | ######################################################################## 187 | # This code should execute after update code regardless of what kind of action (agent/user) 188 | ######################################################################## 189 | self.turn_count += 1 -------------------------------------------------------------------------------- /src/deep_dialog/dialog_system/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on May 25, 2016 3 | 4 | @author: xiul, t-zalipt 5 | """ 6 | 7 | import numpy as np 8 | ################################################################################ 9 | # Some helper functions 10 | ################################################################################ 11 | 12 | def unique_states(training_data): 13 | unique = [] 14 | for datum in training_data: 15 | if contains(unique, datum[0]): 16 | pass 17 | else: 18 | unique.append(datum[0].copy()) 19 | return unique 20 | 21 | def contains(unique, candidate_state): 22 | for state in unique: 23 | if np.array_equal(state, candidate_state): 24 | return True 25 | else: 26 | pass 27 | return False 28 | -------------------------------------------------------------------------------- /src/deep_dialog/models/nlg/convert.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | model=cPickle.load(open('lstm_tanh_relu_[1468202263.38]_2_0.610.p')) 3 | cPickle.dump(model,open('model.bin.nlg','wb')) -------------------------------------------------------------------------------- /src/deep_dialog/models/nlu/convert.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | model=cPickle.load(open('lstm_[1468447442.91]_39_80_0.921.p')) 3 | cPickle.dump(model,open('model.bin.nlu','wb')) -------------------------------------------------------------------------------- /src/deep_dialog/nlg/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .nlg import * -------------------------------------------------------------------------------- /src/deep_dialog/nlg/decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | from .utils import * 8 | 9 | 10 | class decoder: 11 | def __init__(self, input_size, hidden_size, output_size): 12 | pass 13 | 14 | def get_struct(self): 15 | return {'model': self.model, 'update': self.update, 'regularize': self.regularize} 16 | 17 | 18 | """ Activation Function: Sigmoid, or tanh, or ReLu""" 19 | def fwdPass(self, Xs, params, **kwargs): 20 | pass 21 | 22 | def bwdPass(self, dY, cache): 23 | pass 24 | 25 | 26 | """ Batch Forward & Backward Pass""" 27 | def batchForward(self, ds, batch, params, predict_mode = False): 28 | caches = [] 29 | Ys = [] 30 | for i,x in enumerate(batch): 31 | Y, out_cache = self.fwdPass(x, params, predict_mode = predict_mode) 32 | caches.append(out_cache) 33 | Ys.append(Y) 34 | 35 | # back up information for efficient backprop 36 | cache = {} 37 | if not predict_mode: 38 | cache['caches'] = caches 39 | 40 | return Ys, cache 41 | 42 | def batchBackward(self, dY, cache): 43 | caches = cache['caches'] 44 | grads = {} 45 | for i in xrange(len(caches)): 46 | single_cache = caches[i] 47 | local_grads = self.bwdPass(dY[i], single_cache) 48 | mergeDicts(grads, local_grads) # add up the gradients wrt model parameters 49 | 50 | return grads 51 | 52 | 53 | """ Cost function, returns cost and gradients for model """ 54 | def costFunc(self, ds, batch, params): 55 | regc = params['reg_cost'] # regularization cost 56 | 57 | # batch forward RNN 58 | Ys, caches = self.batchForward(ds, batch, params, predict_mode = False) 59 | 60 | loss_cost = 0.0 61 | smooth_cost = 1e-15 62 | dYs = [] 63 | 64 | for i,x in enumerate(batch): 65 | labels = np.array(x['labels'], dtype=int) 66 | 67 | # fetch the predicted probabilities 68 | Y = Ys[i] 69 | maxes = np.amax(Y, axis=1, keepdims=True) 70 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 71 | P = e/np.sum(e, axis=1, keepdims=True) 72 | 73 | # Cross-Entropy Cross Function 74 | loss_cost += -np.sum(np.log(smooth_cost + P[range(len(labels)), labels])) 75 | 76 | for iy,y in enumerate(labels): 77 | P[iy,y] -= 1 # softmax derivatives 78 | dYs.append(P) 79 | 80 | # backprop the RNN 81 | grads = self.batchBackward(dYs, caches) 82 | 83 | # add L2 regularization cost and gradients 84 | reg_cost = 0.0 85 | if regc > 0: 86 | for p in self.regularize: 87 | mat = self.model[p] 88 | reg_cost += 0.5*regc*np.sum(mat*mat) 89 | grads[p] += regc*mat 90 | 91 | # normalize the cost and gradient by the batch size 92 | batch_size = len(batch) 93 | reg_cost /= batch_size 94 | loss_cost /= batch_size 95 | for k in grads: grads[k] /= batch_size 96 | 97 | out = {} 98 | out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost} 99 | out['grads'] = grads 100 | return out 101 | 102 | 103 | """ A single batch """ 104 | def singleBatch(self, ds, batch, params): 105 | learning_rate = params.get('learning_rate', 0.0) 106 | decay_rate = params.get('decay_rate', 0.999) 107 | momentum = params.get('momentum', 0) 108 | grad_clip = params.get('grad_clip', 1) 109 | smooth_eps = params.get('smooth_eps', 1e-8) 110 | sdg_type = params.get('sdgtype', 'rmsprop') 111 | 112 | for u in self.update: 113 | if not u in self.step_cache: 114 | self.step_cache[u] = np.zeros(self.model[u].shape) 115 | 116 | cg = self.costFunc(ds, batch, params) 117 | 118 | cost = cg['cost'] 119 | grads = cg['grads'] 120 | 121 | # clip gradients if needed 122 | if params['activation_func'] == 'relu': 123 | if grad_clip > 0: 124 | for p in self.update: 125 | if p in grads: 126 | grads[p] = np.minimum(grads[p], grad_clip) 127 | grads[p] = np.maximum(grads[p], -grad_clip) 128 | 129 | # perform parameter update 130 | for p in self.update: 131 | if p in grads: 132 | if sdg_type == 'vanilla': 133 | if momentum > 0: dx = momentum*self.step_cache[p] - learning_rate*grads[p] 134 | else: dx = -learning_rate*grads[p] 135 | self.step_cache[p] = dx 136 | elif sdg_type == 'rmsprop': 137 | self.step_cache[p] = self.step_cache[p]*decay_rate + (1.0-decay_rate)*grads[p]**2 138 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 139 | elif sdg_type == 'adgrad': 140 | self.step_cache[p] += grads[p]**2 141 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 142 | 143 | self.model[p] += dx 144 | 145 | # create output dict and return 146 | out = {} 147 | out['cost'] = cost 148 | return out 149 | 150 | 151 | """ Evaluate on the dataset[split] """ 152 | def eval(self, ds, split, params): 153 | acc = 0 154 | total = 0 155 | 156 | total_cost = 0.0 157 | smooth_cost = 1e-15 158 | perplexity = 0 159 | 160 | for i, ele in enumerate(ds.split[split]): 161 | #ele_reps = self.prepare_input_rep(ds, [ele], params) 162 | #Ys, cache = self.fwdPass(ele_reps[0], params, predict_model=True) 163 | #labels = np.array(ele_reps[0]['labels'], dtype=int) 164 | 165 | Ys, cache = self.fwdPass(ele, params, predict_model=True) 166 | 167 | maxes = np.amax(Ys, axis=1, keepdims=True) 168 | e = np.exp(Ys - maxes) # for numerical stability shift into good numerical range 169 | probs = e/np.sum(e, axis=1, keepdims=True) 170 | 171 | labels = np.array(ele['labels'], dtype=int) 172 | 173 | if np.all(np.isnan(probs)): probs = np.zeros(probs.shape) 174 | 175 | log_perplex = 0 176 | log_perplex += -np.sum(np.log2(smooth_cost + probs[range(len(labels)), labels])) 177 | log_perplex /= len(labels) 178 | 179 | loss_cost = 0 180 | loss_cost += -np.sum(np.log(smooth_cost + probs[range(len(labels)), labels])) 181 | 182 | perplexity += log_perplex #2**log_perplex 183 | total_cost += loss_cost 184 | 185 | pred_words_indices = np.nanargmax(probs, axis=1) 186 | for index, l in enumerate(labels): 187 | if pred_words_indices[index] == l: 188 | acc += 1 189 | 190 | total += len(labels) 191 | 192 | perplexity /= len(ds.split[split]) 193 | total_cost /= len(ds.split[split]) 194 | accuracy = 0 if total == 0 else float(acc)/total 195 | 196 | #print ("perplexity: %s, total_cost: %s, accuracy: %s" % (perplexity, total_cost, accuracy)) 197 | result = {'perplexity': perplexity, 'cost': total_cost, 'accuracy': accuracy} 198 | return result 199 | 200 | 201 | 202 | """ prediction on dataset[split] """ 203 | def predict(self, ds, split, params): 204 | inverse_word_dict = {ds.data['word_dict'][k]:k for k in ds.data['word_dict'].keys()} 205 | for i, ele in enumerate(ds.split[split]): 206 | pred_ys, pred_words = self.forward(inverse_word_dict, ele, params, predict_model=True) 207 | 208 | sentence = ' '.join(pred_words[:-1]) 209 | real_sentence = ' '.join(ele['sentence'].split(' ')[1:-1]) 210 | 211 | if params['dia_slot_val'] == 2 or params['dia_slot_val'] == 3: 212 | sentence = self.post_process(sentence, ele['slotval'], ds.data['slot_dict']) 213 | 214 | print 'test case', i 215 | print 'real:', real_sentence 216 | print 'pred:', sentence 217 | 218 | """ post_process to fill the slot """ 219 | def post_process(self, pred_template, slot_val_dict, slot_dict): 220 | sentence = pred_template 221 | suffix = "_PLACEHOLDER" 222 | 223 | for slot in slot_val_dict.keys(): 224 | slot_vals = slot_val_dict[slot] 225 | slot_placeholder = slot + suffix 226 | if slot == 'result' or slot == 'numberofpeople': continue 227 | for slot_val in slot_vals: 228 | tmp_sentence = sentence.replace(slot_placeholder, slot_val, 1) 229 | sentence = tmp_sentence 230 | 231 | if 'numberofpeople' in slot_val_dict.keys(): 232 | slot_vals = slot_val_dict['numberofpeople'] 233 | slot_placeholder = 'numberofpeople' + suffix 234 | for slot_val in slot_vals: 235 | tmp_sentence = sentence.replace(slot_placeholder, slot_val, 1) 236 | sentence = tmp_sentence 237 | 238 | for slot in slot_dict.keys(): 239 | slot_placeholder = slot + suffix 240 | tmp_sentence = sentence.replace(slot_placeholder, '') 241 | sentence = tmp_sentence 242 | 243 | return sentence -------------------------------------------------------------------------------- /src/deep_dialog/nlg/lstm_decoder_tanh.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | An LSTM decoder - add tanh after cell before output gate 5 | 6 | @author: xiul 7 | ''' 8 | 9 | from .decoder import decoder 10 | from .utils import * 11 | 12 | 13 | class lstm_decoder_tanh(decoder): 14 | def __init__(self, diaact_input_size, input_size, hidden_size, output_size): 15 | self.model = {} 16 | # connections from diaact to hidden layer 17 | self.model['Wah'] = initWeights(diaact_input_size, 4*hidden_size) 18 | self.model['bah'] = np.zeros((1, 4*hidden_size)) 19 | 20 | # Recurrent weights: take x_t, h_{t-1}, and bias unit, and produce the 3 gates and the input to cell signal 21 | self.model['WLSTM'] = initWeights(input_size + hidden_size + 1, 4*hidden_size) 22 | # Hidden-Output Connections 23 | self.model['Wd'] = initWeights(hidden_size, output_size)*0.1 24 | self.model['bd'] = np.zeros((1, output_size)) 25 | 26 | self.update = ['Wah', 'bah', 'WLSTM', 'Wd', 'bd'] 27 | self.regularize = ['Wah', 'WLSTM', 'Wd'] 28 | 29 | self.step_cache = {} 30 | 31 | """ Activation Function: Sigmoid, or tanh, or ReLu """ 32 | def fwdPass(self, Xs, params, **kwargs): 33 | predict_mode = kwargs.get('predict_mode', False) 34 | feed_recurrence = params.get('feed_recurrence', 0) 35 | 36 | Ds = Xs['diaact'] 37 | Ws = Xs['words'] 38 | 39 | # diaact input layer to hidden layer 40 | Wah = self.model['Wah'] 41 | bah = self.model['bah'] 42 | Dsh = Ds.dot(Wah) + bah 43 | 44 | WLSTM = self.model['WLSTM'] 45 | n, xd = Ws.shape 46 | 47 | d = self.model['Wd'].shape[0] # size of hidden layer 48 | Hin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias 49 | Hout = np.zeros((n, d)) 50 | IFOG = np.zeros((n, 4*d)) 51 | IFOGf = np.zeros((n, 4*d)) # after nonlinearity 52 | Cellin = np.zeros((n, d)) 53 | Cellout = np.zeros((n, d)) 54 | 55 | for t in xrange(n): 56 | prev = np.zeros(d) if t==0 else Hout[t-1] 57 | Hin[t,0] = 1 # bias 58 | Hin[t, 1:1+xd] = Ws[t] 59 | Hin[t, 1+xd:] = prev 60 | 61 | # compute all gate activations. dots: 62 | IFOG[t] = Hin[t].dot(WLSTM) 63 | 64 | # add diaact vector here 65 | if feed_recurrence == 0: 66 | if t == 0: IFOG[t] += Dsh[0] 67 | else: 68 | IFOG[t] += Dsh[0] 69 | 70 | IFOGf[t, :3*d] = 1/(1+np.exp(-IFOG[t, :3*d])) # sigmoids; these are three gates 71 | IFOGf[t, 3*d:] = np.tanh(IFOG[t, 3*d:]) # tanh for input value 72 | 73 | Cellin[t] = IFOGf[t, :d] * IFOGf[t, 3*d:] 74 | if t>0: Cellin[t] += IFOGf[t, d:2*d]*Cellin[t-1] 75 | 76 | Cellout[t] = np.tanh(Cellin[t]) 77 | 78 | Hout[t] = IFOGf[t, 2*d:3*d] * Cellout[t] 79 | 80 | Wd = self.model['Wd'] 81 | bd = self.model['bd'] 82 | 83 | Y = Hout.dot(Wd)+bd 84 | 85 | cache = {} 86 | if not predict_mode: 87 | cache['WLSTM'] = WLSTM 88 | cache['Hout'] = Hout 89 | cache['WLSTM'] = WLSTM 90 | cache['Wd'] = Wd 91 | cache['IFOGf'] = IFOGf 92 | cache['IFOG'] = IFOG 93 | cache['Cellin'] = Cellin 94 | cache['Cellout'] = Cellout 95 | cache['Ws'] = Ws 96 | cache['Ds'] = Ds 97 | cache['Hin'] = Hin 98 | cache['Dsh'] = Dsh 99 | cache['Wah'] = Wah 100 | cache['feed_recurrence'] = feed_recurrence 101 | 102 | return Y, cache 103 | 104 | """ Forward pass on prediction """ 105 | def forward(self, dict, Xs, params, **kwargs): 106 | max_len = params.get('max_len', 30) 107 | feed_recurrence = params.get('feed_recurrence', 0) 108 | decoder_sampling = params.get('decoder_sampling', 0) 109 | 110 | Ds = Xs['diaact'] 111 | Ws = Xs['words'] 112 | 113 | # diaact input layer to hidden layer 114 | Wah = self.model['Wah'] 115 | bah = self.model['bah'] 116 | Dsh = Ds.dot(Wah) + bah 117 | 118 | WLSTM = self.model['WLSTM'] 119 | xd = Ws.shape[1] 120 | 121 | d = self.model['Wd'].shape[0] # size of hidden layer 122 | Hin = np.zeros((1, WLSTM.shape[0])) # xt, ht-1, bias 123 | Hout = np.zeros((1, d)) 124 | IFOG = np.zeros((1, 4*d)) 125 | IFOGf = np.zeros((1, 4*d)) # after nonlinearity 126 | Cellin = np.zeros((1, d)) 127 | Cellout = np.zeros((1, d)) 128 | 129 | Wd = self.model['Wd'] 130 | bd = self.model['bd'] 131 | 132 | Hin[0,0] = 1 # bias 133 | Hin[0,1:1+xd] = Ws[0] 134 | 135 | IFOG[0] = Hin[0].dot(WLSTM) 136 | IFOG[0] += Dsh[0] 137 | 138 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 139 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 140 | 141 | Cellin[0] = IFOGf[0, :d] * IFOGf[0, 3*d:] 142 | Cellout[0] = np.tanh(Cellin[0]) 143 | Hout[0] = IFOGf[0, 2*d:3*d] * Cellout[0] 144 | 145 | pred_y = [] 146 | pred_words = [] 147 | 148 | Y = Hout.dot(Wd) + bd 149 | maxes = np.amax(Y, axis=1, keepdims=True) 150 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 151 | probs = e/np.sum(e, axis=1, keepdims=True) 152 | 153 | if decoder_sampling == 0: # sampling or argmax 154 | pred_y_index = np.nanargmax(Y) 155 | else: 156 | pred_y_index = np.random.choice(Y.shape[1], 1, p=probs[0])[0] 157 | pred_y.append(pred_y_index) 158 | pred_words.append(dict[pred_y_index]) 159 | 160 | time_stamp = 0 161 | while True: 162 | if dict[pred_y_index] == 'e_o_s' or time_stamp >= max_len: break 163 | 164 | X = np.zeros(xd) 165 | X[pred_y_index] = 1 166 | Hin[0,0] = 1 # bias 167 | Hin[0,1:1+xd] = X 168 | Hin[0, 1+xd:] = Hout[0] 169 | 170 | IFOG[0] = Hin[0].dot(WLSTM) 171 | if feed_recurrence == 1: 172 | IFOG[0] += Dsh[0] 173 | 174 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 175 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 176 | 177 | C = IFOGf[0, :d]*IFOGf[0, 3*d:] 178 | Cellin[0] = C + IFOGf[0, d:2*d]*Cellin[0] 179 | Cellout[0] = np.tanh(Cellin[0]) 180 | Hout[0] = IFOGf[0, 2*d:3*d]*Cellout[0] 181 | 182 | Y = Hout.dot(Wd) + bd 183 | maxes = np.amax(Y, axis=1, keepdims=True) 184 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 185 | probs = e/np.sum(e, axis=1, keepdims=True) 186 | 187 | if decoder_sampling == 0: 188 | pred_y_index = np.nanargmax(Y) 189 | else: 190 | pred_y_index = np.random.choice(Y.shape[1], 1, p=probs[0])[0] 191 | pred_y.append(pred_y_index) 192 | pred_words.append(dict[pred_y_index]) 193 | 194 | time_stamp += 1 195 | 196 | return pred_y, pred_words 197 | 198 | """ Forward pass on prediction with Beam Search """ 199 | def beam_forward(self, dict, Xs, params, **kwargs): 200 | max_len = params.get('max_len', 30) 201 | feed_recurrence = params.get('feed_recurrence', 0) 202 | beam_size = params.get('beam_size', 10) 203 | decoder_sampling = params.get('decoder_sampling', 0) 204 | 205 | Ds = Xs['diaact'] 206 | Ws = Xs['words'] 207 | 208 | # diaact input layer to hidden layer 209 | Wah = self.model['Wah'] 210 | bah = self.model['bah'] 211 | Dsh = Ds.dot(Wah) + bah 212 | 213 | WLSTM = self.model['WLSTM'] 214 | xd = Ws.shape[1] 215 | 216 | d = self.model['Wd'].shape[0] # size of hidden layer 217 | Hin = np.zeros((1, WLSTM.shape[0])) # xt, ht-1, bias 218 | Hout = np.zeros((1, d)) 219 | IFOG = np.zeros((1, 4*d)) 220 | IFOGf = np.zeros((1, 4*d)) # after nonlinearity 221 | Cellin = np.zeros((1, d)) 222 | Cellout = np.zeros((1, d)) 223 | 224 | Wd = self.model['Wd'] 225 | bd = self.model['bd'] 226 | 227 | Hin[0,0] = 1 # bias 228 | Hin[0,1:1+xd] = Ws[0] 229 | 230 | IFOG[0] = Hin[0].dot(WLSTM) 231 | IFOG[0] += Dsh[0] 232 | 233 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 234 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 235 | 236 | Cellin[0] = IFOGf[0, :d] * IFOGf[0, 3*d:] 237 | Cellout[0] = np.tanh(Cellin[0]) 238 | Hout[0] = IFOGf[0, 2*d:3*d] * Cellout[0] 239 | 240 | # keep a beam here 241 | beams = [] 242 | 243 | Y = Hout.dot(Wd) + bd 244 | maxes = np.amax(Y, axis=1, keepdims=True) 245 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 246 | probs = e/np.sum(e, axis=1, keepdims=True) 247 | 248 | # add beam search here 249 | if decoder_sampling == 0: # no sampling 250 | beam_candidate_t = (-probs[0]).argsort()[:beam_size] 251 | else: 252 | beam_candidate_t = np.random.choice(Y.shape[1], beam_size, p=probs[0]) 253 | #beam_candidate_t = (-probs[0]).argsort()[:beam_size] 254 | for ele in beam_candidate_t: 255 | beams.append((np.log(probs[0][ele]), [ele], [dict[ele]], Hout[0], Cellin[0])) 256 | 257 | #beams.sort(key=lambda x:x[0], reverse=True) 258 | #beams.sort(reverse = True) 259 | 260 | time_stamp = 0 261 | while True: 262 | beam_candidates = [] 263 | for b in beams: 264 | log_prob = b[0] 265 | pred_y_index = b[1][-1] 266 | cell_in = b[4] 267 | hout_prev = b[3] 268 | 269 | if b[2][-1] == "e_o_s": # this beam predicted end token. Keep in the candidates but don't expand it out any more 270 | beam_candidates.append(b) 271 | continue 272 | 273 | X = np.zeros(xd) 274 | X[pred_y_index] = 1 275 | Hin[0,0] = 1 # bias 276 | Hin[0,1:1+xd] = X 277 | Hin[0, 1+xd:] = hout_prev 278 | 279 | IFOG[0] = Hin[0].dot(WLSTM) 280 | if feed_recurrence == 1: IFOG[0] += Dsh[0] 281 | 282 | IFOGf[0, :3*d] = 1/(1+np.exp(-IFOG[0, :3*d])) # sigmoids; these are three gates 283 | IFOGf[0, 3*d:] = np.tanh(IFOG[0, 3*d:]) # tanh for input value 284 | 285 | C = IFOGf[0, :d]*IFOGf[0, 3*d:] 286 | cell_in = C + IFOGf[0, d:2*d]*cell_in 287 | cell_out = np.tanh(cell_in) 288 | hout_prev = IFOGf[0, 2*d:3*d]*cell_out 289 | 290 | Y = hout_prev.dot(Wd) + bd 291 | maxes = np.amax(Y, axis=1, keepdims=True) 292 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 293 | probs = e/np.sum(e, axis=1, keepdims=True) 294 | 295 | if decoder_sampling == 0: # no sampling 296 | beam_candidate_t = (-probs[0]).argsort()[:beam_size] 297 | else: 298 | beam_candidate_t = np.random.choice(Y.shape[1], beam_size, p=probs[0]) 299 | #beam_candidate_t = (-probs[0]).argsort()[:beam_size] 300 | for ele in beam_candidate_t: 301 | beam_candidates.append((log_prob+np.log(probs[0][ele]), np.append(b[1], ele), np.append(b[2], dict[ele]), hout_prev, cell_in)) 302 | 303 | beam_candidates.sort(key=lambda x:x[0], reverse=True) 304 | #beam_candidates.sort(reverse = True) # decreasing order 305 | beams = beam_candidates[:beam_size] 306 | time_stamp += 1 307 | 308 | if time_stamp >= max_len: break 309 | 310 | return beams[0][1], beams[0][2] 311 | 312 | """ Backward Pass """ 313 | def bwdPass(self, dY, cache): 314 | Wd = cache['Wd'] 315 | Hout = cache['Hout'] 316 | IFOG = cache['IFOG'] 317 | IFOGf = cache['IFOGf'] 318 | Cellin = cache['Cellin'] 319 | Cellout = cache['Cellout'] 320 | Hin = cache['Hin'] 321 | WLSTM = cache['WLSTM'] 322 | Ws = cache['Ws'] 323 | Ds = cache['Ds'] 324 | Dsh = cache['Dsh'] 325 | Wah = cache['Wah'] 326 | feed_recurrence = cache['feed_recurrence'] 327 | 328 | n,d = Hout.shape 329 | 330 | # backprop the hidden-output layer 331 | dWd = Hout.transpose().dot(dY) 332 | dbd = np.sum(dY, axis=0, keepdims = True) 333 | dHout = dY.dot(Wd.transpose()) 334 | 335 | # backprop the LSTM 336 | dIFOG = np.zeros(IFOG.shape) 337 | dIFOGf = np.zeros(IFOGf.shape) 338 | dWLSTM = np.zeros(WLSTM.shape) 339 | dHin = np.zeros(Hin.shape) 340 | dCellin = np.zeros(Cellin.shape) 341 | dCellout = np.zeros(Cellout.shape) 342 | dWs = np.zeros(Ws.shape) 343 | 344 | dDsh = np.zeros(Dsh.shape) 345 | 346 | for t in reversed(xrange(n)): 347 | dIFOGf[t,2*d:3*d] = Cellout[t] * dHout[t] 348 | dCellout[t] = IFOGf[t,2*d:3*d] * dHout[t] 349 | 350 | dCellin[t] += (1-Cellout[t]**2) * dCellout[t] 351 | 352 | if t>0: 353 | dIFOGf[t, d:2*d] = Cellin[t-1] * dCellin[t] 354 | dCellin[t-1] += IFOGf[t,d:2*d] * dCellin[t] 355 | 356 | dIFOGf[t, :d] = IFOGf[t,3*d:] * dCellin[t] 357 | dIFOGf[t,3*d:] = IFOGf[t, :d] * dCellin[t] 358 | 359 | # backprop activation functions 360 | dIFOG[t, 3*d:] = (1-IFOGf[t, 3*d:]**2) * dIFOGf[t, 3*d:] 361 | y = IFOGf[t, :3*d] 362 | dIFOG[t, :3*d] = (y*(1-y)) * dIFOGf[t, :3*d] 363 | 364 | # backprop matrix multiply 365 | dWLSTM += np.outer(Hin[t], dIFOG[t]) 366 | dHin[t] = dIFOG[t].dot(WLSTM.transpose()) 367 | 368 | if t > 0: dHout[t-1] += dHin[t,1+Ws.shape[1]:] 369 | 370 | if feed_recurrence == 0: 371 | if t == 0: dDsh[t] = dIFOG[t] 372 | else: 373 | dDsh[0] += dIFOG[t] 374 | 375 | # backprop to the diaact-hidden connections 376 | dWah = Ds.transpose().dot(dDsh) 377 | dbah = np.sum(dDsh, axis=0, keepdims = True) 378 | 379 | return {'Wah':dWah, 'bah':dbah, 'WLSTM':dWLSTM, 'Wd':dWd, 'bd':dbd} 380 | 381 | 382 | """ Batch data representation """ 383 | def prepare_input_rep(self, ds, batch, params): 384 | batch_reps = [] 385 | for i,x in enumerate(batch): 386 | batch_rep = {} 387 | 388 | vec = np.zeros((1, self.model['Wah'].shape[0])) 389 | vec[0][x['diaact_rep']] = 1 390 | for v in x['slotrep']: 391 | vec[0][v] = 1 392 | 393 | word_arr = x['sentence'].split(' ') 394 | word_vecs = np.zeros((len(word_arr), self.model['Wxh'].shape[0])) 395 | labels = [0] * (len(word_arr)-1) 396 | for w_index, w in enumerate(word_arr[:-1]): 397 | if w in ds.data['word_dict'].keys(): 398 | w_dict_index = ds.data['word_dict'][w] 399 | word_vecs[w_index][w_dict_index] = 1 400 | 401 | if word_arr[w_index+1] in ds.data['word_dict'].keys(): 402 | labels[w_index] = ds.data['word_dict'][word_arr[w_index+1]] 403 | 404 | batch_rep['diaact'] = vec 405 | batch_rep['words'] = word_vecs 406 | batch_rep['labels'] = labels 407 | batch_reps.append(batch_rep) 408 | return batch_reps -------------------------------------------------------------------------------- /src/deep_dialog/nlg/nlg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 17, 2016 3 | 4 | --dia_act_nl_pairs.v6.json: agt and usr have their own NL. 5 | 6 | 7 | @author: xiul 8 | ''' 9 | 10 | import cPickle as pickle 11 | import copy, argparse, json 12 | import numpy as np 13 | 14 | from deep_dialog import dialog_config 15 | from deep_dialog.nlg.lstm_decoder_tanh import lstm_decoder_tanh 16 | 17 | 18 | class nlg: 19 | def __init__(self): 20 | pass 21 | 22 | def post_process(self, pred_template, slot_val_dict, slot_dict): 23 | """ post_process to fill the slot in the template sentence """ 24 | 25 | sentence = pred_template 26 | suffix = "_PLACEHOLDER" 27 | 28 | for slot in slot_val_dict.keys(): 29 | slot_vals = slot_val_dict[slot] 30 | slot_placeholder = slot + suffix 31 | if slot == 'result' or slot == 'numberofpeople': continue 32 | if slot_vals == dialog_config.NO_VALUE_MATCH: continue 33 | tmp_sentence = sentence.replace(slot_placeholder, slot_vals, 1) 34 | sentence = tmp_sentence 35 | 36 | if 'numberofpeople' in slot_val_dict.keys(): 37 | slot_vals = slot_val_dict['numberofpeople'] 38 | slot_placeholder = 'numberofpeople' + suffix 39 | tmp_sentence = sentence.replace(slot_placeholder, slot_vals, 1) 40 | sentence = tmp_sentence 41 | 42 | for slot in slot_dict.keys(): 43 | slot_placeholder = slot + suffix 44 | tmp_sentence = sentence.replace(slot_placeholder, '') 45 | sentence = tmp_sentence 46 | 47 | return sentence 48 | 49 | 50 | def convert_diaact_to_nl(self, dia_act, turn_msg): 51 | """ Convert Dia_Act into NL: Rule + Model """ 52 | 53 | sentence = "" 54 | boolean_in = False 55 | 56 | # remove I do not care slot in task(complete) 57 | if dia_act['diaact'] == 'inform' and 'taskcomplete' in dia_act['inform_slots'].keys() and dia_act['inform_slots']['taskcomplete'] != dialog_config.NO_VALUE_MATCH: 58 | inform_slot_set = dia_act['inform_slots'].keys() 59 | for slot in inform_slot_set: 60 | if dia_act['inform_slots'][slot] == dialog_config.I_DO_NOT_CARE: del dia_act['inform_slots'][slot] 61 | 62 | if dia_act['diaact'] in self.diaact_nl_pairs['dia_acts'].keys(): 63 | for ele in self.diaact_nl_pairs['dia_acts'][dia_act['diaact']]: 64 | if set(ele['inform_slots']) == set(dia_act['inform_slots'].keys()) and set(ele['request_slots']) == set(dia_act['request_slots'].keys()): 65 | sentence = self.diaact_to_nl_slot_filling(dia_act, ele['nl'][turn_msg]) 66 | boolean_in = True 67 | break 68 | 69 | if dia_act['diaact'] == 'inform' and 'taskcomplete' in dia_act['inform_slots'].keys() and dia_act['inform_slots']['taskcomplete'] == dialog_config.NO_VALUE_MATCH: 70 | sentence = "Oh sorry, there is no ticket available." 71 | 72 | if boolean_in == False: sentence = self.translate_diaact(dia_act) 73 | return sentence 74 | 75 | 76 | def translate_diaact(self, dia_act): 77 | """ prepare the diaact into vector representation, and generate the sentence by Model """ 78 | 79 | word_dict = self.word_dict 80 | template_word_dict = self.template_word_dict 81 | act_dict = self.act_dict 82 | slot_dict = self.slot_dict 83 | inverse_word_dict = self.inverse_word_dict 84 | 85 | act_rep = np.zeros((1, len(act_dict))) 86 | act_rep[0, act_dict[dia_act['diaact']]] = 1.0 87 | 88 | slot_rep_bit = 2 89 | slot_rep = np.zeros((1, len(slot_dict)*slot_rep_bit)) 90 | 91 | suffix = "_PLACEHOLDER" 92 | if self.params['dia_slot_val'] == 2 or self.params['dia_slot_val'] == 3: 93 | word_rep = np.zeros((1, len(template_word_dict))) 94 | words = np.zeros((1, len(template_word_dict))) 95 | words[0, template_word_dict['s_o_s']] = 1.0 96 | else: 97 | word_rep = np.zeros((1, len(word_dict))) 98 | words = np.zeros((1, len(word_dict))) 99 | words[0, word_dict['s_o_s']] = 1.0 100 | 101 | for slot in dia_act['inform_slots'].keys(): 102 | slot_index = slot_dict[slot] 103 | slot_rep[0, slot_index*slot_rep_bit] = 1.0 104 | 105 | for slot_val in dia_act['inform_slots'][slot]: 106 | if self.params['dia_slot_val'] == 2: 107 | slot_placeholder = slot + suffix 108 | if slot_placeholder in template_word_dict.keys(): 109 | word_rep[0, template_word_dict[slot_placeholder]] = 1.0 110 | elif self.params['dia_slot_val'] == 1: 111 | if slot_val in word_dict.keys(): 112 | word_rep[0, word_dict[slot_val]] = 1.0 113 | 114 | for slot in dia_act['request_slots'].keys(): 115 | slot_index = slot_dict[slot] 116 | slot_rep[0, slot_index*slot_rep_bit + 1] = 1.0 117 | 118 | if self.params['dia_slot_val'] == 0 or self.params['dia_slot_val'] == 3: 119 | final_representation = np.hstack([act_rep, slot_rep]) 120 | else: # dia_slot_val = 1, 2 121 | final_representation = np.hstack([act_rep, slot_rep, word_rep]) 122 | 123 | dia_act_rep = {} 124 | dia_act_rep['diaact'] = final_representation 125 | dia_act_rep['words'] = words 126 | 127 | #pred_ys, pred_words = nlg_model['model'].forward(inverse_word_dict, dia_act_rep, nlg_model['params'], predict_model=True) 128 | pred_ys, pred_words = self.model.beam_forward(inverse_word_dict, dia_act_rep, self.params, predict_model=True) 129 | pred_sentence = ' '.join(pred_words[:-1]) 130 | sentence = self.post_process(pred_sentence, dia_act['inform_slots'], slot_dict) 131 | 132 | return sentence 133 | 134 | 135 | def load_nlg_model(self, model_path): 136 | """ load the trained NLG model """ 137 | 138 | model_params = pickle.load(open(model_path)) 139 | 140 | hidden_size = model_params['model']['Wd'].shape[0] 141 | output_size = model_params['model']['Wd'].shape[1] 142 | 143 | if model_params['params']['model'] == 'lstm_tanh': # lstm_tanh 144 | diaact_input_size = model_params['model']['Wah'].shape[0] 145 | input_size = model_params['model']['WLSTM'].shape[0] - hidden_size - 1 146 | rnnmodel = lstm_decoder_tanh(diaact_input_size, input_size, hidden_size, output_size) 147 | 148 | rnnmodel.model = copy.deepcopy(model_params['model']) 149 | model_params['params']['beam_size'] = dialog_config.nlg_beam_size 150 | 151 | self.model = rnnmodel 152 | self.word_dict = copy.deepcopy(model_params['word_dict']) 153 | self.template_word_dict = copy.deepcopy(model_params['template_word_dict']) 154 | self.slot_dict = copy.deepcopy(model_params['slot_dict']) 155 | self.act_dict = copy.deepcopy(model_params['act_dict']) 156 | self.inverse_word_dict = {self.template_word_dict[k]:k for k in self.template_word_dict.keys()} 157 | self.params = copy.deepcopy(model_params['params']) 158 | 159 | 160 | def diaact_to_nl_slot_filling(self, dia_act, template_sentence): 161 | """ Replace the slots with its values """ 162 | 163 | sentence = template_sentence 164 | counter = 0 165 | for slot in dia_act['inform_slots'].keys(): 166 | slot_val = dia_act['inform_slots'][slot] 167 | if slot_val == dialog_config.NO_VALUE_MATCH: 168 | sentence = slot + " is not available!" 169 | break 170 | elif slot_val == dialog_config.I_DO_NOT_CARE: 171 | counter += 1 172 | sentence = sentence.replace('$'+slot+'$', '', 1) 173 | continue 174 | 175 | sentence = sentence.replace('$'+slot+'$', slot_val, 1) 176 | 177 | if counter > 0 and counter == len(dia_act['inform_slots']): 178 | sentence = dialog_config.I_DO_NOT_CARE 179 | 180 | return sentence 181 | 182 | 183 | def load_predefine_act_nl_pairs(self, path): 184 | """ Load some pre-defined Dia_Act&NL Pairs from file """ 185 | 186 | self.diaact_nl_pairs = json.load(open(path, 'rb')) 187 | 188 | for key in self.diaact_nl_pairs['dia_acts'].keys(): 189 | for ele in self.diaact_nl_pairs['dia_acts'][key]: 190 | ele['nl']['usr'] = ele['nl']['usr'].encode('utf-8') # encode issue 191 | ele['nl']['agt'] = ele['nl']['agt'].encode('utf-8') # encode issue 192 | 193 | 194 | def main(params): 195 | pass 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | 201 | args = parser.parse_args() 202 | params = vars(args) 203 | 204 | print ("User Simulator Parameters:") 205 | print (json.dumps(params, indent=2)) 206 | 207 | main(params) 208 | -------------------------------------------------------------------------------- /src/deep_dialog/nlg/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | import math 8 | import numpy as np 9 | 10 | 11 | def initWeights(n,d): 12 | """ Initialization Strategy """ 13 | #scale_factor = 0.1 14 | scale_factor = math.sqrt(float(6)/(n + d)) 15 | return (np.random.rand(n,d)*2-1)*scale_factor 16 | 17 | def mergeDicts(d0, d1): 18 | """ for all k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ 19 | for k in d1: 20 | if k in d0: d0[k] += d1[k] 21 | else: d0[k] = d1[k] -------------------------------------------------------------------------------- /src/deep_dialog/nlu/__init__.py: -------------------------------------------------------------------------------- 1 | from .nlu import nlu 2 | from .bi_lstm import biLSTM 3 | from .lstm import lstm -------------------------------------------------------------------------------- /src/deep_dialog/nlu/bi_lstm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | An Bidirectional LSTM Seq2Seq model 5 | 6 | @author: xiul 7 | ''' 8 | 9 | from .seq_seq import SeqToSeq 10 | from .utils import * 11 | 12 | 13 | class biLSTM(SeqToSeq): 14 | def __init__(self, input_size, hidden_size, output_size): 15 | self.model = {} 16 | # Recurrent weights: take x_t, h_{t-1}, and bias unit, and produce the 3 gates and the input to cell signal 17 | self.model['WLSTM'] = initWeights(input_size + hidden_size + 1, 4*hidden_size) 18 | self.model['bWLSTM'] = initWeights(input_size + hidden_size + 1, 4*hidden_size) 19 | 20 | # Hidden-Output Connections 21 | self.model['Wd'] = initWeights(hidden_size, output_size)*0.1 22 | self.model['bd'] = np.zeros((1, output_size)) 23 | 24 | # Backward Hidden-Output Connections 25 | self.model['bWd'] = initWeights(hidden_size, output_size)*0.1 26 | self.model['bbd'] = np.zeros((1, output_size)) 27 | 28 | self.update = ['WLSTM', 'bWLSTM', 'Wd', 'bd', 'bWd', 'bbd'] 29 | self.regularize = ['WLSTM', 'bWLSTM', 'Wd', 'bWd'] 30 | 31 | self.step_cache = {} 32 | 33 | """ Activation Function: Sigmoid, or tanh, or ReLu """ 34 | def fwdPass(self, Xs, params, **kwargs): 35 | predict_mode = kwargs.get('predict_mode', False) 36 | 37 | Ws = Xs['word_vectors'] 38 | 39 | WLSTM = self.model['WLSTM'] 40 | bWLSTM = self.model['bWLSTM'] 41 | 42 | n, xd = Ws.shape 43 | 44 | d = self.model['Wd'].shape[0] # size of hidden layer 45 | Hin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias 46 | Hout = np.zeros((n, d)) 47 | IFOG = np.zeros((n, 4*d)) 48 | IFOGf = np.zeros((n, 4*d)) # after nonlinearity 49 | Cellin = np.zeros((n, d)) 50 | Cellout = np.zeros((n, d)) 51 | 52 | # backward 53 | bHin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias 54 | bHout = np.zeros((n, d)) 55 | bIFOG = np.zeros((n, 4*d)) 56 | bIFOGf = np.zeros((n, 4*d)) # after nonlinearity 57 | bCellin = np.zeros((n, d)) 58 | bCellout = np.zeros((n, d)) 59 | 60 | for t in xrange(n): 61 | prev = np.zeros(d) if t==0 else Hout[t-1] 62 | Hin[t,0] = 1 # bias 63 | Hin[t, 1:1+xd] = Ws[t] 64 | Hin[t, 1+xd:] = prev 65 | 66 | # compute all gate activations. dots: 67 | IFOG[t] = Hin[t].dot(WLSTM) 68 | 69 | IFOGf[t, :3*d] = 1/(1+np.exp(-IFOG[t, :3*d])) # sigmoids; these are three gates 70 | IFOGf[t, 3*d:] = np.tanh(IFOG[t, 3*d:]) # tanh for input value 71 | 72 | Cellin[t] = IFOGf[t, :d] * IFOGf[t, 3*d:] 73 | if t>0: Cellin[t] += IFOGf[t, d:2*d]*Cellin[t-1] 74 | 75 | Cellout[t] = np.tanh(Cellin[t]) 76 | Hout[t] = IFOGf[t, 2*d:3*d] * Cellout[t] 77 | 78 | # backward hidden layer 79 | b_t = n-1-t 80 | bprev = np.zeros(d) if t == 0 else bHout[b_t+1] 81 | bHin[b_t, 0] = 1 82 | bHin[b_t, 1:1+xd] = Ws[b_t] 83 | bHin[b_t, 1+xd:] = bprev 84 | 85 | bIFOG[b_t] = bHin[b_t].dot(bWLSTM) 86 | bIFOGf[b_t, :3*d] = 1/(1+np.exp(-bIFOG[b_t, :3*d])) 87 | bIFOGf[b_t, 3*d:] = np.tanh(bIFOG[b_t, 3*d:]) 88 | 89 | bCellin[b_t] = bIFOGf[b_t, :d] * bIFOGf[b_t, 3*d:] 90 | if t>0: bCellin[b_t] += bIFOGf[b_t, d:2*d] * bCellin[b_t+1] 91 | 92 | bCellout[b_t] = np.tanh(bCellin[b_t]) 93 | bHout[b_t] = bIFOGf[b_t, 2*d:3*d]*bCellout[b_t] 94 | 95 | Wd = self.model['Wd'] 96 | bd = self.model['bd'] 97 | fY = Hout.dot(Wd)+bd 98 | 99 | bWd = self.model['bWd'] 100 | bbd = self.model['bbd'] 101 | bY = bHout.dot(bWd)+bbd 102 | 103 | Y = fY + bY 104 | 105 | cache = {} 106 | if not predict_mode: 107 | cache['WLSTM'] = WLSTM 108 | cache['Hout'] = Hout 109 | cache['Wd'] = Wd 110 | cache['IFOGf'] = IFOGf 111 | cache['IFOG'] = IFOG 112 | cache['Cellin'] = Cellin 113 | cache['Cellout'] = Cellout 114 | cache['Hin'] = Hin 115 | 116 | cache['bWLSTM'] = bWLSTM 117 | cache['bHout'] = bHout 118 | cache['bWd'] = bWd 119 | cache['bIFOGf'] = bIFOGf 120 | cache['bIFOG'] = bIFOG 121 | cache['bCellin'] = bCellin 122 | cache['bCellout'] = bCellout 123 | cache['bHin'] = bHin 124 | 125 | cache['Ws'] = Ws 126 | 127 | return Y, cache 128 | 129 | """ Backward Pass """ 130 | def bwdPass(self, dY, cache): 131 | Wd = cache['Wd'] 132 | Hout = cache['Hout'] 133 | IFOG = cache['IFOG'] 134 | IFOGf = cache['IFOGf'] 135 | Cellin = cache['Cellin'] 136 | Cellout = cache['Cellout'] 137 | Hin = cache['Hin'] 138 | WLSTM = cache['WLSTM'] 139 | 140 | Ws = cache['Ws'] 141 | 142 | bWd = cache['bWd'] 143 | bHout = cache['bHout'] 144 | bIFOG = cache['bIFOG'] 145 | bIFOGf = cache['bIFOGf'] 146 | bCellin = cache['bCellin'] 147 | bCellout = cache['bCellout'] 148 | bHin = cache['bHin'] 149 | bWLSTM = cache['bWLSTM'] 150 | 151 | n,d = Hout.shape 152 | 153 | # backprop the hidden-output layer 154 | dWd = Hout.transpose().dot(dY) 155 | dbd = np.sum(dY, axis=0, keepdims = True) 156 | dHout = dY.dot(Wd.transpose()) 157 | 158 | # backprop the backward hidden-output layer 159 | dbWd = bHout.transpose().dot(dY) 160 | dbbd = np.sum(dY, axis=0, keepdims = True) 161 | dbHout = dY.dot(bWd.transpose()) 162 | 163 | # backprop the LSTM (forward layer) 164 | dIFOG = np.zeros(IFOG.shape) 165 | dIFOGf = np.zeros(IFOGf.shape) 166 | dWLSTM = np.zeros(WLSTM.shape) 167 | dHin = np.zeros(Hin.shape) 168 | dCellin = np.zeros(Cellin.shape) 169 | dCellout = np.zeros(Cellout.shape) 170 | 171 | # backward-layer 172 | dbIFOG = np.zeros(bIFOG.shape) 173 | dbIFOGf = np.zeros(bIFOGf.shape) 174 | dbWLSTM = np.zeros(bWLSTM.shape) 175 | dbHin = np.zeros(bHin.shape) 176 | dbCellin = np.zeros(bCellin.shape) 177 | dbCellout = np.zeros(bCellout.shape) 178 | 179 | for t in reversed(xrange(n)): 180 | dIFOGf[t,2*d:3*d] = Cellout[t] * dHout[t] 181 | dCellout[t] = IFOGf[t,2*d:3*d] * dHout[t] 182 | 183 | dCellin[t] += (1-Cellout[t]**2) * dCellout[t] 184 | 185 | if t>0: 186 | dIFOGf[t, d:2*d] = Cellin[t-1] * dCellin[t] 187 | dCellin[t-1] += IFOGf[t,d:2*d] * dCellin[t] 188 | 189 | dIFOGf[t, :d] = IFOGf[t,3*d:] * dCellin[t] 190 | dIFOGf[t,3*d:] = IFOGf[t, :d] * dCellin[t] 191 | 192 | # backprop activation functions 193 | dIFOG[t, 3*d:] = (1-IFOGf[t, 3*d:]**2) * dIFOGf[t, 3*d:] 194 | y = IFOGf[t, :3*d] 195 | dIFOG[t, :3*d] = (y*(1-y)) * dIFOGf[t, :3*d] 196 | 197 | # backprop matrix multiply 198 | dWLSTM += np.outer(Hin[t], dIFOG[t]) 199 | dHin[t] = dIFOG[t].dot(WLSTM.transpose()) 200 | 201 | if t>0: dHout[t-1] += dHin[t, 1+Ws.shape[1]:] 202 | 203 | # Backward Layer 204 | b_t = n-1-t 205 | dbIFOGf[b_t, 2*d:3*d] = bCellout[b_t] * dbHout[b_t] # output gate 206 | dbCellout[b_t] = bIFOGf[b_t, 2*d:3*d] * dbHout[b_t] # dCellout 207 | 208 | dbCellin[b_t] += (1-bCellout[b_t]**2) * dbCellout[b_t] 209 | 210 | if t>0: # dcell 211 | dbIFOGf[b_t, d:2*d] = bCellin[b_t+1] * dbCellin[b_t] # forgot gate 212 | dbCellin[b_t+1] += bIFOGf[b_t, d:2*d] * dbCellin[b_t] 213 | 214 | dbIFOGf[b_t, :d] = bIFOGf[b_t, 3*d:] * dbCellin[b_t] # input gate 215 | dbIFOGf[b_t, 3*d:] = bIFOGf[b_t, :d] * dbCellin[b_t] 216 | 217 | # backprop activation functions 218 | dbIFOG[b_t, 3*d:] = (1-bIFOGf[b_t, 3*d:]**2) * dbIFOGf[b_t, 3*d:] 219 | by = bIFOGf[b_t, :3*d] 220 | dbIFOG[b_t, :3*d] = (by*(1-by)) * dbIFOGf[b_t, :3*d] 221 | 222 | dbWLSTM += np.outer(bHin[b_t], dbIFOG[b_t]) 223 | dbHin[b_t] = dbIFOG[b_t].dot(bWLSTM.transpose()) 224 | 225 | if t>0: dbHout[b_t+1] += dbHin[b_t, 1+Ws.shape[1]:] 226 | 227 | return {'WLSTM':dWLSTM, 'Wd':dWd, 'bd':dbd, 'bWLSTM':dbWLSTM, 'bWd':dbWd, 'bbd':dbbd} -------------------------------------------------------------------------------- /src/deep_dialog/nlu/lstm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | An LSTM decoder - add tanh after cell before output gate 5 | 6 | @author: xiul 7 | ''' 8 | 9 | from seq_seq import SeqToSeq 10 | from .utils import * 11 | 12 | 13 | class lstm(SeqToSeq): 14 | def __init__(self, input_size, hidden_size, output_size): 15 | self.model = {} 16 | # Recurrent weights: take x_t, h_{t-1}, and bias unit, and produce the 3 gates and the input to cell signal 17 | self.model['WLSTM'] = initWeights(input_size + hidden_size + 1, 4*hidden_size) 18 | # Hidden-Output Connections 19 | self.model['Wd'] = initWeights(hidden_size, output_size)*0.1 20 | self.model['bd'] = np.zeros((1, output_size)) 21 | 22 | self.update = ['WLSTM', 'Wd', 'bd'] 23 | self.regularize = ['WLSTM', 'Wd'] 24 | 25 | self.step_cache = {} 26 | 27 | """ Activation Function: Sigmoid, or tanh, or ReLu """ 28 | def fwdPass(self, Xs, params, **kwargs): 29 | predict_mode = kwargs.get('predict_mode', False) 30 | 31 | Ws = Xs['word_vectors'] 32 | 33 | WLSTM = self.model['WLSTM'] 34 | n, xd = Ws.shape 35 | 36 | d = self.model['Wd'].shape[0] # size of hidden layer 37 | Hin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias 38 | Hout = np.zeros((n, d)) 39 | IFOG = np.zeros((n, 4*d)) 40 | IFOGf = np.zeros((n, 4*d)) # after nonlinearity 41 | Cellin = np.zeros((n, d)) 42 | Cellout = np.zeros((n, d)) 43 | 44 | for t in xrange(n): 45 | prev = np.zeros(d) if t==0 else Hout[t-1] 46 | Hin[t,0] = 1 # bias 47 | Hin[t, 1:1+xd] = Ws[t] 48 | Hin[t, 1+xd:] = prev 49 | 50 | # compute all gate activations. dots: 51 | IFOG[t] = Hin[t].dot(WLSTM) 52 | 53 | IFOGf[t, :3*d] = 1/(1+np.exp(-IFOG[t, :3*d])) # sigmoids; these are three gates 54 | IFOGf[t, 3*d:] = np.tanh(IFOG[t, 3*d:]) # tanh for input value 55 | 56 | Cellin[t] = IFOGf[t, :d] * IFOGf[t, 3*d:] 57 | if t>0: Cellin[t] += IFOGf[t, d:2*d]*Cellin[t-1] 58 | 59 | Cellout[t] = np.tanh(Cellin[t]) 60 | 61 | Hout[t] = IFOGf[t, 2*d:3*d] * Cellout[t] 62 | 63 | Wd = self.model['Wd'] 64 | bd = self.model['bd'] 65 | 66 | Y = Hout.dot(Wd)+bd 67 | 68 | cache = {} 69 | if not predict_mode: 70 | cache['WLSTM'] = WLSTM 71 | cache['Hout'] = Hout 72 | cache['Wd'] = Wd 73 | cache['IFOGf'] = IFOGf 74 | cache['IFOG'] = IFOG 75 | cache['Cellin'] = Cellin 76 | cache['Cellout'] = Cellout 77 | cache['Ws'] = Ws 78 | cache['Hin'] = Hin 79 | 80 | return Y, cache 81 | 82 | """ Backward Pass """ 83 | def bwdPass(self, dY, cache): 84 | Wd = cache['Wd'] 85 | Hout = cache['Hout'] 86 | IFOG = cache['IFOG'] 87 | IFOGf = cache['IFOGf'] 88 | Cellin = cache['Cellin'] 89 | Cellout = cache['Cellout'] 90 | Hin = cache['Hin'] 91 | WLSTM = cache['WLSTM'] 92 | Ws = cache['Ws'] 93 | 94 | n,d = Hout.shape 95 | 96 | # backprop the hidden-output layer 97 | dWd = Hout.transpose().dot(dY) 98 | dbd = np.sum(dY, axis=0, keepdims = True) 99 | dHout = dY.dot(Wd.transpose()) 100 | 101 | # backprop the LSTM 102 | dIFOG = np.zeros(IFOG.shape) 103 | dIFOGf = np.zeros(IFOGf.shape) 104 | dWLSTM = np.zeros(WLSTM.shape) 105 | dHin = np.zeros(Hin.shape) 106 | dCellin = np.zeros(Cellin.shape) 107 | dCellout = np.zeros(Cellout.shape) 108 | 109 | for t in reversed(xrange(n)): 110 | dIFOGf[t,2*d:3*d] = Cellout[t] * dHout[t] 111 | dCellout[t] = IFOGf[t,2*d:3*d] * dHout[t] 112 | 113 | dCellin[t] += (1-Cellout[t]**2) * dCellout[t] 114 | 115 | if t>0: 116 | dIFOGf[t, d:2*d] = Cellin[t-1] * dCellin[t] 117 | dCellin[t-1] += IFOGf[t,d:2*d] * dCellin[t] 118 | 119 | dIFOGf[t, :d] = IFOGf[t,3*d:] * dCellin[t] 120 | dIFOGf[t,3*d:] = IFOGf[t, :d] * dCellin[t] 121 | 122 | # backprop activation functions 123 | dIFOG[t, 3*d:] = (1-IFOGf[t, 3*d:]**2) * dIFOGf[t, 3*d:] 124 | y = IFOGf[t, :3*d] 125 | dIFOG[t, :3*d] = (y*(1-y)) * dIFOGf[t, :3*d] 126 | 127 | # backprop matrix multiply 128 | dWLSTM += np.outer(Hin[t], dIFOG[t]) 129 | dHin[t] = dIFOG[t].dot(WLSTM.transpose()) 130 | 131 | if t > 0: dHout[t-1] += dHin[t, 1+Ws.shape[1]:] 132 | 133 | #dXs = dXsh.dot(Wxh.transpose()) 134 | return {'WLSTM':dWLSTM, 'Wd':dWd, 'bd':dbd} -------------------------------------------------------------------------------- /src/deep_dialog/nlu/nlu.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jul 13, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | import cPickle as pickle 8 | import copy 9 | import numpy as np 10 | 11 | from lstm import lstm 12 | from bi_lstm import biLSTM 13 | 14 | 15 | class nlu: 16 | def __init__(self): 17 | pass 18 | 19 | def generate_dia_act(self, annot): 20 | """ generate the Dia-Act with NLU model """ 21 | 22 | if len(annot) > 0: 23 | tmp_annot = annot.strip('.').strip('?').strip(',').strip('!') 24 | 25 | rep = self.parse_str_to_vector(tmp_annot) 26 | Ys, cache = self.model.fwdPass(rep, self.params, predict_model=True) # default: True 27 | 28 | maxes = np.amax(Ys, axis=1, keepdims=True) 29 | e = np.exp(Ys - maxes) # for numerical stability shift into good numerical range 30 | probs = e/np.sum(e, axis=1, keepdims=True) 31 | if np.all(np.isnan(probs)): probs = np.zeros(probs.shape) 32 | 33 | # special handling with intent label 34 | for tag_id in self.inverse_tag_dict.keys(): 35 | if self.inverse_tag_dict[tag_id].startswith('B-') or self.inverse_tag_dict[tag_id].startswith('I-') or self.inverse_tag_dict[tag_id] == 'O': 36 | probs[-1][tag_id] = 0 37 | 38 | pred_words_indices = np.nanargmax(probs, axis=1) 39 | pred_tags = [self.inverse_tag_dict[index] for index in pred_words_indices] 40 | 41 | diaact = self.parse_nlu_to_diaact(pred_tags, tmp_annot) 42 | return diaact 43 | else: 44 | return None 45 | 46 | 47 | def load_nlu_model(self, model_path): 48 | """ load the trained NLU model """ 49 | 50 | model_params = pickle.load(open(model_path)) 51 | # model_params = pickle.load(open(model_path, 'rb')) 52 | 53 | hidden_size = model_params['model']['Wd'].shape[0] 54 | output_size = model_params['model']['Wd'].shape[1] 55 | 56 | if model_params['params']['model'] == 'lstm': # lstm_ 57 | input_size = model_params['model']['WLSTM'].shape[0] - hidden_size - 1 58 | rnnmodel = lstm(input_size, hidden_size, output_size) 59 | elif model_params['params']['model'] == 'bi_lstm': # bi_lstm 60 | input_size = model_params['model']['WLSTM'].shape[0] - hidden_size - 1 61 | rnnmodel = biLSTM(input_size, hidden_size, output_size) 62 | 63 | rnnmodel.model = copy.deepcopy(model_params['model']) 64 | 65 | self.model = rnnmodel 66 | self.word_dict = copy.deepcopy(model_params['word_dict']) 67 | self.slot_dict = copy.deepcopy(model_params['slot_dict']) 68 | self.act_dict = copy.deepcopy(model_params['act_dict']) 69 | self.tag_set = copy.deepcopy(model_params['tag_set']) 70 | self.params = copy.deepcopy(model_params['params']) 71 | self.inverse_tag_dict = {self.tag_set[k]:k for k in self.tag_set.keys()} 72 | 73 | 74 | def parse_str_to_vector(self, string): 75 | """ Parse string into vector representations """ 76 | 77 | tmp = 'BOS ' + string + ' EOS' 78 | words = tmp.lower().split(' ') 79 | 80 | vecs = np.zeros((len(words), len(self.word_dict))) 81 | for w_index, w in enumerate(words): 82 | if w.endswith(',') or w.endswith('?'): w = w[0:-1] 83 | if w in self.word_dict.keys(): 84 | vecs[w_index][self.word_dict[w]] = 1 85 | else: vecs[w_index][self.word_dict['unk']] = 1 86 | 87 | rep = {} 88 | rep['word_vectors'] = vecs 89 | rep['raw_seq'] = string 90 | return rep 91 | 92 | def parse_nlu_to_diaact(self, nlu_vector, string): 93 | """ Parse BIO and Intent into Dia-Act """ 94 | 95 | tmp = 'BOS ' + string + ' EOS' 96 | words = tmp.lower().split(' ') 97 | 98 | diaact = {} 99 | diaact['diaact'] = "inform" 100 | diaact['request_slots'] = {} 101 | diaact['inform_slots'] = {} 102 | 103 | intent = nlu_vector[-1] 104 | index = 1 105 | pre_tag = nlu_vector[0] 106 | pre_tag_index = 0 107 | 108 | slot_val_dict = {} 109 | 110 | while index<(len(nlu_vector)-1): # except last Intent tag 111 | cur_tag = nlu_vector[index] 112 | if cur_tag == 'O' and pre_tag.startswith('B-'): 113 | slot = pre_tag.split('-')[1] 114 | slot_val_str = ' '.join(words[pre_tag_index:index]) 115 | slot_val_dict[slot] = slot_val_str 116 | elif cur_tag.startswith('B-') and pre_tag.startswith('B-'): 117 | slot = pre_tag.split('-')[1] 118 | slot_val_str = ' '.join(words[pre_tag_index:index]) 119 | slot_val_dict[slot] = slot_val_str 120 | elif cur_tag.startswith('B-') and pre_tag.startswith('I-'): 121 | if cur_tag.split('-')[1] != pre_tag.split('-')[1]: 122 | slot = pre_tag.split('-')[1] 123 | slot_val_str = ' '.join(words[pre_tag_index:index]) 124 | slot_val_dict[slot] = slot_val_str 125 | elif cur_tag == 'O' and pre_tag.startswith('I-'): 126 | slot = pre_tag.split('-')[1] 127 | slot_val_str = ' '.join(words[pre_tag_index:index]) 128 | slot_val_dict[slot] = slot_val_str 129 | 130 | if cur_tag.startswith('B-'): pre_tag_index = index 131 | 132 | pre_tag = cur_tag 133 | index += 1 134 | 135 | if cur_tag.startswith('B-') or cur_tag.startswith('I-'): 136 | slot = cur_tag.split('-')[1] 137 | slot_val_str = ' '.join(words[pre_tag_index:-1]) 138 | slot_val_dict[slot] = slot_val_str 139 | 140 | if intent != 'null': 141 | arr = intent.split('+') 142 | diaact['diaact'] = arr[0] 143 | diaact['request_slots'] = {} 144 | for ele in arr[1:]: 145 | #request_slots.append(ele) 146 | diaact['request_slots'][ele] = 'UNK' 147 | 148 | diaact['inform_slots'] = slot_val_dict 149 | 150 | # add rule here 151 | for slot in diaact['inform_slots'].keys(): 152 | slot_val = diaact['inform_slots'][slot] 153 | if slot_val.startswith('bos'): 154 | slot_val = slot_val.replace('bos', '', 1) 155 | diaact['inform_slots'][slot] = slot_val.strip(' ') 156 | 157 | self.refine_diaact_by_rules(diaact) 158 | return diaact 159 | 160 | def refine_diaact_by_rules(self, diaact): 161 | """ refine the dia_act by rules """ 162 | 163 | # rule for taskcomplete 164 | if 'request_slots' in diaact.keys(): 165 | if 'taskcomplete' in diaact['request_slots'].keys(): 166 | del diaact['request_slots']['taskcomplete'] 167 | diaact['inform_slots']['taskcomplete'] = 'PLACEHOLDER' 168 | 169 | # rule for request 170 | if len(diaact['request_slots'])>0: diaact['diaact'] = 'request' 171 | 172 | if len(diaact['request_slots'])==0 and diaact['diaact'] == 'request': diaact['diaact'] = 'inform' 173 | 174 | 175 | 176 | 177 | def diaact_penny_string(self, dia_act): 178 | """ Convert the Dia-Act into penny string """ 179 | 180 | penny_str = "" 181 | penny_str = dia_act['diaact'] + "(" 182 | for slot in dia_act['request_slots'].keys(): 183 | penny_str += slot + ";" 184 | 185 | for slot in dia_act['inform_slots'].keys(): 186 | slot_val_str = slot + "=" 187 | if len(dia_act['inform_slots'][slot]) == 1: 188 | slot_val_str += dia_act['inform_slots'][slot][0] 189 | else: 190 | slot_val_str += "{" 191 | for slot_val in dia_act['inform_slots'][slot]: 192 | slot_val_str += slot_val + "#" 193 | slot_val_str = slot_val_str[:-1] 194 | slot_val_str += "}" 195 | penny_str += slot_val_str + ";" 196 | 197 | if penny_str[-1] == ";": penny_str = penny_str[:-1] 198 | penny_str += ")" 199 | return penny_str -------------------------------------------------------------------------------- /src/deep_dialog/nlu/seq_seq.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | from .utils import * 8 | import time, os 9 | 10 | 11 | class SeqToSeq: 12 | def __init__(self, input_size, hidden_size, output_size): 13 | pass 14 | 15 | def get_struct(self): 16 | return {'model': self.model, 'update': self.update, 'regularize': self.regularize} 17 | 18 | 19 | """ Forward Function""" 20 | def fwdPass(self, Xs, params, **kwargs): 21 | pass 22 | 23 | def bwdPass(self, dY, cache): 24 | pass 25 | 26 | 27 | """ Batch Forward & Backward Pass""" 28 | def batchForward(self, ds, batch, params, predict_mode = False): 29 | caches = [] 30 | Ys = [] 31 | for i,x in enumerate(batch): 32 | Y, out_cache = self.fwdPass(x, params, predict_mode = predict_mode) 33 | caches.append(out_cache) 34 | Ys.append(Y) 35 | 36 | # back up information for efficient backprop 37 | cache = {} 38 | if not predict_mode: 39 | cache['caches'] = caches 40 | 41 | return Ys, cache 42 | 43 | def batchBackward(self, dY, cache): 44 | caches = cache['caches'] 45 | grads = {} 46 | for i in xrange(len(caches)): 47 | single_cache = caches[i] 48 | local_grads = self.bwdPass(dY[i], single_cache) 49 | mergeDicts(grads, local_grads) # add up the gradients wrt model parameters 50 | 51 | return grads 52 | 53 | 54 | """ Cost function, returns cost and gradients for model """ 55 | def costFunc(self, ds, batch, params): 56 | regc = params['reg_cost'] # regularization cost 57 | 58 | # batch forward RNN 59 | Ys, caches = self.batchForward(ds, batch, params, predict_mode = False) 60 | 61 | loss_cost = 0.0 62 | smooth_cost = 1e-15 63 | dYs = [] 64 | 65 | for i,x in enumerate(batch): 66 | labels = np.array(x['tags_rep'], dtype=int) 67 | 68 | # fetch the predicted probabilities 69 | Y = Ys[i] 70 | maxes = np.amax(Y, axis=1, keepdims=True) 71 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 72 | P = e/np.sum(e, axis=1, keepdims=True) 73 | 74 | # Cross-Entropy Cross Function 75 | loss_cost += -np.sum(np.log(smooth_cost + P[range(len(labels)), labels])) 76 | 77 | for iy,y in enumerate(labels): 78 | P[iy,y] -= 1 # softmax derivatives 79 | dYs.append(P) 80 | 81 | # backprop the RNN 82 | grads = self.batchBackward(dYs, caches) 83 | 84 | # add L2 regularization cost and gradients 85 | reg_cost = 0.0 86 | if regc > 0: 87 | for p in self.regularize: 88 | mat = self.model[p] 89 | reg_cost += 0.5*regc*np.sum(mat*mat) 90 | grads[p] += regc*mat 91 | 92 | # normalize the cost and gradient by the batch size 93 | batch_size = len(batch) 94 | reg_cost /= batch_size 95 | loss_cost /= batch_size 96 | for k in grads: grads[k] /= batch_size 97 | 98 | out = {} 99 | out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost} 100 | out['grads'] = grads 101 | return out 102 | 103 | 104 | """ A single batch """ 105 | def singleBatch(self, ds, batch, params): 106 | learning_rate = params.get('learning_rate', 0.0) 107 | decay_rate = params.get('decay_rate', 0.999) 108 | momentum = params.get('momentum', 0) 109 | grad_clip = params.get('grad_clip', 1) 110 | smooth_eps = params.get('smooth_eps', 1e-8) 111 | sdg_type = params.get('sdgtype', 'rmsprop') 112 | 113 | for u in self.update: 114 | if not u in self.step_cache: 115 | self.step_cache[u] = np.zeros(self.model[u].shape) 116 | 117 | cg = self.costFunc(ds, batch, params) 118 | 119 | cost = cg['cost'] 120 | grads = cg['grads'] 121 | 122 | # clip gradients if needed 123 | if params['activation_func'] == 'relu': 124 | if grad_clip > 0: 125 | for p in self.update: 126 | if p in grads: 127 | grads[p] = np.minimum(grads[p], grad_clip) 128 | grads[p] = np.maximum(grads[p], -grad_clip) 129 | 130 | # perform parameter update 131 | for p in self.update: 132 | if p in grads: 133 | if sdg_type == 'vanilla': 134 | if momentum > 0: dx = momentum*self.step_cache[p] - learning_rate*grads[p] 135 | else: dx = -learning_rate*grads[p] 136 | self.step_cache[p] = dx 137 | elif sdg_type == 'rmsprop': 138 | self.step_cache[p] = self.step_cache[p]*decay_rate + (1.0-decay_rate)*grads[p]**2 139 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 140 | elif sdg_type == 'adgrad': 141 | self.step_cache[p] += grads[p]**2 142 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 143 | 144 | self.model[p] += dx 145 | 146 | # create output dict and return 147 | out = {} 148 | out['cost'] = cost 149 | return out 150 | 151 | 152 | """ Evaluate on the dataset[split] """ 153 | def eval(self, ds, split, params): 154 | acc = 0 155 | total = 0 156 | 157 | total_cost = 0.0 158 | smooth_cost = 1e-15 159 | 160 | if split == 'test': 161 | res_filename = 'res_%s_[%s].txt' % (params['model'], time.time()) 162 | res_filepath = os.path.join(params['test_res_dir'], res_filename) 163 | res = open(res_filepath, 'w') 164 | inverse_tag_dict = {ds.data['tag_set'][k]:k for k in ds.data['tag_set'].keys()} 165 | 166 | for i, ele in enumerate(ds.split[split]): 167 | Ys, cache = self.fwdPass(ele, params, predict_model=True) 168 | 169 | maxes = np.amax(Ys, axis=1, keepdims=True) 170 | e = np.exp(Ys - maxes) # for numerical stability shift into good numerical range 171 | probs = e/np.sum(e, axis=1, keepdims=True) 172 | 173 | labels = np.array(ele['tags_rep'], dtype=int) 174 | 175 | if np.all(np.isnan(probs)): probs = np.zeros(probs.shape) 176 | 177 | loss_cost = 0 178 | loss_cost += -np.sum(np.log(smooth_cost + probs[range(len(labels)), labels])) 179 | total_cost += loss_cost 180 | 181 | pred_words_indices = np.nanargmax(probs, axis=1) 182 | 183 | tokens = ele['raw_seq'] 184 | real_tags = ele['tag_seq'] 185 | for index, l in enumerate(labels): 186 | if pred_words_indices[index] == l: acc += 1 187 | 188 | if split == 'test': 189 | res.write('%s %s %s %s\n' % (tokens[index], 'NA', real_tags[index], inverse_tag_dict[pred_words_indices[index]])) 190 | if split == 'test': res.write('\n') 191 | total += len(labels) 192 | 193 | total_cost /= len(ds.split[split]) 194 | accuracy = 0 if total == 0 else float(acc)/total 195 | 196 | #print ("total_cost: %s, accuracy: %s" % (total_cost, accuracy)) 197 | result = {'cost': total_cost, 'accuracy': accuracy} 198 | return result -------------------------------------------------------------------------------- /src/deep_dialog/nlu/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 13, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | import math 8 | import numpy as np 9 | 10 | 11 | def initWeights(n,d): 12 | """ Initialization Strategy """ 13 | #scale_factor = 0.1 14 | scale_factor = math.sqrt(float(6)/(n + d)) 15 | return (np.random.rand(n,d)*2-1)*scale_factor 16 | 17 | def mergeDicts(d0, d1): 18 | """ for all k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ 19 | for k in d1: 20 | if k in d0: d0[k] += d1[k] 21 | else: d0[k] = d1[k] -------------------------------------------------------------------------------- /src/deep_dialog/qlearning/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .dqn_torch import * -------------------------------------------------------------------------------- /src/deep_dialog/qlearning/dqn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 18, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | from .utils import * 8 | 9 | 10 | class DQN: 11 | 12 | def __init__(self, input_size, hidden_size, output_size): 13 | self.model = {} 14 | # input-hidden 15 | self.model['Wxh'] = initWeight(input_size, hidden_size) 16 | self.model['bxh'] = np.zeros((1, hidden_size)) 17 | 18 | # hidden-output 19 | self.model['Wd'] = initWeight(hidden_size, output_size)*0.1 20 | self.model['bd'] = np.zeros((1, output_size)) 21 | 22 | self.update = ['Wxh', 'bxh', 'Wd', 'bd'] 23 | self.regularize = ['Wxh', 'Wd'] 24 | 25 | self.step_cache = {} 26 | 27 | 28 | def getStruct(self): 29 | return {'model': self.model, 'update': self.update, 'regularize': self.regularize} 30 | 31 | 32 | """Activation Function: Sigmoid, or tanh, or ReLu""" 33 | def fwdPass(self, Xs, params, **kwargs): 34 | predict_mode = kwargs.get('predict_mode', False) 35 | active_func = params.get('activation_func', 'relu') 36 | 37 | # input layer to hidden layer 38 | Wxh = self.model['Wxh'] 39 | bxh = self.model['bxh'] 40 | Xsh = Xs.dot(Wxh) + bxh 41 | 42 | hidden_size = self.model['Wd'].shape[0] # size of hidden layer 43 | H = np.zeros((1, hidden_size)) # hidden layer representation 44 | 45 | if active_func == 'sigmoid': 46 | H = 1/(1+np.exp(-Xsh)) 47 | elif active_func == 'tanh': 48 | H = np.tanh(Xsh) 49 | elif active_func == 'relu': # ReLU 50 | H = np.maximum(Xsh, 0) 51 | else: # no activation function 52 | H = Xsh 53 | 54 | # decoder at the end; hidden layer to output layer 55 | Wd = self.model['Wd'] 56 | bd = self.model['bd'] 57 | Y = H.dot(Wd) + bd 58 | 59 | # cache the values in forward pass, we expect to do a backward pass 60 | cache = {} 61 | if not predict_mode: 62 | cache['Wxh'] = Wxh 63 | cache['Wd'] = Wd 64 | cache['Xs'] = Xs 65 | cache['Xsh'] = Xsh 66 | cache['H'] = H 67 | 68 | cache['bxh'] = bxh 69 | cache['bd'] = bd 70 | cache['activation_func'] = active_func 71 | 72 | cache['Y'] = Y 73 | 74 | return Y, cache 75 | 76 | def bwdPass(self, dY, cache): 77 | Wd = cache['Wd'] 78 | H = cache['H'] 79 | Xs = cache['Xs'] 80 | Xsh = cache['Xsh'] 81 | Wxh = cache['Wxh'] 82 | 83 | active_func = cache['activation_func'] 84 | n,d = H.shape 85 | 86 | dH = dY.dot(Wd.transpose()) 87 | # backprop the decoder 88 | dWd = H.transpose().dot(dY) 89 | dbd = np.sum(dY, axis=0, keepdims=True) 90 | 91 | dXsh = np.zeros(Xsh.shape) 92 | dXs = np.zeros(Xs.shape) 93 | 94 | if active_func == 'sigmoid': 95 | dH = (H-H**2)*dH 96 | elif active_func == 'tanh': 97 | dH = (1-H**2)*dH 98 | elif active_func == 'relu': 99 | dH = (H>0)*dH # backprop ReLU 100 | else: 101 | dH = dH 102 | 103 | # backprop to the input-hidden connection 104 | dWxh = Xs.transpose().dot(dH) 105 | dbxh = np.sum(dH, axis=0, keepdims = True) 106 | 107 | # backprop to the input 108 | dXsh = dH 109 | dXs = dXsh.dot(Wxh.transpose()) 110 | 111 | return {'Wd': dWd, 'bd': dbd, 'Wxh':dWxh, 'bxh':dbxh} 112 | 113 | 114 | """batch Forward & Backward Pass""" 115 | def batchForward(self, batch, params, predict_mode = False): 116 | caches = [] 117 | Ys = [] 118 | for i,x in enumerate(batch): 119 | Xs = np.array([x['cur_states']], dtype=float) 120 | 121 | Y, out_cache = self.fwdPass(Xs, params, predict_mode = predict_mode) 122 | caches.append(out_cache) 123 | Ys.append(Y) 124 | 125 | # back up information for efficient backprop 126 | cache = {} 127 | if not predict_mode: 128 | cache['caches'] = caches 129 | 130 | return Ys, cache 131 | 132 | def batchDoubleForward(self, batch, params, clone_dqn, predict_mode = False): 133 | caches = [] 134 | Ys = [] 135 | tYs = [] 136 | 137 | for i,x in enumerate(batch): 138 | Xs = x[0] 139 | Y, out_cache = self.fwdPass(Xs, params, predict_mode = predict_mode) 140 | caches.append(out_cache) 141 | Ys.append(Y) 142 | 143 | tXs = x[3] 144 | tY, t_cache = clone_dqn.fwdPass(tXs, params, predict_mode = False) 145 | 146 | tYs.append(tY) 147 | 148 | # back up information for efficient backprop 149 | cache = {} 150 | if not predict_mode: 151 | cache['caches'] = caches 152 | 153 | return Ys, cache, tYs 154 | 155 | def batchBackward(self, dY, cache): 156 | caches = cache['caches'] 157 | 158 | grads = {} 159 | for i in xrange(len(caches)): 160 | single_cache = caches[i] 161 | local_grads = self.bwdPass(dY[i], single_cache) 162 | mergeDicts(grads, local_grads) # add up the gradients wrt model parameters 163 | 164 | return grads 165 | 166 | 167 | """ cost function, returns cost and gradients for model """ 168 | def costFunc(self, batch, params, clone_dqn): 169 | regc = params.get('reg_cost', 1e-3) 170 | gamma = params.get('gamma', 0.9) 171 | 172 | # batch forward 173 | Ys, caches, tYs = self.batchDoubleForward(batch, params, clone_dqn, predict_mode = False) 174 | 175 | loss_cost = 0.0 176 | dYs = [] 177 | for i,x in enumerate(batch): 178 | Y = Ys[i] 179 | nY = tYs[i] 180 | 181 | action = np.array(x[1], dtype=int) 182 | reward = np.array(x[2], dtype=float) 183 | 184 | n_action = np.nanargmax(nY[0]) 185 | max_next_y = nY[0][n_action] 186 | 187 | eposide_terminate = x[4] 188 | 189 | target_y = reward 190 | if eposide_terminate != True: target_y += gamma*max_next_y 191 | 192 | pred_y = Y[0][action] 193 | 194 | nY = np.zeros(nY.shape) 195 | nY[0][action] = target_y 196 | Y = np.zeros(Y.shape) 197 | Y[0][action] = pred_y 198 | 199 | # Cost Function 200 | loss_cost += (target_y - pred_y)**2 201 | 202 | dY = -(nY - Y) 203 | #dY = np.minimum(dY, 1) 204 | #dY = np.maximum(dY, -1) 205 | dYs.append(dY) 206 | 207 | # backprop the RNN 208 | grads = self.batchBackward(dYs, caches) 209 | 210 | # add L2 regularization cost and gradients 211 | reg_cost = 0.0 212 | if regc > 0: 213 | for p in self.regularize: 214 | mat = self.model[p] 215 | reg_cost += 0.5*regc*np.sum(mat*mat) 216 | grads[p] += regc*mat 217 | 218 | # normalize the cost and gradient by the batch size 219 | batch_size = len(batch) 220 | reg_cost /= batch_size 221 | loss_cost /= batch_size 222 | for k in grads: grads[k] /= batch_size 223 | 224 | out = {} 225 | out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost} 226 | out['grads'] = grads 227 | return out 228 | 229 | 230 | """ A single batch """ 231 | def singleBatch(self, batch, params, clone_dqn): 232 | learning_rate = params.get('learning_rate', 0.001) 233 | decay_rate = params.get('decay_rate', 0.999) 234 | momentum = params.get('momentum', 0.1) 235 | grad_clip = params.get('grad_clip', -1e-3) 236 | smooth_eps = params.get('smooth_eps', 1e-8) 237 | sdg_type = params.get('sdgtype', 'rmsprop') 238 | activation_func = params.get('activation_func', 'relu') 239 | 240 | for u in self.update: 241 | if not u in self.step_cache: 242 | self.step_cache[u] = np.zeros(self.model[u].shape) 243 | 244 | cg = self.costFunc(batch, params, clone_dqn) 245 | 246 | cost = cg['cost'] 247 | grads = cg['grads'] 248 | 249 | # clip gradients if needed 250 | if activation_func.lower() == 'relu': 251 | if grad_clip > 0: 252 | for p in self.update: 253 | if p in grads: 254 | grads[p] = np.minimum(grads[p], grad_clip) 255 | grads[p] = np.maximum(grads[p], -grad_clip) 256 | 257 | # perform parameter update 258 | for p in self.update: 259 | if p in grads: 260 | if sdg_type == 'vanilla': 261 | if momentum > 0: 262 | dx = momentum*self.step_cache[p] - learning_rate*grads[p] 263 | else: 264 | dx = -learning_rate*grads[p] 265 | self.step_cache[p] = dx 266 | elif sdg_type == 'rmsprop': 267 | self.step_cache[p] = self.step_cache[p]*decay_rate + (1.0-decay_rate)*grads[p]**2 268 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 269 | elif sdg_type == 'adgrad': 270 | self.step_cache[p] += grads[p]**2 271 | dx = -(learning_rate*grads[p])/np.sqrt(self.step_cache[p] + smooth_eps) 272 | 273 | self.model[p] += dx 274 | 275 | out = {} 276 | out['cost'] = cost 277 | return out 278 | 279 | """ prediction """ 280 | def predict(self, Xs, params, **kwargs): 281 | Ys, caches = self.fwdPass(Xs, params, predict_model=True) 282 | pred_action = np.argmax(Ys) 283 | 284 | return pred_action 285 | -------------------------------------------------------------------------------- /src/deep_dialog/qlearning/dqn_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class DQN(nn.Module): 10 | def __init__(self, input_size, hidden_size, output_size): 11 | super(DQN, self).__init__() 12 | 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.output_size = output_size 16 | 17 | self.linear_i2h = nn.Linear(self.input_size, self.hidden_size) 18 | self.linear_h2o = nn.Linear(self.hidden_size, self.output_size) 19 | 20 | def forward(self, x): 21 | x = F.tanh(self.linear_i2h(x)) 22 | x = self.linear_h2o(x) 23 | return x 24 | 25 | def predict(self, x): 26 | y = self.forward(x) 27 | return torch.argmax(y, 1) 28 | 29 | -------------------------------------------------------------------------------- /src/deep_dialog/qlearning/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jun 18, 2016 3 | 4 | @author: xiul 5 | ''' 6 | 7 | import numpy as np 8 | import math 9 | 10 | 11 | def initWeight(n,d): 12 | scale_factor = math.sqrt(float(6)/(n + d)) 13 | #scale_factor = 0.1 14 | return (np.random.rand(n,d)*2-1)*scale_factor 15 | 16 | """ for all k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ 17 | def mergeDicts(d0, d1): 18 | for k in d1: 19 | if k in d0: 20 | d0[k] += d1[k] 21 | else: 22 | d0[k] = d1[k] -------------------------------------------------------------------------------- /src/deep_dialog/usersims/__init__.py: -------------------------------------------------------------------------------- 1 | from .usersim_rule import * 2 | from .usersim_model import * -------------------------------------------------------------------------------- /src/deep_dialog/usersims/user_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class SimulatorModel(nn.Module): 10 | def __init__(self, 11 | agent_action_size, 12 | hidden_size, 13 | state_size, 14 | user_action_size, 15 | reward_size=1, 16 | termination_size=1): 17 | super(SimulatorModel, self).__init__() 18 | 19 | self.linear_i2h = nn.Linear(state_size, hidden_size) 20 | self.agent_emb = nn.Embedding(agent_action_size, hidden_size) 21 | self.linear_h2r = nn.Linear(hidden_size, reward_size) 22 | self.linear_h2t = nn.Linear(hidden_size, termination_size) 23 | self.linear_h2a = nn.Linear(hidden_size, user_action_size) 24 | 25 | def forward(self, s, a): 26 | h_s = self.linear_i2h(s) 27 | h_a = self.agent_emb(a).squeeze(1) 28 | h = F.tanh(h_s + h_a) 29 | 30 | reward = self.linear_h2r(h) 31 | term = self.linear_h2t(h) 32 | action = F.log_softmax(self.linear_h2a(h), 1) 33 | 34 | return reward, term, action 35 | 36 | def predict(self, s, a): 37 | h_s = self.linear_i2h(s) 38 | h_a = self.agent_emb(a).squeeze(1) 39 | h = F.tanh(h_s + h_a) 40 | 41 | reward = self.linear_h2r(h) 42 | term = F.sigmoid(self.linear_h2t(h)) 43 | action = F.log_softmax(self.linear_h2a(h), 1) 44 | 45 | return reward, term, action.argmax(1) 46 | -------------------------------------------------------------------------------- /src/deep_dialog/usersims/usersim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on June 7, 2016 3 | 4 | a rule-based user simulator 5 | 6 | @author: xiul, t-zalipt 7 | """ 8 | 9 | import random 10 | 11 | 12 | class UserSimulator: 13 | """ Parent class for all user sims to inherit from """ 14 | 15 | def __init__(self, movie_dict=None, act_set=None, slot_set=None, start_set=None, params=None): 16 | """ Constructor shared by all user simulators """ 17 | 18 | self.movie_dict = movie_dict 19 | self.act_set = act_set 20 | self.slot_set = slot_set 21 | self.start_set = start_set 22 | 23 | self.max_turn = params['max_turn'] 24 | self.slot_err_probability = params['slot_err_probability'] 25 | self.slot_err_mode = params['slot_err_mode'] 26 | self.intent_err_probability = params['intent_err_probability'] 27 | 28 | def initialize_episode(self): 29 | """ Initialize a new episode (dialog)""" 30 | 31 | print "initialize episode called, generating goal" 32 | self.goal = random.choice(self.start_set) 33 | self.goal['request_slots']['ticket'] = 'UNK' 34 | episode_over, user_action = self._sample_action() 35 | assert (episode_over != 1), ' but we just started' 36 | return user_action 37 | 38 | def next(self, system_action, *argv): 39 | pass 40 | 41 | def set_nlg_model(self, nlg_model): 42 | self.nlg_model = nlg_model 43 | 44 | def set_nlu_model(self, nlu_model): 45 | self.nlu_model = nlu_model 46 | 47 | def add_nl_to_action(self, user_action): 48 | """ Add NL to User Dia_Act """ 49 | 50 | user_nlg_sentence = self.nlg_model.convert_diaact_to_nl(user_action, 'usr') 51 | user_action['nl'] = user_nlg_sentence 52 | 53 | if self.simulator_act_level == 1: 54 | user_nlu_res = self.nlu_model.generate_dia_act(user_action['nl']) # NLU 55 | if user_nlu_res != None: 56 | # user_nlu_res['diaact'] = user_action['diaact'] # or not? 57 | user_action.update(user_nlu_res) 58 | -------------------------------------------------------------------------------- /src/draw_learning_curve.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov 3, 2016 3 | 4 | draw a learning curve 5 | 6 | @author: xiul 7 | ''' 8 | 9 | import argparse, json 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def read_performance_records(path): 14 | """ load the performance score (.json) file """ 15 | 16 | data = json.load(open(path, 'rb')) 17 | for key in data['success_rate'].keys(): 18 | if int(key) > -1: 19 | print("%s\t%s\t%s\t%s" % (key, data['success_rate'][key], data['ave_turns'][key], data['ave_reward'][key])) 20 | 21 | 22 | def load_performance_file(path): 23 | """ load the performance score (.json) file """ 24 | 25 | data = json.load(open(path, 'rb')) 26 | numbers = {'x': [], 'success_rate':[], 'ave_turns':[], 'ave_rewards':[]} 27 | keylist = [int(key) for key in data['success_rate'].keys()] 28 | keylist.sort() 29 | 30 | for key in keylist: 31 | if int(key) > -1: 32 | numbers['x'].append(int(key)) 33 | numbers['success_rate'].append(data['success_rate'][str(key)]) 34 | numbers['ave_turns'].append(data['ave_turns'][str(key)]) 35 | numbers['ave_rewards'].append(data['ave_reward'][str(key)]) 36 | return numbers 37 | 38 | def draw_learning_curve(numbers): 39 | """ draw the learning curve """ 40 | 41 | plt.xlabel('Simulation Epoch') 42 | plt.ylabel('Success Rate') 43 | plt.title('Learning Curve') 44 | plt.grid(True) 45 | 46 | plt.plot(numbers['x'], numbers['success_rate'], 'r', lw=1) 47 | plt.show() 48 | 49 | 50 | 51 | def main(params): 52 | cmd = params['cmd'] 53 | 54 | if cmd == 0: 55 | numbers = load_performance_file(params['result_file']) 56 | draw_learning_curve(numbers) 57 | elif cmd == 1: 58 | read_performance_records(params['result_file']) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | 64 | parser.add_argument('--cmd', dest='cmd', type=int, default=1, help='cmd') 65 | 66 | parser.add_argument('--result_file', dest='result_file', type=str, default='./deep_dialog/checkpoints/rl_agent/11142016/noe2e/agt_9_performance_records.json', help='path to the result file') 67 | 68 | args = parser.parse_args() 69 | params = vars(args) 70 | print json.dumps(params, indent=2) 71 | 72 | main(params) --------------------------------------------------------------------------------