├── src ├── __init__.py ├── classifier │ ├── __init__.py │ ├── run │ │ ├── __init__.py │ │ └── run_cla.py │ ├── symptom_as_feature │ │ ├── __init__.py │ │ └── svm_kliao.py │ └── self_report_as_feature │ │ ├── __init__.py │ │ └── report_classifier.py ├── dialogue_system │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── symptom_dist.pdf │ │ ├── goal_to_slot.py │ │ ├── plot_slot_distribution.py │ │ ├── slot_distribution.py │ │ ├── plot_slot_dist.py │ │ ├── goal_action_slots_dumper.py │ │ ├── plot_single_dist.py │ │ ├── draw_curve_each.py │ │ ├── plot_curve_each.py │ │ └── draw_curve_std.py │ ├── policy_learning │ │ ├── __init__.py │ │ ├── dqn_with_goal_joint.py │ │ └── dqn_with_goal.py │ ├── run │ │ ├── __init__.py │ │ └── utils.py │ ├── state_tracker │ │ ├── __init__.py │ │ └── state_tracker.py │ ├── user_simulator │ │ ├── __init__.py │ │ └── user_rule.py │ ├── dialogue_manager │ │ └── __init__.py │ ├── memory │ │ ├── __init__.py │ │ ├── util.py │ │ ├── base.py │ │ ├── prioritized.py │ │ ├── onpolicy.py │ │ └── replay.py │ ├── agent │ │ ├── __init__.py │ │ ├── agent_random.py │ │ ├── prioritized_new.py │ │ ├── agent_with_goal_joint.py │ │ └── agent_rule.py │ ├── dialogue_configuration.py │ └── disease_classifier.py └── dqn_gym.py ├── preprocess ├── __init__.py ├── label │ ├── __init__.py │ ├── get_slot_from_goal.py │ ├── svm_class.py │ ├── frequency.py │ └── preprocess_label.py ├── README.txt ├── match_disease.py ├── symptom_liking.py ├── run_pre.py ├── extract_symptom.py ├── top_disease.py ├── statistics.py ├── kliao │ └── svm_class.py └── aligned_symptoms_extracting.py ├── log ├── slot_set.p ├── IOHandler.py └── utils.py ├── .gitignore ├── README.md ├── draw_finals.py └── draw_distribution.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/label/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/classifier/run/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dialogue_system/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dialogue_system/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dialogue_system/policy_learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dialogue_system/run/__init__.py: -------------------------------------------------------------------------------- 1 | from .running_steward import * -------------------------------------------------------------------------------- /src/dialogue_system/state_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .state_tracker import * -------------------------------------------------------------------------------- /src/classifier/symptom_as_feature/__init__.py: -------------------------------------------------------------------------------- 1 | from .symptom_classifier import * -------------------------------------------------------------------------------- /src/classifier/self_report_as_feature/__init__.py: -------------------------------------------------------------------------------- 1 | from .report_classifier import * -------------------------------------------------------------------------------- /log/slot_set.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnbay/MeicalChatbot-HRL/HEAD/log/slot_set.p -------------------------------------------------------------------------------- /src/dialogue_system/user_simulator/__init__.py: -------------------------------------------------------------------------------- 1 | from .user import * 2 | from .user_rule import * -------------------------------------------------------------------------------- /src/dialogue_system/dialogue_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .dialogue_manager import * 2 | from .dialogue_manager_hrl import * 3 | -------------------------------------------------------------------------------- /src/dialogue_system/utils/symptom_dist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnbay/MeicalChatbot-HRL/HEAD/src/dialogue_system/utils/symptom_dist.pdf -------------------------------------------------------------------------------- /preprocess/README.txt: -------------------------------------------------------------------------------- 1 | 数据预处理,从原始的文件得到可以进行对话的user goal文件,总共包括以下几个步骤: 2 | 但是前提需要手动整理好哪些是top疾病,且top疾病中口语表达symptom和归一化symptom之间的对应关系,即top_disease_symptom_aligned.json 3 | 文件。 4 | 1. 运行top_disease.py 文件,里面定义好需要抽取的几种疾病名称; 5 | 2. 运行match_disease.py, 从包含主诉症状的文件中抽取出前几种疾病的主诉内容和相应的症状。 6 | 3. 运行extract_symptom.py文件,分别从主诉文本、对话内容中抽取疾病症状,得到的症状都是口语化表达形式; -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | resources/* 3 | others/ 4 | src/dialogue_system/data/* 5 | src/dialogue_system/model/* 6 | log/train/* 7 | **/*.logs 8 | **/*.log 9 | **/*.py[co] 10 | ./log 11 | 12 | test/* 13 | 14 | **/**/**/*.py[co] 15 | **/__pycache__/ 16 | **/**/__pycache__/ 17 | **/.DS_Store 18 | **/**/.DS_Store 19 | **/**/**/.DS_Store 20 | **/**/**/**/.DS_Store 21 | -------------------------------------------------------------------------------- /src/dialogue_system/memory/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified by Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | ''' 5 | The memory module 6 | Contains different ways of storing an agents experiences and sampling from them 7 | ''' 8 | 9 | from .onpolicy import * 10 | from .prioritized import * 11 | # expose all the classes 12 | from .replay import * 13 | -------------------------------------------------------------------------------- /src/dialogue_system/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_random import * 2 | from .agent_dqn import * 3 | from .agent import * 4 | from .agent_rule import * 5 | from .agent_hrl import * 6 | from .agent_with_goal_joint import * 7 | from .agent_with_goal import * 8 | from .agent_hrl_new import * 9 | from .agent_hrl_new2 import * 10 | from .agent_hrl_joint import * 11 | from .agent_hrl_joint2 import * -------------------------------------------------------------------------------- /src/dialogue_system/user_simulator/user_rule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | User simulator which is based on rules. 4 | """ 5 | 6 | import sys, os 7 | sys.path.append(os.getcwd().replace("src/dialogue_system/user_simulator","")) 8 | from src.dialogue_system.user_simulator.user import User 9 | 10 | class UserRule(User): 11 | def __init__(self, goal_set, disease_syptom, parameter): 12 | super(UserRule,self).__init__(goal_set=goal_set, 13 | disease_symptom=disease_syptom, 14 | parameter=parameter) -------------------------------------------------------------------------------- /src/dialogue_system/memory/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import operator 4 | import pydash as ps 5 | 6 | 7 | def batch_get(arr, idxs): 8 | '''Get multi-idxs from an array depending if it's a python list or np.array''' 9 | if isinstance(arr, (list, deque)): 10 | return np.array(operator.itemgetter(*idxs)(arr)) 11 | else: 12 | return arr[idxs] 13 | 14 | def set_attr(obj, attr_dict, keys=None): 15 | '''Set attribute of an object from a dict''' 16 | if keys is not None: 17 | attr_dict = ps.pick(attr_dict, keys) 18 | for attr, val in attr_dict.items(): 19 | setattr(obj, attr, val) 20 | return obj -------------------------------------------------------------------------------- /log/IOHandler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | """ 3 | 写一个类,用来进行文件的读写操作。 4 | """ 5 | 6 | import datetime 7 | import csv 8 | import os 9 | 10 | 11 | class FileIO(object): 12 | @staticmethod 13 | def writeToFile(text, filename): 14 | file = open(filename, 'a+',encoding='utf8') 15 | file.write(text + '\n') 16 | file.close() 17 | 18 | @staticmethod 19 | def writeToCsvFile(list_msg, filename, mode='a+'): 20 | file = open('./' + filename, mode=mode,encoding='utf8') 21 | writer = csv.writer(file) 22 | writer.writerow(list_msg) 23 | file.close() 24 | 25 | @staticmethod 26 | def exceptionHandler(message, url=''): 27 | FileIO.writeToFile(text='[' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ']: ' + url + '\n' 28 | + message, filename='./../../logs/error_log.logs') -------------------------------------------------------------------------------- /src/dialogue_system/dialogue_configuration.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | # Max turn. 4 | MAX_TURN = 22 5 | 6 | # DIALOGUE STATUS 7 | DIALOGUE_STATUS_FAILED = 0 8 | DIALOGUE_STATUS_SUCCESS = 1 9 | DIALOGUE_STATUS_NOT_COME_YET = -1 10 | DIALOGUE_STATUS_NOT_GET_ALL_SYMPTOMS = -2 11 | DIALOGUE_STATUS_INFORM_WRONG_DISEASE = 2 12 | DIALOGUE_STATUS_INFORM_RIGHT_SYMPTOM = 3 13 | DIALOGUE_STATUS_REACH_MAX_TURN = -3 14 | 15 | # Special Actions. 16 | CLOSE_DIALOGUE = "closing" 17 | THANKS = "thanks" 18 | 19 | # Slot value for unknown, placeholder and no value matches. 20 | VALUE_UNKNOWN = "UNK" 21 | VALUE_PLACEHOLDER = "placeholder" 22 | VALUE_NO_MATCH = "No value matches." 23 | 24 | # RESPONSE 25 | I_DO_NOT_CARE = "I don't care." 26 | I_DO_NOT_KNOW = "I don't know." 27 | I_DENY = "No" 28 | 29 | # Constraint Check 30 | CONSTRAINT_CHECK_SUCCESS = 1 31 | CONSTRAINT_CHECK_FAILURE = 0 32 | 33 | # Update condition 34 | SUCCESS_RATE_THRESHOLD = 0.15 35 | AVERAGE_WRONG_DISEASE = 7 -------------------------------------------------------------------------------- /src/dialogue_system/utils/goal_to_slot.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | 4 | class Goal2Slot(object): 5 | def __init__(self): 6 | pass 7 | 8 | def load_goal(self,goal_file): 9 | slot_set = set() 10 | goal_set = pickle.load(open(goal_file,"rb")) 11 | for key in goal_set.keys(): 12 | for goal in goal_set[key]: 13 | for symptom in goal["goal"]["explicit_inform_slots"].keys(): 14 | slot_set.add(symptom) 15 | for symptom in goal["goal"]["implicit_inform_slots"].keys(): 16 | slot_set.add(symptom) 17 | self.slot_set = list(slot_set) 18 | print(len(self.slot_set)) 19 | 20 | 21 | if __name__ == "__main__": 22 | goal_file = "./../data/dataset/1200/0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN/goal_set_2.p" 23 | goal2slot = Goal2Slot() 24 | goal2slot.load_goal(goal_file) 25 | -------------------------------------------------------------------------------- /preprocess/match_disease.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 从包含所有主诉内容和症状的文件(self_report_extracted_symptom.csv)中抽取出前几种疾病的主诉内容和症状。 4 | """ 5 | import pandas as pd 6 | import csv 7 | 8 | class DiseaseMatch(object): 9 | def __init__(self, top_disease_list, self_report_extracted_symptom_file): 10 | self.report_extracted_symptom_file = self_report_extracted_symptom_file 11 | self.top_disease_list = top_disease_list 12 | 13 | def match(self, save_file_name): 14 | report_symptom = open(self.report_extracted_symptom_file,mode='r',encoding="utf-8") 15 | report_symptom_reader = csv.reader(report_symptom) 16 | 17 | save_file = open(save_file_name,encoding="utf-8",mode="w") 18 | writer = csv.writer(save_file) 19 | 20 | index = 0 21 | for line in report_symptom_reader: 22 | if line[5] in self.top_disease_list: 23 | print(line) 24 | writer.writerow(line) 25 | index += 1 26 | print(index) 27 | save_file.close() 28 | report_symptom.close() 29 | -------------------------------------------------------------------------------- /src/dialogue_system/agent/agent_random.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An agent that randomly choose an action from action_set. 4 | """ 5 | import random 6 | 7 | import sys, os 8 | sys.path.append(os.getcwd().replace("src/dialogue_system/agent","")) 9 | 10 | from src.dialogue_system.agent.agent import Agent 11 | 12 | 13 | class AgentRandom(Agent): 14 | def __init__(self, action_set, slot_set, disease_symptom, parameter): 15 | super(AgentRandom, self).__init__(action_set=action_set, slot_set=slot_set,disease_symptom=disease_symptom,parameter=parameter) 16 | self.max_turn = parameter["max_turn"] 17 | 18 | def next(self, state,turn,greedy_strategy,**kwargs): 19 | self.agent_action["turn"] = turn 20 | action_index = random.randint(0, len(self.action_space)-1) 21 | agent_action = self.action_space[action_index] 22 | agent_action["turn"] = turn 23 | agent_action["speaker"] = "agent" 24 | return agent_action, action_index 25 | 26 | def train_mode(self): 27 | pass 28 | 29 | def eval_mode(self): 30 | pass -------------------------------------------------------------------------------- /src/dialogue_system/memory/base.py: -------------------------------------------------------------------------------- 1 | # Modified by Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class Memory(ABC): 8 | '''Abstract Memory class to define the API methods''' 9 | 10 | def __init__(self, paramter): 11 | ''' 12 | @param {*} body is the unit that stores its experience in this memory. Each body has a distinct memory. 13 | ''' 14 | self.parameter = paramter 15 | # declare what data keys to store 16 | self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones', 'priorities'] 17 | 18 | @abstractmethod 19 | def reset(self): 20 | '''Method to fully reset the memory storage and related variables''' 21 | raise NotImplementedError 22 | 23 | @abstractmethod 24 | def update(self, state, action, reward, next_state, done): 25 | '''Implement memory update given the full info from the latest timestep. NOTE: guard for np.nan reward and done when individual env resets.''' 26 | raise NotImplementedError 27 | 28 | @abstractmethod 29 | def sample(self): 30 | '''Implement memory sampling mechanism''' 31 | raise NotImplementedError 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedicalChatbot-HRL 2 | 3 | This is the code of a task-oriented dialogue system for automatic diagnosis, where a hierarchical reinforcement learning are implemented as the dialogue policy. The low level RL consists of multi-agents and each agent is the specific policy in terms of a type of disease. While the high level RL is responsible for selecting one of the low level agent or the disease classifier, then the selected agent will interact with patient in this dialogue turn. 4 | 5 | The datasets can be downloaded in http://www.sdspeople.fudan.edu.cn/zywei/data/Fudan-Medical-Dialogue2.0 6 | 7 | The paper draft is available at https://arxiv.org/abs/2004.14254 8 | 9 | # How to run the code 10 | 11 | 1. download the datasets and unzip this folder in src/data. 12 | 2. Using the following command to run the code. 13 | ``` 14 | cd src/dialogue_system/run 15 | python run.py --help 16 | ``` 17 | 18 | 19 | # Cite 20 | ``` 21 | @article{liao2020task, 22 | title={Task-oriented Dialogue System for Automatic Disease Diagnosis via Hierarchical Reinforcement Learning}, 23 | author={Liao, Kangenbei and Liu, Qianlong and Wei, Zhongyu and Peng, Baolin and Chen, Qin and Sun, Weijian and Huang, Xuanjing}, 24 | journal={arXiv preprint arXiv:2004.14254}, 25 | year={2020} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /draw_finals.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri May 17 20:26:03 2019 4 | 5 | @author: DELL 6 | """ 7 | 8 | import pickle 9 | import os 10 | import copy 11 | 12 | #file0='./src/data/real_world' 13 | file0='./src/data/simulated/label13' 14 | disease_symptom = pickle.load(open(file0+'/disease_symptom.p','rb')) 15 | slot_set = pickle.load(open(file0+'/slot_set.p','rb')) 16 | slot_set.pop('disease') 17 | disease2id = {} 18 | for disease,value in disease_symptom.items(): 19 | index = value['index'] 20 | disease2id[disease] = index 21 | sorts = sorted(disease2id.items(),key = lambda x:x[1],reverse = False) 22 | diseases = [x[0] for x in sorts] 23 | id2disease = {value:key for key,value in disease2id.items()} 24 | id2slot = {value:key for key,value in slot_set.items()} 25 | 26 | dirs = os.listdir('./visit/') 27 | result_dir = os.path.join('./visit',dirs[-1]) 28 | 29 | result = pickle.load(open(result_dir,'rb')) 30 | 31 | symptom_disease = {} 32 | for s in slot_set.values(): 33 | symptom_disease[id2slot[s]] = {} 34 | for d in diseases: 35 | symptom_disease[id2slot[s]][d]=0 36 | 37 | temp = copy.deepcopy(result['disease']) 38 | for d,values in temp.items(): 39 | for slot,count in values.items(): 40 | if slot>=(len(slot_set)): 41 | pass 42 | else: 43 | symptom_disease[id2slot[slot]][id2disease[d]]+=count 44 | 45 | with open('./symptom_disease_final_s.csv','w') as f: 46 | f.writelines(','+','.join(diseases)) 47 | f.writelines('\n') 48 | for d,s_dict in symptom_disease.items(): 49 | s_value = list(s_dict.values()) 50 | s_value = [str(x) for x in s_value] 51 | f.writelines(d+','+','.join(s_value)) 52 | f.writelines('\n') 53 | 54 | #with open('./resource/MedicalChatbotMultiAgent1/visit') 55 | -------------------------------------------------------------------------------- /preprocess/symptom_liking.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | conversation.txt是包含所有疾病的对话内容(没有主诉),下面部分代码是从这个文件中将所要研究的几种疾病数据抽取出来一边进行下一步处理。 5 | 实际中从原始文件到可以使用的user goal文件,这个代码应该用不到。 6 | """ 7 | 8 | import pandas as pd 9 | 10 | 11 | class ReportConversation(object): 12 | def __init__(self): 13 | pass 14 | 15 | def match(self, conversation_file_name, save_file_name,report_file_name=None,consult_id_list=None): 16 | assert (report_file_name is None or consult_id_list is None), "no consult id is provided." 17 | if report_file_name != None: 18 | project_data = pd.read_csv(report_file_name, sep="\t") 19 | print(project_data) 20 | consult_id_list = list(project_data['咨询ID']) 21 | 22 | return self.__match_based_on_id__(conversation_file_name,save_file_name,consult_id_list) 23 | 24 | def __match_based_on_id__(self, conversation_file_name, save_file_name,consult_id_list): 25 | conversation_file = open(conversation_file_name, 'r', encoding="utf-8") 26 | save_conversation_file = open(save_file_name, mode="w", encoding="utf-8") 27 | found_consult_id_list = [] 28 | write = False 29 | for line in conversation_file: 30 | line = line.replace("\n", "") 31 | if "consult_id" in line: 32 | line = line.replace(' ', "") 33 | consult_id = str(line.split(":")[1]) 34 | line = "\n" + line 35 | if consult_id in consult_id_list: 36 | found_consult_id_list.append(consult_id) 37 | write = True 38 | print(consult_id) 39 | if len(line) == 0: 40 | write = False 41 | if write: 42 | save_conversation_file.write(line + "\n") 43 | conversation_file.close() 44 | save_conversation_file.close() 45 | return found_consult_id_list 46 | 47 | 48 | -------------------------------------------------------------------------------- /preprocess/label/get_slot_from_goal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | 5 | 6 | class GoalReader(object): 7 | def __init__(self): 8 | pass 9 | 10 | def load(self, goal_file): 11 | self.goal_file = goal_file 12 | goal_set = pickle.load(open(goal_file, 'rb')) 13 | self.goal_set = goal_set['train'] + goal_set['test'] + goal_set['validate'] 14 | 15 | def dump(self, slot_file, disease_symptom_file): 16 | self.slot_set = {} 17 | self.disease_symptom = {} 18 | for goal in self.goal_set: 19 | disease = goal['disease_tag'] 20 | self.disease_symptom.setdefault(disease, {'index':len(self.disease_symptom), 'symptom':dict()}) 21 | slot_set = goal['goal']['explicit_inform_slots'] 22 | slot_set.update(goal['goal']['implicit_inform_slots']) 23 | for slot, value in slot_set.items(): 24 | self.slot_set.setdefault(slot, len(self.slot_set)) 25 | self.disease_symptom[disease]['symptom'].setdefault(slot,0) 26 | self.disease_symptom[disease]['symptom'][slot] += 1 27 | self.slot_set['disease'] = len(self.slot_set) 28 | # for key in self.disease_symptom.keys(): 29 | # self.disease_symptom[key]['symptom'] = list(self.disease_symptom[key]['symptom']) 30 | 31 | pickle.dump(obj=self.slot_set, file=open(slot_file, 'wb')) 32 | pickle.dump(obj=self.disease_symptom, file=open(disease_symptom_file, 'wb')) 33 | print(len(self.slot_set), self.slot_set) 34 | for key, value in self.disease_symptom.items(): 35 | print(key, len(value['symptom'])) 36 | print(len(self.disease_symptom), self.disease_symptom) 37 | 38 | 39 | if __name__ == '__main__': 40 | path = './../../resources/label/used/' 41 | reader = GoalReader() 42 | reader.load(path + 'goal_set.p') 43 | reader.dump(path + 'slot_set.p', path + 'disease_symptom.p') -------------------------------------------------------------------------------- /draw_distribution.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu May 16 15:44:49 2019 4 | 5 | @author: DELL 6 | """ 7 | 8 | import pickle 9 | 10 | file0='./src/data/real_world' 11 | #file0='./src/data/simulated/label13' 12 | goal_set=pickle.load(open(file0+'/goal_set.p','rb')) 13 | slot_set = pickle.load(open(file0+'/slot_set.p','rb')) 14 | disease_symptom = pickle.load(open(file0+'/disease_symptom.p','rb')) 15 | slot_set.pop('disease') 16 | goals=[] 17 | for goal in goal_set.values(): 18 | goals += goal 19 | 20 | diseases = list(set([x['disease_tag'] for x in goals])) 21 | disease_symptom2 = {} 22 | disease_symptom3 = {} 23 | for i in diseases: 24 | disease_symptom2[i] = {} 25 | disease_symptom3[i] = {} 26 | 27 | symptom_disease = {} 28 | symptom_disease2 = {} 29 | for s in slot_set.keys(): 30 | symptom_disease[s] = {} 31 | symptom_disease2[s] = {} 32 | for d in diseases: 33 | symptom_disease[s][d]=0 34 | symptom_disease2[s][d] = 0 35 | 36 | 37 | 38 | for a in goals: 39 | disease = a['disease_tag'] 40 | explicit = a['goal']['explicit_inform_slots'] 41 | implicit = a['goal']['implicit_inform_slots'] 42 | dict_combined = dict( explicit, **implicit) 43 | for s,value in dict_combined.items(): 44 | symptom_disease[s][disease]+=1 45 | if s not in disease_symptom2[disease].keys(): 46 | disease_symptom2[disease][s]=1 47 | else: 48 | disease_symptom2[disease][s]+=1 49 | if value==True: 50 | symptom_disease2[s][disease]+=1 51 | if s not in disease_symptom3[disease].keys(): 52 | disease_symptom3[disease][s]=1 53 | else: 54 | disease_symptom3[disease][s]+=1 55 | 56 | with open('./symptom_disease2.csv','w') as f: 57 | f.writelines(','+','.join(diseases)) 58 | f.writelines('\n') 59 | for d,s_dict in symptom_disease2.items(): 60 | s_value = list(s_dict.values()) 61 | s_value = [str(x) for x in s_value] 62 | f.writelines(d+','+','.join(s_value)) 63 | f.writelines('\n') 64 | -------------------------------------------------------------------------------- /log/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding : utf8 -*- 2 | import os 3 | 4 | 5 | def get_dir_list(path, key_word_list=None, no_key_word_list=None): 6 | file_name_list = os.listdir(path) # 获得原始json文件所在目录里面的所有文件名称 7 | if key_word_list == None and no_key_word_list == None: 8 | temp_file_list = file_name_list 9 | elif key_word_list != None and no_key_word_list == None: 10 | temp_file_list = [] 11 | for file_name in file_name_list: 12 | have_key_words = True 13 | for key_word in key_word_list: 14 | if key_word not in file_name: 15 | have_key_words = False 16 | break 17 | else: 18 | pass 19 | if have_key_words == True: 20 | temp_file_list.append(file_name) 21 | elif key_word_list == None and no_key_word_list != None: 22 | temp_file_list = [] 23 | for file_name in file_name_list: 24 | have_no_key_word = False 25 | for no_key_word in no_key_word_list: 26 | if no_key_word in file_name: 27 | have_no_key_word = True 28 | break 29 | if have_no_key_word == False: 30 | temp_file_list.append(file_name) 31 | elif key_word_list != None and no_key_word_list != None: 32 | temp_file_list = [] 33 | for file_name in file_name_list: 34 | have_key_words = True 35 | for key_word in key_word_list: 36 | if key_word not in file_name: 37 | have_key_words = False 38 | break 39 | else: 40 | pass 41 | have_no_key_word = False 42 | for no_key_word in no_key_word_list: 43 | if no_key_word in file_name: 44 | have_no_key_word = True 45 | break 46 | else: 47 | pass 48 | if have_key_words == True and have_no_key_word == False: 49 | temp_file_list.append(file_name) 50 | print(key_word_list, len(temp_file_list)) 51 | # time.sleep(2) 52 | return temp_file_list -------------------------------------------------------------------------------- /src/dqn_gym.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import gym 4 | import argparse 5 | from src.dialogue_system.policy_learning.actor_critic import ActorCritic 6 | import gym.spaces.box 7 | 8 | parser = argparse.ArgumentParser() 9 | # For Actor-critic 10 | parser.add_argument("--actor_learning_rate", dest="actor_learning_rate", type=float, default=0.001, help="the learning rate of actor") 11 | parser.add_argument("--critic_learning_rate", dest="critic_learning_rate", type=float, default=0.001, help="the learning rate of critic") 12 | parser.add_argument("--trajectory_pool_size", dest="trajectory_pool_size", type=int, default=100, help="the size of trajectory pool") 13 | parser.add_argument("--gamma", dest="gamma", type=float, default=0.9, help="The discount factor of immediate reward.") 14 | 15 | args = parser.parse_args() 16 | parameter = vars(args) 17 | 18 | # env = gym.make("MountainCar-v0") 19 | env = gym.make("Acrobot-v1") 20 | print(type(env.observation_space)) 21 | print(env.observation_space.shape) 22 | input_size = env.observation_space.shape[0] 23 | output_size = env.action_space.n 24 | hidden_size = 10 25 | 26 | 27 | actor_critic = ActorCritic(input_size,hidden_size,output_size,parameter) 28 | 29 | # state, agent_action, reward, next_state, episode_over 30 | def simulate(): 31 | total_reward = 0 32 | trajectory_pool = [] 33 | episode_size = 100 34 | for i_episode in range(episode_size): 35 | observation = env.reset() 36 | trajectory = [] 37 | done = False 38 | while done == False: 39 | # env.render() 40 | # action = env.action_space.sample() 41 | action = actor_critic.take_action(observation) 42 | next_observation, reward, done, info = env.step(action) 43 | total_reward += reward 44 | trajectory.append((observation, action,reward,next_observation, done)) 45 | observation = next_observation 46 | if done: 47 | # print("Episode finished after {} timesteps".format(t + 1)) 48 | break 49 | trajectory_pool.append(trajectory) 50 | return trajectory_pool, total_reward/float(episode_size) 51 | 52 | def train(trajectory_pool): 53 | for trajectory in trajectory_pool: 54 | actor_critic.train(trajectory) 55 | 56 | def run(): 57 | for _i in range(0, 2000,1): 58 | trajectory_pool, average_reward = simulate() 59 | print("%3d, average reward: %4f"%(_i, average_reward)) 60 | # if average_reward >= -110.0: 61 | # break 62 | train(trajectory_pool) 63 | 64 | if __name__ == "__main__": 65 | run() -------------------------------------------------------------------------------- /preprocess/run_pre.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | from top_disease import TopDiseaseReporter 4 | from match_disease import DiseaseMatch 5 | from aligned_symptoms_extracting import DataLoader 6 | 7 | # 10 diseases 8 | # top_disease_list = ["上呼吸道感染", "小儿消化不良", "小儿支气管炎","小儿腹泻","小儿感冒", 9 | # "小儿咳嗽","新生儿黄疸","小儿便秘","急性支气管炎","小儿支气管肺炎"] 10 | 11 | # 8 diseases 12 | # top_disease_list = ["上呼吸道感染", "小儿支气管炎","小儿腹泻","小儿感冒" 13 | # ,"新生儿黄疸","小儿便秘","急性支气管炎","小儿支气管肺炎"] 14 | 15 | # 7 diseases. 16 | # top_disease_list = ["上呼吸道感染", "小儿支气管炎", "小儿腹泻", "小儿感冒", 17 | # "小儿咳嗽", "急性支气管炎", "小儿支气管肺炎"] 18 | 19 | # 4 diseases. 20 | top_disease_list = ["上呼吸道感染", "小儿支气管炎", "小儿腹泻", "小儿消化不良",] 21 | 22 | # TODO: fist step 23 | # 从原始文件中把需要分析的疾病抽取出来,也就是最后保存的结果是(日期 1级科室 2级科室 咨询ID qid 提问内容 疾病标准名称)。 24 | # 从/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/儿科咨询疾病标注数据.xlsx 文件中把几种数量较多的疾病找 25 | # 出来,包含主诉内容和疾病信息,但是没有抽取的症状信息。 26 | def top_disease(): 27 | disease_file = "/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/儿科咨询疾病标注数据.xlsx" 28 | top_self_report_file_save = "./../resources/top_self_report_text.csv" 29 | 30 | report = TopDiseaseReporter(top_disease_list=top_disease_list, disease_file=disease_file) 31 | report.load_data() 32 | report.save(save_file=top_self_report_file_save) 33 | 34 | 35 | # TODO: second step 36 | # 从包含所有主诉内容和症状的文件(self_report_extracted_symptom.csv)中抽取出前几种疾病的主诉内容和症状。 37 | def match_top_self_report(): 38 | self_report_extracted_symptom_file = "./../resources/self_report_extracted_symptom.csv" # 一直都不要改变,固定的文件 39 | save_file_name = "./../resources/top_self_report_extracted_symptom.csv" 40 | 41 | match = DiseaseMatch(top_disease_list=top_disease_list, 42 | self_report_extracted_symptom_file=self_report_extracted_symptom_file) 43 | match.match(save_file_name=save_file_name) 44 | 45 | 46 | # TODO: third step 47 | # 对主诉症状、问答症状进行归一。 48 | # 使用简单的字符串相识度匹配。 49 | def symptom_normalization(): 50 | threshold = 0.2 51 | disease_symptom_aligned_file = "./../resources/top_disease_symptom_aligned.json" 52 | top_self_report_extracted_symptom_file = "./../resources/top_self_report_extracted_symptom.csv" 53 | conversation_symptom_file = "/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/conversation_symptom.txt" 54 | goal_spoken_writing_save = "./../resources/goal_spoken_writing_" + str(threshold) + ".json" 55 | goal_slot_value_save = "./../resources/goal_slot_value_" + str(threshold) + ".json" 56 | hand_crafted_symptom = True # 是否是人工匹配的症状归一,如果是的,对应的文件为top_disease_symptom_aligned.json,不是需要使用原始的文件。 57 | 58 | report_loader = DataLoader(threshold=threshold, disease_symptom_aligned_file=disease_symptom_aligned_file, 59 | hand_crafted_symptom=hand_crafted_symptom, 60 | top_disease_list=top_disease_list) 61 | report_loader.load_self_report(self_report_file=top_self_report_extracted_symptom_file) 62 | print("Conversation:") 63 | report_loader.load_conversation(conversation_file=conversation_symptom_file) 64 | 65 | slot_file = "./../resources/slot_set.txt" 66 | report_loader.write(file_name=goal_spoken_writing_save) 67 | report_loader.write_slot_value(file_name=goal_slot_value_save) 68 | report_loader.write_slots(file_name=slot_file) 69 | 70 | 71 | if __name__ == "__main__": 72 | top_disease() 73 | # match_top_self_report() 74 | # symptom_normalization() -------------------------------------------------------------------------------- /src/classifier/run/run_cla.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import time 4 | import argparse 5 | import pickle 6 | import sys, os 7 | sys.path.append(os.getcwd().replace("src/classifier/run","")) 8 | 9 | from src.classifier.symptom_as_feature.symptom_classifier import SymptomClassifier 10 | from src.classifier.self_report_as_feature.report_classifier import ReportClassifier 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--goal_set", dest="goal_set", type=str, default="./../../dialogue_system/data/dataset/label/goal_set.p", help='path and filename of user goal') 16 | parser.add_argument("--slot_set", dest="slot_set", type=str, default='./../../dialogue_system/data/dataset/label/slot_set.p', help='path and filename of the slots set') 17 | parser.add_argument("--disease_symptom", dest="disease_symptom", type=str, default="./../../dialogue_system/data/dataset/label/disease_symptom.p", help="path and filename of the disease_symptom file") 18 | 19 | 20 | parser.add_argument("--explicit_number", dest="explicit_number", type=int, default=0, help="the number of explicit symptoms of used sample") 21 | parser.add_argument("--implicit_number", dest="implicit_number", type=int, default=0, help="the number of implicit symptoms of used sample") 22 | 23 | 24 | parser.add_argument("--batch_size", dest="batch_size",type=int, default=32, help="the batch size for training.") 25 | parser.add_argument("--hidden_size", dest="hidden_size",type=int, default=40, help="the hidden size of classifier.") 26 | parser.add_argument("--train_feature", dest="train_feature", type=str, default="ex&im", help="only use explicit symptom for classification? ex:yes, ex&im:no") 27 | parser.add_argument("--test_feature", dest="test_feature", type=str, default="ex&im", help="only use explicit symptom for testing? ex:yes, ex&im:no") 28 | parser.add_argument("--checkpoint_path",dest="checkpoint_path", type=str, default="./../model/checkpoint/", help="the folder where models save to, ending with /.") 29 | parser.add_argument("--saved_model", dest="saved_model", type=str, default="./../model/dqn/checkpoint_d4_agt1_dqn1/model_d4_agent1_dqn1_s0.602_r17.036_t4.326_wd0.0_e214.ckpt") 30 | parser.add_argument("--learning_rate", dest="learning_rate", type=float, default=0.2,help="the learning rate when training the model.") 31 | 32 | args = parser.parse_args() 33 | parameter = vars(args) 34 | 35 | def run(): 36 | slot_set = pickle.load(file=open(parameter["slot_set"], "rb")) 37 | goal_set = pickle.load(file=open(parameter["goal_set"], "rb")) 38 | disease_symptom = pickle.load(file=open(parameter["disease_symptom"], "rb")) 39 | hidden_size = parameter.get("hidden_size") 40 | 41 | print("##"*30+"\nSymptom as features\n"+"##"*30) 42 | classifier = SymptomClassifier(goal_set=goal_set,symptom_set=slot_set,disease_symptom=disease_symptom,hidden_size=hidden_size,parameter=parameter,k_fold=5) 43 | classifier.train_sklearn_svm() 44 | print(classifier.disease_sample_count) 45 | # classifier.sample_to_file("./../data/goal_set.json") 46 | # classifier.dump_goal_set("/Volumes/LIUQL/dataset/goal_set_6.p") 47 | 48 | 49 | # print("##"*30+"\nSelf-report as features\n"+"##"*30) 50 | # data_file = "./../../../resources/top_self_report_extracted_symptom.csv" 51 | # stop_words = "./../data/stopwords.txt" 52 | # report_classifier = ReportClassifier(stop_words=stop_words,data_file=data_file) 53 | # report_classifier.train_tf() 54 | # report_classifier.evaluate_tf() 55 | # report_classifier.train_sklearn_svm() 56 | # report_classifier.evaluate_sklearn_svm() 57 | 58 | 59 | if __name__ == "__main__": 60 | run() -------------------------------------------------------------------------------- /preprocess/extract_symptom.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 根据top disease文件中的consult id,从主诉症状文件、对话内容症状抽取口语表达的症状。 4 | """ 5 | 6 | import csv 7 | import pandas as pd 8 | 9 | 10 | class SelfReportSymptomExtractor(object): 11 | def __init__(self, disease_list): 12 | self.symptom = {} 13 | for disease in disease_list: 14 | self.symptom[disease] = set() 15 | 16 | def extract(self, file_name): 17 | reader = csv.reader(open(file_name, encoding="utf-8", mode="r")) 18 | for line in reader: 19 | if line[5] in self.symptom.keys(): 20 | self._extract_symptom(line) 21 | 22 | def _extract_symptom(self,line): 23 | qid = line[0] 24 | consult_id = line[4] 25 | for index in range(7, len(line)): 26 | print(line) 27 | if len(line[index]) > 0: 28 | temp_symptom = line[index] 29 | self.symptom[line[5]].add((qid, consult_id, temp_symptom)) 30 | 31 | def save(self,save_file): 32 | writer = csv.writer(open(save_file, encoding="utf-8", mode="w")) 33 | for key in self.symptom.keys(): 34 | for symptom in self.symptom[key]: 35 | # writer.writerow([key] + list(self.symptom[key])) 36 | writer.writerow([symptom[0],symptom[1], key, symptom[2]]) 37 | 38 | 39 | class ConversationSymptomExtractor(object): 40 | def __init__(self, disease_list): 41 | self.symptom={} 42 | for disease in disease_list: 43 | self.symptom[disease] = set() 44 | 45 | def extract(self,consult_id_file, from_file): 46 | self.data = pd.read_csv(consult_id_file,header=None,) 47 | self.data.index = self.data[4] 48 | 49 | data_file = open(from_file,mode="r", encoding="utf-8") 50 | for line in data_file: 51 | self._extract(line) 52 | data_file.close() 53 | 54 | def _extract(self, line): 55 | line = line.replace("\n", "").split('\t') 56 | consult_id = int(line[0]) 57 | try: 58 | disease = self.data.loc[consult_id,5] 59 | print("*" * 30 + "\n", line) 60 | if disease in self.symptom.keys(): 61 | consult_id = line[0] 62 | for symptom in line[3:len(line)]: 63 | self.symptom[disease].add((consult_id,symptom)) 64 | print(disease, symptom) 65 | print("add symptom") 66 | except: 67 | pass 68 | 69 | def save(self,save_file): 70 | writer = csv.writer(open(save_file, encoding="utf-8", mode="w")) 71 | for key in self.symptom.keys(): 72 | for symptom in self.symptom[key]: 73 | writer.writerow([symptom[0],key,symptom[1]]) 74 | 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | from run_pre import top_disease_list 80 | 81 | # # Extracting symptoms from self-report. 82 | top_disease_symptom_file = "./../resources/top_self_report_extracted_symptom.csv" 83 | save_file = "./../resources/top_symptom_self_report.csv" 84 | extractor = SelfReportSymptomExtractor(top_disease_list) 85 | extractor.extract(file_name=top_disease_symptom_file) 86 | extractor.save(save_file) 87 | 88 | 89 | # Extracting symptoms from conversations. 90 | consult_id_file = "./../resources/top_self_report_extracted_symptom.csv" 91 | conversation_file = "/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/conversation_symptom.txt" 92 | save_file = "./../resources/top_symptom_conversation.csv" 93 | extractor = ConversationSymptomExtractor(top_disease_list) 94 | extractor.extract(consult_id_file=consult_id_file,from_file=conversation_file) 95 | extractor.save(save_file=save_file) 96 | -------------------------------------------------------------------------------- /preprocess/top_disease.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 从原始文件中把需要分析的疾病抽取出来,也就是最后保存的结果是(日期 1级科室 2级科室 咨询ID qid 提问内容 疾病标准名称)。 4 | 从/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/儿科咨询疾病标注数据.xlsx 文件中把几种数量较多的疾病找 5 | 出来,包含主诉内容和疾病信息,但是没有抽取的症状信息。 6 | """ 7 | import pandas as pd 8 | 9 | class TopDiseaseReporter(object): 10 | def __init__(self, top_disease_list, disease_file): 11 | self.disease_file = disease_file 12 | self.top_disease_list = top_disease_list 13 | self.data = pd.read_excel(io=disease_file, sheet_name="儿科主诉列表 (修正)") 14 | self.top_disease_data = pd.DataFrame(columns=self.data.columns) 15 | 16 | def load_data(self): 17 | key = self.data.columns[6] 18 | for index in self.data.index: 19 | if self.data.loc[index, key] in self.top_disease_list: 20 | print(index, self.data.loc[index, key]) 21 | self.top_disease_data = self.top_disease_data.append(self.data.loc[index]) 22 | # if len(top_disease_data.index) > 10: 23 | # break 24 | del self.top_disease_data["Unnamed: 7"] 25 | del self.top_disease_data["Unnamed: 8"] 26 | del self.top_disease_data["Unnamed: 9"] 27 | del self.top_disease_data["Unnamed: 10"] 28 | print(self.top_disease_data) 29 | 30 | def load_data_consult_id(self, consult_id_list): 31 | """ 32 | Match conversations from consult id. 33 | :param consult_id_list: 34 | :return: 35 | """ 36 | key = self.data.columns[3] 37 | for index in self.data.index: 38 | if str(self.data.loc[index, key]) in consult_id_list: 39 | print(index, self.data.loc[index, key]) 40 | self.top_disease_data = self.top_disease_data.append(self.data.loc[index]) 41 | # if len(top_disease_data.index) > 10: 42 | # break 43 | del self.top_disease_data["Unnamed: 7"] 44 | del self.top_disease_data["Unnamed: 8"] 45 | del self.top_disease_data["Unnamed: 9"] 46 | del self.top_disease_data["Unnamed: 10"] 47 | print(self.top_disease_data) 48 | 49 | def save(self, save_file): 50 | self.top_disease_data.to_csv(save_file, sep="\t", index=True, header=True) 51 | 52 | 53 | if __name__ == "__main__": 54 | import pickle 55 | import json 56 | import random 57 | from symptom_liking import ReportConversation 58 | 59 | top_disease_list = ["上呼吸道感染", "小儿支气管炎", "小儿腹泻", "小儿消化不良", ] 60 | 61 | disease_file = "/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/儿科咨询疾病标注数据.xlsx" 62 | top_self_report_file_save = "./../resources/880/top_self_report_text.csv" 63 | goal_set = pickle.load(open("./../src/dialogue_system/data/4_diseases/both/goal_set.p","rb")) 64 | consult_id_dict = {} 65 | for key in goal_set.keys(): 66 | print(key, len(goal_set[key])) 67 | for goal in goal_set[key]: 68 | if len(goal["goal"]["explicit_inform_slots"].keys()) >= 0 and\ 69 | len(goal["goal"]["implicit_inform_slots"].keys()) >= 1: 70 | consult_id_dict.setdefault(goal["disease_tag"],list()) 71 | consult_id_dict[goal["disease_tag"]].append(goal["consult_id"]) 72 | print(consult_id_dict.keys()) 73 | 74 | consult_id_list = [] 75 | for key in consult_id_dict.keys(): 76 | print(key, len(consult_id_dict[key])) 77 | consult_id_list = consult_id_list + list(random.sample(consult_id_dict[key], 250)) 78 | print(len(consult_id_list)) 79 | print(type(consult_id_list[0])) 80 | 81 | 82 | conversation_file_name = "/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/conversation.txt" 83 | top_conversation_save = "./../resources/880/top_conversation.txt" 84 | conversation = ReportConversation() 85 | consult_id_list = conversation.match(conversation_file_name=conversation_file_name, save_file_name=top_conversation_save,consult_id_list=consult_id_list) 86 | 87 | 88 | 89 | reporter = TopDiseaseReporter(top_disease_list=top_disease_list,disease_file=disease_file) 90 | reporter.load_data_consult_id(consult_id_list=consult_id_list) 91 | reporter.save(save_file=top_self_report_file_save) 92 | pickle.dump(file=open("./../resources/880/consult_id_list.p","wb"),obj=consult_id_list) 93 | print(len(consult_id_list)) 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /src/dialogue_system/agent/prioritized_new.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | 4 | class ReplayBuffer: 5 | """ 6 | Simple replay buffer to store and sample transition experiences 7 | """ 8 | def __init__(self, size): 9 | """ 10 | Constructor function 11 | args: 12 | size (int) : Maximum size of replay buffer 13 | """ 14 | self._maxsize = size 15 | self._storage = deque(maxlen=size) 16 | 17 | def __len__(self): 18 | return len(self._storage) 19 | 20 | def add(self, state, action, reward, next_state, done): 21 | """ 22 | Add transition data to the replay buffer 23 | args: 24 | state : Current state 25 | action : Action taken 26 | reward (float) : Received reward 27 | next_state : Next state 28 | done (bool) : Episode done 29 | """ 30 | data = (state, action, reward, next_state, done) 31 | self._storage.append(data) 32 | 33 | def _encode_sample(self, idxes): 34 | """ 35 | Sample data from given indexes 36 | args: 37 | idxes (list/np.array) : List with indexes of data to sample 38 | returns: 39 | np.array, np.array, np.array, np.array, np.array : Sampled states, actions, rewards, next_states and dones 40 | """ 41 | states, actions, rewards, next_states, dones = [], [], [], [], [] 42 | for i in idxes: 43 | obs_t, action, reward, obs_tp1, done = self._storage[i] 44 | states.append(np.array(obs_t, copy=False)) 45 | actions.append(np.array(action, copy=False)) 46 | rewards.append(reward) 47 | next_states.append(np.array(obs_tp1, copy=False)) 48 | dones.append(done) 49 | return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones) 50 | 51 | def sample(self, batch_size): 52 | """ 53 | Sample data from the replay buffer 54 | args: 55 | batch_size (int) : Maximum batch size to sample 56 | returns: 57 | tuple of 5 lists : Sampled batch of transitions 58 | """ 59 | batch_size = min(len(self), batch_size) 60 | idxes = np.random.randint(0, len(self), size=batch_size) 61 | return self._encode_sample(idxes) 62 | 63 | def clear(self): 64 | """ 65 | Clear the contents of replay buffer 66 | """ 67 | self._storage.clear() 68 | 69 | 70 | class PrioritizedReplayBuffer(object): 71 | 72 | def __init__(self, buffer_size): 73 | self._priorities = deque(maxlen=buffer_size) 74 | 75 | def __len__(self): 76 | return len(self._priorities) 77 | 78 | def add(self, state, action, reward, next_state, episode_over, error): 79 | self._priorities.append((state, action, reward, next_state, episode_over, error )) 80 | 81 | def sample(self, batch_size, priority_scale=1.0): 82 | batch_size = min(len(self._priorities), batch_size) 83 | batch_probs = self.get_probabilities(priority_scale) 84 | #print(len(self._priorities),len(batch_probs)) 85 | #print(batch_probs) 86 | batch_indices = np.random.choice(range(len(self._priorities)), size=batch_size, p=batch_probs) 87 | #batch_importance = self.get_importance(batch_probs[batch_indices]) 88 | batch = [self._priorities[x][:5] for x in batch_indices] 89 | 90 | return batch 91 | 92 | def get_probabilities(self, priority_scale): 93 | td_errors = np.array([abs(x[5]) for x in self._priorities]) 94 | #print(td_errors) 95 | scaled_priorities = td_errors ** priority_scale 96 | batch_probabilities = scaled_priorities / sum(scaled_priorities) 97 | return batch_probabilities 98 | 99 | def get_importance(self, probabilities): 100 | importance = 1 / (len(self._priorities) * probabilities+0.001) # TODO: The change here might create problem 101 | importance_normalized = importance / max(importance) 102 | return importance_normalized 103 | 104 | def set_priorities(self, indices, errors, offset=0.1): 105 | for i, e in zip(indices, errors): 106 | self._priorities[i] = abs(e) + offset -------------------------------------------------------------------------------- /src/dialogue_system/utils/plot_slot_distribution.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import matplotlib.pyplot as plt 4 | import pickle 5 | 6 | # name_list = ['Monday', 'Tuesday', 'Friday', 'Sunday'] 7 | # num_list = [1.5, 0.6, 7.8, 6] 8 | # num_list1 = [1, 2, 3, 1] 9 | # x = list(range(len(num_list))) 10 | # total_width, n = 0.8, 3 11 | # width = total_width / n 12 | # 13 | # plt.bar(x, num_list, width=width, label='boy', fc='y') 14 | # for i in range(len(x)): 15 | # x[i] = x[i] + width 16 | # plt.bar(x, num_list1, width=width, label='girl', tick_label=name_list, fc='r') 17 | # 18 | # x = [i + width for i in x] 19 | # plt.bar(x, num_list, width=width, label='men', tick_label=name_list, fc='b') 20 | # 21 | # plt.legend() 22 | # plt.show() 23 | # 24 | # 25 | # 26 | # name_list = ['Monday','Tuesday','Friday','Sunday'] * 20 27 | # num_list = [1.5,0.6,7.8,6] * 20 28 | # num_list1 = [1,2,3,1] * 20 29 | # plt.bar(range(len(num_list)), num_list, label='boy',fc = 'y') 30 | # plt.bar(range(len(num_list)), num_list1, bottom=num_list, label='girl',tick_label = name_list,fc = 'r') 31 | # bottom = [num_list[i] + num_list1[i] for i in range(len(num_list))] 32 | # plt.bar(range(len(num_list)), num_list1, bottom=bottom, label='girl',tick_label = name_list,fc = 'b') 33 | # plt.legend() 34 | # plt.show() 35 | 36 | 37 | 38 | class DistributionPloter(object): 39 | def __init__(self, goal_set_file): 40 | self.goal_set = pickle.load(open(goal_set_file, 'rb')) 41 | self.symptom2id, self.id2disease, self.symptom_dist_by_disease = self.__distribution__() 42 | self.disease_to_english = { 43 | '小儿腹泻': 'Infantile diarrhea', 44 | '小儿支气管炎': 'Children’s bronchitis', 45 | '小儿消化不良': 'Children functional dyspepsia', 46 | '上呼吸道感染': 'Upper respiratory infection' 47 | } 48 | 49 | def __distribution__(self): 50 | symptom2id = dict() 51 | id2disease = dict() 52 | disease2id = dict() 53 | 54 | for goal in self.goal_set['train'] + self.goal_set['test'] + self.goal_set['validate']: 55 | id = len(disease2id) 56 | disease2id.setdefault(goal['disease_tag'], id) 57 | for symptom in goal['goal']['explicit_inform_slots'].keys(): 58 | id = len(symptom2id) 59 | symptom2id.setdefault(symptom, id) 60 | for symptom in goal['goal']['implicit_inform_slots'].keys(): 61 | id = len(symptom2id) 62 | symptom2id.setdefault(symptom, id) 63 | 64 | symptom_dist_by_disease = {} 65 | for goal in self.goal_set['train'] + self.goal_set['test'] + self.goal_set['validate']: 66 | symptom_dist_by_disease.setdefault(goal['disease_tag'], [0] * len(symptom2id)) 67 | for symptom in goal['goal']['explicit_inform_slots'].keys(): 68 | symptom_dist_by_disease[goal['disease_tag']][symptom2id[symptom]] += 1 69 | for symptom in goal['goal']['implicit_inform_slots'].keys(): 70 | symptom_dist_by_disease[goal['disease_tag']][symptom2id[symptom]] += 1 71 | 72 | for key, v in disease2id.items(): 73 | id2disease[v] = key 74 | print(key, v) 75 | return symptom2id, id2disease, symptom_dist_by_disease 76 | 77 | def plot(self): 78 | colors = ['#2f79c0', '#278b18', '#ff5186', '#8660a4', '#D49E0F', '#a8d40f'] 79 | print(self.symptom2id) 80 | bottom = [0]* len(self.symptom2id) 81 | 82 | disease = self.id2disease[0] 83 | symptom_dist = self.symptom_dist_by_disease[disease] 84 | plt.bar(range(len(symptom_dist)), symptom_dist, label=self.disease_to_english[disease], fc=colors[0]) 85 | for index in range(1, len(self.id2disease)): 86 | disease = self.id2disease[index] 87 | symptom_dist = self.symptom_dist_by_disease[disease] 88 | print(disease,len(symptom_dist), symptom_dist) 89 | plt.bar(range(len(symptom_dist)), symptom_dist, bottom=bottom, label=self.disease_to_english[disease], fc=colors[index]) 90 | # plt.bar(range(len(symptom_dist)), symptom_dist, bottom=bottom, label=disease, tick_label=self.symptom2id.keys(), fc=colors[index]) 91 | bottom = [bottom[i] + symptom_dist[i] for i in range(len(self.symptom2id))] 92 | plt.legend() 93 | plt.savefig('symptom_dist.pdf') 94 | plt.show() 95 | 96 | 97 | if __name__ == '__main__': 98 | goal_set_file = './../../data/goal_set_2.p' 99 | ploter = DistributionPloter(goal_set_file) 100 | ploter.plot() -------------------------------------------------------------------------------- /src/dialogue_system/run/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import time 4 | import os 5 | 6 | def verify_params(params): 7 | dqn_type = params.get("dqn_type") 8 | if dqn_type not in ['DQN', 'DoubleDQN', 'DuelingDQN']: 9 | raise ValueError("dqn_type should be one of ['DQN', 'DoubleDQN','DuelingDQN']") 10 | 11 | return construct_info(params) 12 | 13 | def construct_info(params): 14 | """ 15 | Constructing a string which contains the primary super-parameters. 16 | Args: 17 | params: the super-parameter 18 | 19 | Returns: 20 | A dict, the updated parameter. 21 | """ 22 | os.environ["CUDA_VISIBLE_DEVICES"] = params["gpu"] 23 | gpu_str = os.environ.get("CUDA_VISIBLE_DEVICES") 24 | gpu_str.replace(' ', '') 25 | if len(gpu_str.split(',')) > 1: 26 | params.setdefault("multi_GPUs",True) 27 | else: 28 | params.setdefault("multi_GPUs", False) 29 | 30 | agent_id = params.get("agent_id") 31 | disease_number = params.get("disease_number") 32 | lr = params.get("dqn_learning_rate") 33 | reward_for_success = params.get("reward_for_success") 34 | reward_for_fail = params.get("reward_for_fail") 35 | reward_for_not_come_yet = params.get("reward_for_not_come_yet") 36 | reward_for_inform_right_symptom = params.get("reward_for_inform_right_symptom") 37 | reward_for_repeated_action = params.get("reward_for_repeated_action") 38 | reward_for_reach_max_turn = params.get("reward_for_reach_max_turn") 39 | allow_wrong_disease = params.get("allow_wrong_disease") 40 | check_related_symptoms = params.get("check_related_symptoms") 41 | 42 | max_turn = params.get("max_turn") 43 | minus_left_slots = params.get("minus_left_slots") 44 | gamma = params.get("gamma") 45 | gamma_worker = params.get('gamma_worker') 46 | epsilon = params.get("epsilon") 47 | data_set_name = params.get("goal_set").split("/")[-2] 48 | run_id = params.get('run_id') 49 | multi_gpu = params.get("multi_GPUs") 50 | dqn_type = params["dqn_type"] 51 | hrl_with_goal = params["hrl_with_goal"] 52 | weight_correction = params["weight_correction"] 53 | value_as_reward = params["value_as_reward"] 54 | symptom_dist_as_input = params["symptom_dist_as_input"] 55 | weight_for_reward_shaping = params["weight_for_reward_shaping"] 56 | disease_tag_for_terminating = params["disease_tag_for_terminating"] 57 | simulation_size = params["simulation_size"] 58 | is_relational_dqn = params["is_relational_dqn"] 59 | upper_bound_critic = params["upper_bound_critic"] 60 | lower_bound_critic = params["lower_bound_critic"] 61 | run_time = time.strftime('%m%d%H%M%S', time.localtime(time.time())) 62 | info = run_time + \ 63 | "_" + agent_id + \ 64 | "_T" + str(max_turn) + \ 65 | "_ss" + str(simulation_size) + \ 66 | "_lr" + str(lr) + \ 67 | "_RFS" + str(reward_for_success) + \ 68 | "_RFF" + str(reward_for_fail) + \ 69 | "_RFNCY" + str(reward_for_not_come_yet) + \ 70 | "_RFIRS" + str(reward_for_inform_right_symptom) +\ 71 | "_RFRA" + str(reward_for_repeated_action) +\ 72 | "_RFRMT" + str(reward_for_reach_max_turn) +\ 73 | "_mls" + str(int(minus_left_slots)) + \ 74 | "_gamma" + str(gamma) + \ 75 | "_gammaW" + str(gamma_worker) + \ 76 | "_epsilon" + str(epsilon) + \ 77 | "_awd" + str(int(allow_wrong_disease)) + \ 78 | "_crs" + str(int(check_related_symptoms)) + \ 79 | "_hwg" + str(int(hrl_with_goal)) + \ 80 | "_wc" + str(int(weight_correction)) + \ 81 | "_var" + str(int(value_as_reward)) + \ 82 | "_sdai" + str(int(symptom_dist_as_input)) + \ 83 | "_wfrs" + str(weight_for_reward_shaping) + \ 84 | "_dtft" + str(int(disease_tag_for_terminating)) + \ 85 | "_ird" + str(int(is_relational_dqn)) + \ 86 | "_ubc" + str(upper_bound_critic) + \ 87 | "_lbc" + str(lower_bound_critic) + \ 88 | "_data" + str(data_set_name.title()) + \ 89 | "_RID" + str(run_id) 90 | params['run_info'] = info 91 | 92 | checkpoint_path = "./../../model/" + dqn_type + "/checkpoint/" + info 93 | params["checkpoint_path"] = checkpoint_path 94 | 95 | performance_save_path = "./../../model/" + dqn_type + "/performance_new/" 96 | params["performance_save_path"] = performance_save_path 97 | 98 | visit_save_path = "./../../model/" + dqn_type + "/visit/" 99 | params["visit_save_path"] = visit_save_path 100 | 101 | return params -------------------------------------------------------------------------------- /src/dialogue_system/utils/slot_distribution.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import json 4 | import pickle 5 | import pandas as pd 6 | 7 | 8 | class SlotDistributor(object): 9 | def __init__(self, goal_set, slot_set, disease_symptom): 10 | self.goal_set = pickle.load(open(goal_set, "rb")) 11 | self.slot_set = pickle.load(open(slot_set, "rb")) 12 | self.disease_symptom = pickle.load(open(disease_symptom, "rb")) 13 | self.symptom_distribution = {} 14 | self.slot_set.pop("disease") 15 | for symptom in self.slot_set.keys(): 16 | if symptom != "disease" : 17 | self.symptom_distribution[symptom] = {} 18 | self.symptom_distribution[symptom]["total"] = 0.0 19 | self.symptom_distribution[symptom]["total_explicit"] = 0.0 20 | self.symptom_distribution[symptom]["total_implicit"] = 0.0 21 | for disease in self.disease_symptom.keys(): 22 | print(disease) 23 | self.symptom_distribution[symptom][disease] = {} 24 | self.symptom_distribution[symptom][disease]["total"] = 0.0 25 | self.symptom_distribution[symptom][disease]["implicit"] = 0.0 26 | self.symptom_distribution[symptom][disease]["explicit"] = 0.0 27 | 28 | def calculate(self): 29 | """ 30 | 统计每一个symptom在每种疾病下出现的个数,分为总的个数,在explicit的个数,在implicit的个数。 31 | :return: 32 | """ 33 | key_list = self.goal_set.keys() 34 | for key in key_list: 35 | for goal in self.goal_set[key]: 36 | disease = goal["disease_tag"] 37 | if goal["consult_id"] == "10000894": 38 | print(goal) 39 | exit(0) 40 | 41 | for symptom in goal["goal"]["explicit_inform_slots"].keys(): 42 | self.symptom_distribution[symptom]["total"] += 1 43 | self.symptom_distribution[symptom]["total_explicit"] += 1 44 | self.symptom_distribution[symptom][disease]["total"] += 1 45 | self.symptom_distribution[symptom][disease]["explicit"] += 1 46 | for symptom in goal["goal"]["implicit_inform_slots"].keys(): 47 | self.symptom_distribution[symptom]["total"] += 1 48 | self.symptom_distribution[symptom]["total_implicit"] += 1 49 | self.symptom_distribution[symptom][disease]["total"] += 1 50 | self.symptom_distribution[symptom][disease]["implicit"] += 1 51 | 52 | def write(self,file_name): 53 | 54 | pickle.dump(file=open(file_name, "wb"), obj=self.symptom_distribution) 55 | 56 | def to_dataframe(self): 57 | 58 | explicit_data = pd.DataFrame(index=list(self.slot_set.keys()), columns=list(self.disease_symptom.keys())) 59 | implicit_data = pd.DataFrame(index=list(self.slot_set.keys()), columns=list(self.disease_symptom.keys())) 60 | total_data = pd.DataFrame(index=list(self.slot_set.keys()), columns=list(self.disease_symptom.keys())) 61 | for symptom, value in self.symptom_distribution.items(): 62 | for disease in self.disease_symptom.keys(): 63 | if value["total_explicit"] > 0.0: 64 | explicit_data.loc[symptom, disease] = value[disease]["explicit"] / value["total_explicit"] 65 | else: 66 | explicit_data.loc[symptom, disease] = 0.0 67 | if value["total_implicit"] > 0.0: 68 | implicit_data.loc[symptom, disease] = value[disease]["implicit"] / value["total_implicit"] 69 | else: 70 | implicit_data.loc[symptom, disease] = 0.0 71 | if value["total"] > 0.0: 72 | total_data.loc[symptom, disease] = value[disease]["total"] / value["total"] 73 | else: 74 | total_data.loc[symptom, disease] = 0.0 75 | 76 | explicit_data.to_excel("./../../../resources/symptom_distribution/explicit.xlsx", sheet_name="explicit_distribution" ) 77 | implicit_data.to_excel("./../../../resources/symptom_distribution/implicit.xlsx", sheet_name="implicit_distribution" ) 78 | total_data.to_excel("./../../../resources/symptom_distribution/total.xlsx", sheet_name="total_distribution" ) 79 | 80 | 81 | 82 | 83 | 84 | if __name__ == "__main__": 85 | slot_set = "./../data/10_diseases/slot_set.p" 86 | goal_set = "./../data/10_diseases/goal_set_2.p" 87 | disease_symptom = "./../data/10_diseases/disease_symptom.p" 88 | distributor = SlotDistributor(goal_set,slot_set,disease_symptom) 89 | distributor.calculate() 90 | distributor.write("./../../../resources/symptom_distribution/symptom_distribution.p") 91 | distributor.to_dataframe() -------------------------------------------------------------------------------- /src/dialogue_system/agent/agent_with_goal_joint.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -* 2 | """ 3 | The agent will maintain two ranked list of candidate disease and symptoms, the two list will be updated every turn based 4 | on the information agent collected. The two ranked list will affect each other according pairs. 5 | Agent will choose the first symptom with request as the agent action aiming to ask if the user has the symptom. The rank 6 | model will change if the user's answer is no in continual several times. 7 | """ 8 | 9 | import random 10 | import sys, os 11 | import copy 12 | import json 13 | import numpy as np 14 | sys.path.append(os.getcwd().replace("src/dialogue_system/agent","")) 15 | from src.dialogue_system.agent.agent_dqn import AgentDQN 16 | from src.dialogue_system.policy_learning.dqn_with_goal_joint import DQNWithGoalJoint 17 | from src.dialogue_system.agent.utils import state_to_representation_last 18 | from src.dialogue_system import dialogue_configuration 19 | 20 | 21 | class AgentWithGoalJoint(AgentDQN): 22 | def __init__(self, action_set, slot_set, disease_symptom, parameter): 23 | super(AgentWithGoalJoint, self).__init__(action_set=action_set, slot_set=slot_set,disease_symptom=disease_symptom, parameter=parameter) 24 | input_size = parameter.get("input_size_dqn") 25 | hidden_size = parameter.get("hidden_size_dqn", 100) 26 | output_size = len(self.action_space) 27 | del self.dqn 28 | 29 | # symptom distribution by diseases. 30 | temp_slot_set = copy.deepcopy(slot_set) 31 | temp_slot_set.pop('disease') 32 | self.disease_to_symptom_dist = {} 33 | self.id2disease = {} 34 | total_count = np.zeros(len(temp_slot_set)) 35 | for disease, v in disease_symptom.items(): 36 | dist = np.zeros(len(temp_slot_set)) 37 | self.id2disease[v['index']] = disease 38 | for symptom, count in v['symptom'].items(): 39 | dist[temp_slot_set[symptom]] = count 40 | total_count[temp_slot_set[symptom]] += count 41 | self.disease_to_symptom_dist[disease] = dist 42 | 43 | for disease in self.disease_to_symptom_dist.keys(): 44 | self.disease_to_symptom_dist[disease] = self.disease_to_symptom_dist[disease] / total_count 45 | goal_embed_value = [0] * len(disease_symptom) 46 | for disease in self.disease_to_symptom_dist.keys(): 47 | self.disease_to_symptom_dist[disease] = self.disease_to_symptom_dist[disease] / total_count 48 | goal_embed_value[disease_symptom[disease]['index']] = list(self.disease_to_symptom_dist[disease]) 49 | self.dqn = DQNWithGoalJoint(input_size=input_size, hidden_size=hidden_size, output_size=output_size, goal_embedding_value=goal_embed_value, parameter=parameter) 50 | 51 | def record_training_sample(self, state, agent_action, reward, next_state, episode_over, **kwargs): 52 | shaping = self.reward_shaping(state, next_state) 53 | alpha = self.parameter.get("weight_for_reward_shaping") 54 | # if True: 55 | # print('shaping', shaping) 56 | 57 | # Reward shaping only when non-terminal state. 58 | if episode_over is True: 59 | pass 60 | else: 61 | reward = reward + alpha * shaping 62 | state_rep = state_to_representation_last(state=state, action_set=self.action_set, slot_set=self.slot_set, disease_symptom=self.disease_symptom, max_turn=self.parameter["max_turn"]) 63 | next_state_rep = state_to_representation_last(state=next_state, action_set=self.action_set, slot_set=self.slot_set, disease_symptom=self.disease_symptom, max_turn=self.parameter["max_turn"]) 64 | self.experience_replay_pool.append((state_rep, agent_action, reward, next_state_rep, episode_over)) 65 | 66 | def reward_shaping(self, state, next_state): 67 | def delete_item_from_dict(item, value): 68 | new_item = {} 69 | for k, v in item.items(): 70 | if v != value: new_item[k] = v 71 | return new_item 72 | 73 | # slot number in state. 74 | slot_dict = copy.deepcopy(state["current_slots"]["inform_slots"]) 75 | slot_dict.update(state["current_slots"]["explicit_inform_slots"]) 76 | slot_dict.update(state["current_slots"]["implicit_inform_slots"]) 77 | slot_dict.update(state["current_slots"]["proposed_slots"]) 78 | slot_dict.update(state["current_slots"]["agent_request_slots"]) 79 | slot_dict = delete_item_from_dict(slot_dict, dialogue_configuration.I_DO_NOT_KNOW) 80 | 81 | next_slot_dict = copy.deepcopy(next_state["current_slots"]["inform_slots"]) 82 | next_slot_dict.update(next_state["current_slots"]["explicit_inform_slots"]) 83 | next_slot_dict.update(next_state["current_slots"]["implicit_inform_slots"]) 84 | next_slot_dict.update(next_state["current_slots"]["proposed_slots"]) 85 | next_slot_dict.update(next_state["current_slots"]["agent_request_slots"]) 86 | next_slot_dict = delete_item_from_dict(next_slot_dict, dialogue_configuration.I_DO_NOT_KNOW) 87 | gamma = self.parameter.get("gamma") 88 | return gamma * len(next_slot_dict) - len(slot_dict) -------------------------------------------------------------------------------- /src/dialogue_system/agent/agent_rule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Rule-based agent. 4 | """ 5 | 6 | import copy 7 | import random 8 | import sys, os 9 | sys.path.append(os.getcwd().replace("src/dialogue_system/agent","")) 10 | 11 | from src.dialogue_system.agent import Agent 12 | from src.dialogue_system import dialogue_configuration 13 | 14 | 15 | class AgentRule(Agent): 16 | """ 17 | Rule-based agent. 18 | """ 19 | def __init__(self,action_set, slot_set, disease_symptom, parameter): 20 | super(AgentRule,self).__init__(action_set=action_set,slot_set=slot_set,disease_symptom=disease_symptom,parameter=parameter) 21 | 22 | def next(self, state, turn, greedy_strategy, **kwargs): 23 | candidate_disease_symptoms = self._get_candidate_disease_symptoms(state=state) 24 | disease = candidate_disease_symptoms["disease"] 25 | candidate_symptoms = candidate_disease_symptoms["candidate_symptoms"] 26 | 27 | self.agent_action["request_slots"].clear() 28 | self.agent_action["explicit_inform_slots"].clear() 29 | self.agent_action["implicit_inform_slots"].clear() 30 | self.agent_action["inform_slots"].clear() 31 | self.agent_action["turn"] = turn 32 | 33 | if len(candidate_symptoms) == 0: 34 | self.agent_action["action"] = "inform" 35 | self.agent_action["inform_slots"]["disease"] = disease 36 | else: 37 | symptom = random.choice(candidate_symptoms) 38 | self.agent_action["action"] = "request" 39 | self.agent_action["request_slots"].clear() 40 | self.agent_action["request_slots"][symptom] = dialogue_configuration.VALUE_UNKNOWN 41 | agent_action = copy.deepcopy(self.agent_action) 42 | agent_action.pop("turn") 43 | agent_action.pop("speaker") 44 | agent_index = self.action_space.index(agent_action) 45 | return self.agent_action, agent_index 46 | 47 | def _get_candidate_disease_symptoms(self, state): 48 | """ 49 | Comparing state["current_slots"] with disease_symptom to identify which disease the user may have. 50 | :param state: a dict, the current dialogue state gotten from dialogue state tracker.. 51 | :return: a list of candidate symptoms. 52 | """ 53 | inform_slots = state["current_slots"]["inform_slots"] 54 | inform_slots.update(state["current_slots"]["explicit_inform_slots"]) 55 | inform_slots.update(state["current_slots"]["implicit_inform_slots"]) 56 | wrong_diseases = state["current_slots"]["wrong_diseases"] 57 | 58 | # Calculate number of informed symptom for each disease. 59 | disease_match_number = {} 60 | for disease in self.disease_symptom.keys(): 61 | disease_match_number[disease] = {} 62 | disease_match_number[disease]["yes"] = 0 63 | disease_match_number[disease]["not_sure"] = 0 64 | disease_match_number[disease]["deny"] = 0 65 | 66 | for slot in inform_slots.keys(): 67 | for disease in disease_match_number.keys(): 68 | if inform_slots[slot] in self.disease_symptom[disease]["symptom"] and inform_slots[slot] == True: 69 | disease_match_number[disease]["yes"] += 1 70 | elif inform_slots[slot] in self.disease_symptom[disease]["symptom"] and inform_slots[slot] == dialogue_configuration.I_DO_NOT_KNOW: 71 | disease_match_number[disease]["not_sure"] += 1 72 | elif inform_slots[slot] in self.disease_symptom[disease]["symptom"] and inform_slots[slot] == dialogue_configuration.I_DENY: 73 | disease_match_number[disease]["deny"] += 1 74 | 75 | # Get the ratio of informed symptom number to the number of symptoms of each disease. 76 | disease_score = {} 77 | for disease in disease_match_number.keys(): 78 | yes_score = float(disease_match_number[disease]["yes"]) / len(self.disease_symptom[disease]["symptom"]) 79 | not_sure_score = float(disease_match_number[disease]["not_sure"]) / len(self.disease_symptom[disease]["symptom"]) 80 | deny_score = float(disease_match_number[disease]["deny"]) / len(self.disease_symptom[disease]["symptom"]) 81 | disease_score[disease] = yes_score - 0.5*not_sure_score - deny_score 82 | 83 | # Get the most probable disease that has not been wrongly informed 84 | sorted_diseases = sorted(disease_score.items(), key=lambda d: d[1], reverse=True) 85 | for disease in sorted_diseases: 86 | if disease[0] not in wrong_diseases: 87 | match_disease = disease[0] 88 | break 89 | # match_disease = max(disease_score.items(), key=lambda x: x[0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN])[0] # Get the most probable disease that the user have. 90 | # Candidate symptom list of symptoms that belong to the most probable disease but have't been informed yet. 91 | candidate_symptoms = [] 92 | for symptom in self.disease_symptom[match_disease]["symptom"]: 93 | if symptom not in inform_slots.keys(): 94 | candidate_symptoms.append(symptom) 95 | return {"disease":match_disease,"candidate_symptoms":candidate_symptoms} 96 | 97 | def train_mode(self): 98 | pass 99 | 100 | def eval_mode(self): 101 | pass -------------------------------------------------------------------------------- /src/dialogue_system/utils/plot_slot_dist.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import matplotlib.pyplot as plt 4 | import pickle 5 | 6 | # name_list = ['Monday', 'Tuesday', 'Friday', 'Sunday'] 7 | # num_list = [0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN.5, 0.6, 7.8, 6] 8 | # num_list1 = [0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN, 2, 3, 0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN] 9 | # x = list(range(len(num_list))) 10 | # total_width, n = 0.8, 3 11 | # width = total_width / n 12 | # 13 | # plt.bar(x, num_list, width=width, label='boy', fc='y') 14 | # for i in range(len(x)): 15 | # x[i] = x[i] + width 16 | # plt.bar(x, num_list1, width=width, label='girl', tick_label=name_list, fc='r') 17 | # 18 | # x = [i + width for i in x] 19 | # plt.bar(x, num_list, width=width, label='men', tick_label=name_list, fc='b') 20 | # 21 | # plt.legend() 22 | # plt.show() 23 | # 24 | # 25 | # 26 | # name_list = ['Monday','Tuesday','Friday','Sunday'] * 20 27 | # num_list = [0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN.5,0.6,7.8,6] * 20 28 | # num_list1 = [0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN,2,3,0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN] * 20 29 | # plt.bar(range(len(num_list)), num_list, label='boy',fc = 'y') 30 | # plt.bar(range(len(num_list)), num_list1, bottom=num_list, label='girl',tick_label = name_list,fc = 'r') 31 | # bottom = [num_list[i] + num_list1[i] for i in range(len(num_list))] 32 | # plt.bar(range(len(num_list)), num_list1, bottom=bottom, label='girl',tick_label = name_list,fc = 'b') 33 | # plt.legend() 34 | # plt.show() 35 | 36 | 37 | 38 | class DistributionPloter(object): 39 | def __init__(self, goal_set_file): 40 | self.goal_set = pickle.load(open(goal_set_file, 'rb')) 41 | self.symptom2id, self.id2disease, self.symptom_dist_by_disease = self.__distribution__() 42 | self.disease_to_english = { 43 | '小儿腹泻': 'Infantile diarrhea', 44 | '小儿支气管炎': 'Children’s bronchitis', 45 | '小儿消化不良': 'Children functional dyspepsia', 46 | '上呼吸道感染': 'Upper respiratory infection' 47 | } 48 | 49 | def __distribution__(self): 50 | symptom2id = dict() 51 | id2disease = dict() 52 | disease2id = dict() 53 | 54 | for goal in self.goal_set['train'] + self.goal_set['test'] + self.goal_set['validate']: 55 | id = len(disease2id) 56 | disease2id.setdefault(goal['disease_tag'], id) 57 | for symptom in goal['goal']['explicit_inform_slots'].keys(): 58 | id = len(symptom2id) 59 | symptom2id.setdefault(symptom, id) 60 | for symptom in goal['goal']['implicit_inform_slots'].keys(): 61 | id = len(symptom2id) 62 | symptom2id.setdefault(symptom, id) 63 | 64 | symptom_dist_by_disease = {} 65 | for goal in self.goal_set['train'] + self.goal_set['test'] + self.goal_set['validate']: 66 | symptom_dist_by_disease.setdefault(goal['disease_tag'], [0] * len(symptom2id)) 67 | for symptom in goal['goal']['explicit_inform_slots'].keys(): 68 | symptom_dist_by_disease[goal['disease_tag']][symptom2id[symptom]] += 1 69 | for symptom in goal['goal']['implicit_inform_slots'].keys(): 70 | symptom_dist_by_disease[goal['disease_tag']][symptom2id[symptom]] += 1 71 | 72 | for key, v in disease2id.items(): 73 | id2disease[v] = key 74 | print(key, v) 75 | return symptom2id, id2disease, symptom_dist_by_disease 76 | 77 | def plot(self): 78 | colors = ['#2f79c0', '#278b18', '#ff5186', '#8660a4', '#D49E0F', '#a8d40f'] 79 | print(self.symptom2id) 80 | bottom = [0]* len(self.symptom2id) 81 | 82 | disease = self.id2disease[0] 83 | symptom_dist = self.symptom_dist_by_disease[disease] 84 | plt.bar(range(len(symptom_dist)), symptom_dist, label=self.disease_to_english[disease], fc=colors[0]) 85 | for index in range(1, len(self.id2disease)): 86 | disease = self.id2disease[index] 87 | symptom_dist = self.symptom_dist_by_disease[disease] 88 | print(disease,len(symptom_dist), symptom_dist) 89 | plt.bar(range(len(symptom_dist)), symptom_dist, bottom=bottom, label=self.disease_to_english[disease], fc=colors[index]) 90 | # plt.bar(range(len(symptom_dist)), symptom_dist, bottom=bottom, label=disease, tick_label=self.symptom2id.keys(), fc=colors[index]) 91 | bottom = [bottom[i] + symptom_dist[i] for i in range(len(self.symptom2id))] 92 | plt.legend() 93 | plt.savefig('symptom_dist.pdf') 94 | plt.show() 95 | 96 | 97 | if __name__ == '__main__': 98 | goal_set_file = './../../data/real_world/goal_set.p' 99 | ploter = DistributionPloter(goal_set_file) 100 | ploter.plot() -------------------------------------------------------------------------------- /src/dialogue_system/utils/goal_action_slots_dumper.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | 将txt的action和json的action、症状都转化为只包含action, symptom的list,并持久化到文件保存, 4 | 后期使用中直接调用持久化的文件即可。这里每一个symptom都作为一个slot进行处理。 5 | """ 6 | import pickle 7 | import json 8 | import random 9 | 10 | 11 | class ActionDumper(object): 12 | """ 13 | 处理action文件,保存成list并进行持久化处理。 14 | """ 15 | def __init__(self, action_set_file): 16 | self.file_name = action_set_file 17 | 18 | def dump(self, dump_file_name): 19 | data_file = open(self.file_name, "r") 20 | action_set = [] 21 | for line in data_file: 22 | action_set.append(line.replace("\n","")) 23 | data_file.close() 24 | action_set_dict = {} 25 | for index in range(0, len(action_set), 1): 26 | action_set_dict[action_set[index]] = index 27 | pickle.dump(file=open(dump_file_name,"wb"), obj=action_set_dict) 28 | 29 | 30 | class SlotDumper(object): 31 | """ 32 | 处理disease_symptom文件,将里面的每一个symptom作为一个slot处理,进行持久化。 33 | """ 34 | def __init__(self, slots_file, hand_crafted_symptom=True): 35 | self.file_name = slots_file 36 | self.hand_crafted_symptom = hand_crafted_symptom 37 | 38 | def dump(self, slot_dump_file_name, disease_dump_file_name): 39 | self._load_slot() 40 | self.slot_set.add("disease") 41 | # self.slot_set.add("taskcomplete") 42 | 43 | slot_set = list(self.slot_set) 44 | slot_set_dict = {} 45 | for index in range(0, len(slot_set), 1): 46 | slot_set_dict[slot_set[index]] = index 47 | pickle.dump(file=open(slot_dump_file_name,"wb"), obj=slot_set_dict) 48 | pickle.dump(file=open(disease_dump_file_name, "wb"), obj=self.disease_symptom) 49 | 50 | def _load_slot(self): 51 | self.slot_set = set() 52 | self.disease_symptom = {} 53 | data_file = open(file=self.file_name, mode="r",encoding="utf-8") 54 | if self.hand_crafted_symptom == True: 55 | index = 0 56 | for line in data_file: 57 | line = json.loads(line) 58 | self.disease_symptom[line["name"]] = {} 59 | self.disease_symptom[line["name"]]["index"] = index 60 | self.disease_symptom[line["name"]]["symptom"] = list(line["symptom"].keys()) 61 | for key in line["symptom"].keys(): 62 | self.slot_set.add(key) 63 | index += 1 64 | else: 65 | index = 0 66 | for line in data_file: 67 | line = json.loads(line) 68 | self.disease_symptom[line["name"]] = {} 69 | self.disease_symptom[line["name"]]["index"] = index 70 | self.disease_symptom[line["name"]]["symptom"] = line["symptom"] 71 | for symptom in line["symptom"]: 72 | self.slot_set.add(symptom) 73 | index += 1 74 | 75 | data_file.close() 76 | 77 | 78 | class GoalDumper(object): 79 | def __init__(self, goal_file): 80 | self.file_name = goal_file 81 | self.slot_set = set() 82 | 83 | def dump(self, dump_file_name, train=0.8, test=0.2, validate=0.0): 84 | assert (train*100+test*100+validate*100==100), "train + test + validate not equals to 0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN.0." 85 | self.goal_set = [] 86 | data_file = open(file=self.file_name, mode="r") 87 | for line in data_file: 88 | line = json.loads(line) 89 | self.goal_set.append(line) 90 | data_file.close() 91 | goal_number = len(self.goal_set) 92 | data_set = { 93 | "train":[], 94 | "test":[], 95 | "validate":[] 96 | } 97 | 98 | for goal in self.goal_set: 99 | random_float = random.random() 100 | if random_float <= train: 101 | data_set["train"].append(goal) 102 | elif train < random_float and random_float <= train+test: 103 | data_set["test"].append(goal) 104 | else: 105 | data_set["validate"].append(goal) 106 | 107 | for slot, value in goal["goal"]["explicit_inform_slots"].items(): 108 | if value == False: print(goal) 109 | break 110 | for slot, value in goal["goal"]["implicit_inform_slots"].items(): 111 | if value == False: print(goal) 112 | break 113 | 114 | # for slot. 115 | for symptom in goal["goal"]["explicit_inform_slots"].keys(): self.slot_set.add(symptom) 116 | for symptom in goal["goal"]["implicit_inform_slots"].keys(): self.slot_set.add(symptom) 117 | 118 | pickle.dump(file=open(dump_file_name,"wb"), obj=data_set) 119 | 120 | def dump_slot(self,slot_file): 121 | self.slot_set.add("disease") 122 | slot_set_dict = {} 123 | slot_set = list(self.slot_set) 124 | for index in range(0, len(slot_set), 1): 125 | slot_set_dict[slot_set[index]] = index 126 | pickle.dump(file=open(slot_file,"wb"),obj=slot_set_dict) 127 | 128 | 129 | 130 | 131 | 132 | if __name__ == "__main__": 133 | # Action 134 | # action_file = "./../../../resources/action_set.txt" 135 | # action_dump_file = "./../data/action_set.p" 136 | # 137 | # action_dumper = ActionDumper(action_set_file=action_file) 138 | # action_dumper.dump(dump_file_name=action_dump_file) 139 | 140 | # Slots. 141 | slots_file = "./../../../resources/top_disease_symptom_aligned.json" 142 | slots_dump_file = "./../data/slot_set.p" 143 | disease_dump_file = "./../data/disease_symptom.p" 144 | slots_dumper = SlotDumper(slots_file=slots_file) 145 | slots_dumper.dump(slot_dump_file_name=slots_dump_file,disease_dump_file_name=disease_dump_file) 146 | 147 | # Goal 148 | goal_file = "./../../../resources/goal_slot_value_0.2.json" 149 | goal_dump_file = "./../data/goal_set_2.p" 150 | slots_dump_file = "./../data/slot_set_2.p" 151 | goal_dumper = GoalDumper(goal_file=goal_file) 152 | goal_dumper.dump(dump_file_name=goal_dump_file) -------------------------------------------------------------------------------- /src/dialogue_system/utils/plot_single_dist.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import matplotlib.pyplot as plt 4 | import os 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | 9 | # name_list = ['Monday', 'Tuesday', 'Friday', 'Sunday'] 10 | # num_list = [.5, 0.6, 7.8, 6] 11 | # num_list1 = [2, 3, 1,4] 12 | # x = list(range(len(num_list))) 13 | # total_width, n = 0.8, 3 14 | # width = total_width / n 15 | # 16 | # plt.bar(x, num_list, width=width, label='boy', fc='y') 17 | # for i in range(len(x)): 18 | # x[i] = x[i] + width 19 | # plt.bar(x, num_list1, width=width, label='girl', tick_label=name_list, fc='r') 20 | # 21 | # x = [i + width for i in x] 22 | # plt.bar(x, num_list, width=width, label='men', tick_label=name_list, fc='b') 23 | # 24 | # plt.legend() 25 | # plt.show() 26 | 27 | 28 | class PlotDistribution(object): 29 | def __init__(self, params): 30 | self.params = params 31 | 32 | def get_visitation_mean(self, path, key_word_list, no_key_word_list): 33 | file_list = PlotDistribution.get_dir_list(path=path, key_word_list=key_word_list, no_key_word_list=no_key_word_list) 34 | all_run_visitation = [] 35 | for file_name in file_list: 36 | visitation_count = pickle.load(open(os.path.join(path, file_name), 'rb')) 37 | visitation_list = [visitation_count[key] for key in sorted(visitation_count.keys())] 38 | all_run_visitation.append(visitation_list) 39 | count = np.array(all_run_visitation) 40 | return count.mean(axis=0) 41 | 42 | def plot(self): 43 | colors = ['#2f79c0', '#278b18', '#ff5186', '#8660a4', '#D49E0F', '#a8d40f', '#b4546f', '#6495ED', '#778899', '#48D1CC', '#00FA9A','#F4A460', '#8FBC8F','#C0C0C0'] 44 | no_key_word_list = ['.DS_Store','.pdf','RID9'] 45 | key_word_list = ['AgentDQN', '4599.p'] 46 | mean_point = self.get_visitation_mean(path=self.params['result_path'], 47 | key_word_list=key_word_list, 48 | no_key_word_list=no_key_word_list) 49 | mean_point = mean_point[0:len(mean_point) - 4] 50 | name_list = [i for i in range(len(mean_point))] 51 | plt.bar(range(len(mean_point)), mean_point, label='Flat-DQN', fc=colors[0]) 52 | plt.legend() 53 | plt.show() 54 | 55 | no_key_word_list = ['.DS_Store','.pdf','RID9'] 56 | key_word_list = ['AgentWithGoal2', '4599.p'] 57 | mean_point = self.get_visitation_mean(path=self.params['result_path'], 58 | key_word_list=key_word_list, 59 | no_key_word_list=no_key_word_list) 60 | mean_point = mean_point[0:len(mean_point) - 4] 61 | name_list = [i for i in range(len(mean_point))] 62 | plt.bar(range(len(mean_point)), mean_point, label='HRL, ex', fc=colors[0]) 63 | plt.legend() 64 | plt.show() 65 | 66 | no_key_word_list = ['.DS_Store','.pdf'] 67 | key_word_list = ['AgentWithGoal2', '4599.p','RID9'] 68 | mean_point = self.get_visitation_mean(path=self.params['result_path'], 69 | key_word_list=key_word_list, 70 | no_key_word_list=no_key_word_list) 71 | mean_point = mean_point[0:len(mean_point) - 4] 72 | name_list = [i for i in range(len(mean_point))] 73 | plt.bar(range(len(mean_point)), mean_point, label='HRL, ex&im', fc=colors[0]) 74 | plt.legend() 75 | plt.show() 76 | 77 | @staticmethod 78 | def get_dir_list(path, key_word_list=None, no_key_word_list=None): 79 | file_name_list = os.listdir(path) # 获得原始json文件所在目录里面的所有文件名称 80 | if key_word_list == None and no_key_word_list == None: 81 | temp_file_list = file_name_list 82 | elif key_word_list != None and no_key_word_list == None: 83 | temp_file_list = [] 84 | for file_name in file_name_list: 85 | have_key_words = True 86 | for key_word in key_word_list: 87 | if key_word not in file_name: 88 | have_key_words = False 89 | break 90 | else: 91 | pass 92 | if have_key_words == True: 93 | temp_file_list.append(file_name) 94 | elif key_word_list == None and no_key_word_list != None: 95 | temp_file_list = [] 96 | for file_name in file_name_list: 97 | have_no_key_word = False 98 | for no_key_word in no_key_word_list: 99 | if no_key_word in file_name: 100 | have_no_key_word = True 101 | break 102 | if have_no_key_word == False: 103 | temp_file_list.append(file_name) 104 | elif key_word_list != None and no_key_word_list != None: 105 | temp_file_list = [] 106 | for file_name in file_name_list: 107 | have_key_words = True 108 | for key_word in key_word_list: 109 | if key_word not in file_name: 110 | have_key_words = False 111 | break 112 | else: 113 | pass 114 | have_no_key_word = False 115 | for no_key_word in no_key_word_list: 116 | if no_key_word in file_name: 117 | have_no_key_word = True 118 | break 119 | else: 120 | pass 121 | if have_key_words == True and have_no_key_word == False: 122 | temp_file_list.append(file_name) 123 | print(key_word_list, len(temp_file_list)) 124 | # time.sleep(2) 125 | return temp_file_list 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser() 129 | 130 | parser.add_argument('--result_path', dest='result_path', type=str, default='/Users/qianlong/Desktop/visit2/', help='the directory of the results.') 131 | 132 | parser.add_argument('--metric', dest='metric', type=str, default='recall', help='the metric to show') 133 | 134 | args = parser.parse_args() 135 | params = vars(args) 136 | drawer = PlotDistribution(params) 137 | drawer.plot() 138 | -------------------------------------------------------------------------------- /src/dialogue_system/memory/prioritized.py: -------------------------------------------------------------------------------- 1 | # Modified by Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import random 5 | 6 | import numpy as np 7 | 8 | from src.dialogue_system.memory import Replay, util 9 | 10 | 11 | 12 | class SumTree: 13 | ''' 14 | Helper class for PrioritizedReplay 15 | 16 | This implementation is, with minor adaptations, Jaromír Janisch's. The license is reproduced below. 17 | For more information see his excellent blog series "Let's make a DQN" https://jaromiru.com/2016/09/27/lets-make-a-dqn-theory/ 18 | 19 | MIT License 20 | 21 | Copyright (c) 2018 Jaromír Janisch 22 | 23 | Permission is hereby granted, free of charge, to any person obtaining a copy 24 | of this software and associated documentation files (the "Software"), to deal 25 | in the Software without restriction, including without limitation the rights 26 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 27 | copies of the Software, and to permit persons to whom the Software is 28 | furnished to do so, subject to the following conditions: 29 | ''' 30 | write = 0 31 | 32 | def __init__(self, capacity): 33 | self.capacity = capacity 34 | self.tree = np.zeros(2 * capacity - 1) # Stores the priorities and sums of priorities 35 | self.indices = np.zeros(capacity) # Stores the indices of the experiences 36 | 37 | def _propagate(self, idx, change): 38 | parent = (idx - 1) // 2 39 | 40 | self.tree[parent] += change 41 | 42 | if parent != 0: 43 | self._propagate(parent, change) 44 | 45 | def _retrieve(self, idx, s): 46 | left = 2 * idx + 1 47 | right = left + 1 48 | 49 | if left >= len(self.tree): 50 | return idx 51 | 52 | if s <= self.tree[left]: 53 | return self._retrieve(left, s) 54 | else: 55 | return self._retrieve(right, s - self.tree[left]) 56 | 57 | def total(self): 58 | return self.tree[0] 59 | 60 | def add(self, p, index): 61 | idx = self.write + self.capacity - 1 62 | 63 | self.indices[self.write] = index 64 | self.update(idx, p) 65 | 66 | self.write += 1 67 | if self.write >= self.capacity: 68 | self.write = 0 69 | 70 | def update(self, idx, p): 71 | change = p - self.tree[idx] 72 | 73 | self.tree[idx] = p 74 | self._propagate(idx, change) 75 | 76 | def get(self, s): 77 | assert s <= self.total() 78 | idx = self._retrieve(0, s) 79 | indexIdx = idx - self.capacity + 1 80 | 81 | return (idx, self.tree[idx], self.indices[indexIdx]) 82 | 83 | def print_tree(self): 84 | for i in range(len(self.indices)): 85 | j = i + self.capacity - 1 86 | print(f'Idx: {i}, Data idx: {self.indices[i]}, Prio: {self.tree[j]}') 87 | 88 | 89 | class PrioritizedReplay(Replay): 90 | ''' 91 | Prioritized Experience Replay 92 | 93 | Implementation follows the approach in the paper "Prioritized Experience Replay", Schaul et al 2015" https://arxiv.org/pdf/1511.05952.pdf and is Jaromír Janisch's with minor adaptations. 94 | See memory_util.py for the license and link to Jaromír's excellent blog 95 | 96 | Stores agent experiences and samples from them for agent training according to each experience's priority 97 | 98 | The memory has the same behaviour and storage structure as Replay memory with the addition of a SumTree to store and sample the priorities. 99 | 100 | e.g. memory_spec 101 | "memory": { 102 | "name": "PrioritizedReplay", 103 | "alpha": 1, 104 | "epsilon": 0, 105 | "batch_size": 32, 106 | "max_size": 10000, 107 | "use_cer": true 108 | } 109 | ''' 110 | 111 | def __init__(self, paramter): 112 | super().__init__(paramter) 113 | self.epsilon = 0.001 114 | self.alpha = 0.9 115 | self.epsilon = np.full((1,), self.epsilon) 116 | self.alpha = np.full((1,), self.alpha) 117 | # adds a 'priorities' scalar to the data_keys and call reset again 118 | self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones', 'priorities'] 119 | self.reset() 120 | 121 | def reset(self): 122 | super().reset() 123 | self.tree = SumTree(self.max_size) 124 | 125 | def add_experience(self, state, action, reward, next_state, done, error=100000): 126 | ''' 127 | Implementation for update() to add experience to memory, expanding the memory size if necessary. 128 | All experiences are added with a high priority to increase the likelihood that they are sampled at least once. 129 | ''' 130 | super().add_experience(state, action, reward, next_state, done) 131 | priority = self.get_priority(error) 132 | self.priorities[self.head] = priority 133 | self.tree.add(priority, self.head) 134 | 135 | def get_priority(self, error): 136 | '''Takes in the error of one or more examples and returns the proportional priority''' 137 | return np.power(error + self.epsilon, self.alpha).squeeze() 138 | 139 | def sample_idxs(self, batch_size): 140 | '''Samples batch_size indices from memory in proportional to their priority.''' 141 | batch_idxs = np.zeros(batch_size) 142 | tree_idxs = np.zeros(batch_size, dtype=np.int) 143 | 144 | for i in range(batch_size): 145 | s = random.uniform(0, self.tree.total()) 146 | (tree_idx, p, idx) = self.tree.get(s) 147 | batch_idxs[i] = idx 148 | tree_idxs[i] = tree_idx 149 | 150 | batch_idxs = np.asarray(batch_idxs).astype(int) 151 | self.tree_idxs = tree_idxs 152 | if self.use_cer: # add the latest sample 153 | batch_idxs[-1] = self.head 154 | return batch_idxs 155 | 156 | def update_priorities(self, errors): 157 | ''' 158 | Updates the priorities from the most recent batch 159 | Assumes the relevant batch indices are stored in self.batch_idxs 160 | ''' 161 | priorities = self.get_priority(errors) 162 | assert len(priorities) == self.batch_idxs.size 163 | for idx, p in zip(self.batch_idxs, priorities): 164 | self.priorities[idx] = p 165 | for p, i in zip(priorities, self.tree_idxs): 166 | self.tree.update(i, p) 167 | -------------------------------------------------------------------------------- /preprocess/statistics.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | Used for statistics of user goal 4 | 疾病数量 每一种疾病对应的数据量、平均每一轮对话用户的symptom有多少(显性、隐性) 5 | """ 6 | 7 | import json 8 | import pickle 9 | import csv 10 | import pandas as pd 11 | 12 | 13 | class StatisticsOfUserGoal(object): 14 | def __init__(self, data_file): 15 | self.file_name = data_file 16 | goal_set = pickle.load(file=open(self.file_name, "rb")) 17 | self.goal_set = goal_set["train"] + goal_set["test"] + goal_set["validate"] 18 | self.information = {} 19 | 20 | 21 | # """ 22 | # { 23 | # "consult_id":123, 24 | # "request_slots":{ 25 | # "disease": "UNK" 26 | # }, 27 | # "explicit_inform_slots":{ 28 | # "咳嗽":true 29 | # }, 30 | # "implicit_inform_slots":{ 31 | # } 32 | # } 33 | # """ 34 | 35 | def statistics(self): 36 | 37 | self.information['all'] = {} 38 | self.information['all']["user_number"] = 0 39 | self.information['all']["explicit_number"] = 0 40 | self.information['all']["implicit_number"] = 0 41 | self.information['all']["symptom_number"] = list() 42 | 43 | for goal in self.goal_set: 44 | print(json.dumps(goal,indent=2)) 45 | disease = goal["disease_tag"] 46 | if disease not in self.information.keys(): 47 | self.information[disease] = {} 48 | self.information[disease]["user_number"] = 0 49 | self.information[disease]["explicit_number"] = 0 50 | self.information[disease]["implicit_number"] = 0 51 | self.information[disease]["symptom_number"] = list() 52 | 53 | for goal in self.goal_set: 54 | disease = goal["disease_tag"] 55 | explicit_inform_slots = goal["goal"]["explicit_inform_slots"] 56 | implicit_inform_slots = goal["goal"]["implicit_inform_slots"] 57 | if len(goal["goal"]["explicit_inform_slots"].keys()) >= 0 and \ 58 | len(goal["goal"]["implicit_inform_slots"].keys()) >= 0: 59 | 60 | self.information[disease]["user_number"] += 1 61 | self.information[disease]["explicit_number"] += len(explicit_inform_slots.keys()) 62 | self.information[disease]["implicit_number"] += len(implicit_inform_slots.keys()) 63 | self.information[disease]["symptom_number"] = self.information[disease]["symptom_number"] + list(goal["goal"]["explicit_inform_slots"].keys()) + list(goal["goal"]["implicit_inform_slots"].keys()) 64 | 65 | self.information['all']["user_number"] += 1 66 | self.information['all']["explicit_number"] += len(explicit_inform_slots.keys()) 67 | self.information['all']["implicit_number"] += len(implicit_inform_slots.keys()) 68 | self.information['all']["symptom_number"] = self.information['all']["symptom_number"] + list(goal["goal"]["explicit_inform_slots"].keys()) + list(goal["goal"]["implicit_inform_slots"].keys()) 69 | 70 | disease_list = list(self.information.keys()) 71 | for disease in disease_list: 72 | explicit_number = self.information[disease]["explicit_number"] 73 | implicit_number = self.information[disease]["implicit_number"] 74 | self.information[disease]["explicit_number"] = float(explicit_number) / self.information[disease]["user_number"] 75 | self.information[disease]["implicit_number"] = float(implicit_number) / self.information[disease]["user_number"] 76 | self.information[disease]["symptom_number"] = len(set(self.information[disease]["symptom_number"])) 77 | print(json.dumps(self.information)) 78 | 79 | def write_file(self, save_file): 80 | data_file = open(save_file, "w",encoding="utf-8") 81 | writer = csv.writer(data_file) 82 | writer.writerow(["disease", "user_number", "explicit_number", "implicit_number", "symptom_number"]) 83 | for disease in self.information.keys(): 84 | writer.writerow([disease, self.information[disease]["user_number"], self.information[disease]["explicit_number"],self.information[disease]["implicit_number"], self.information[disease]["symptom_number"]]) 85 | data_file.close() 86 | 87 | 88 | class StatisticsOfDiseaseSymptom(object): 89 | def __init__(self,disease_symptom_file): 90 | self.disease_symptom_file = disease_symptom_file 91 | self.disease_list = [] 92 | 93 | def statistics(self): 94 | disease_symptoms = {} 95 | data_file = open(file=self.disease_symptom_file,mode="r", encoding="utf-8") 96 | for line in data_file: 97 | line = json.loads(line) 98 | self.disease_list.append(line["name"]) 99 | disease_symptoms[line["name"]] = line["symptom"].keys() 100 | print(line) 101 | 102 | result = pd.DataFrame(index=self.disease_list,columns=self.disease_list) 103 | 104 | for index1 in range(0, len(self.disease_list), 1): 105 | for index2 in range(index1, len(self.disease_list), 1): 106 | count = 0 107 | for symptom1 in disease_symptoms[self.disease_list[index1]]: 108 | if symptom1 in disease_symptoms[self.disease_list[index2]]: count += 1 109 | result.loc[self.disease_list[index2], self.disease_list[index1]] = count 110 | data_file.close() 111 | self.result = result 112 | 113 | def save(self, file_name): 114 | # self.result.to_csv(file_name,encoding="utf-8") 115 | self.result.to_excel(file_name, sheet_name="Sheet1") 116 | 117 | 118 | 119 | if __name__ == "__main__": 120 | # statics for the goal set, e.g., average number of explicit symptoms, average of number of implicit symptoms and the 121 | # number of user goal of each disease. 122 | 123 | data_file = "./../src/dialogue_system/data/goal_set.p" 124 | save_file = "./../resources/goal_set_statistics.csv" 125 | save_file = "/Users/qianlong/Desktop/goal_set_statistics.csv" 126 | 127 | stata = StatisticsOfUserGoal(data_file=data_file) 128 | stata.statistics() 129 | stata.write_file(save_file=save_file) 130 | 131 | 132 | # statistics of overlap in symptoms for different diseases. 133 | # data_file = "./../resources/top_disease_symptom_aligned.json" 134 | # save_to = "./../resources/overlap_disease_symptom.xlsx" 135 | # statistics = StatisticsOfDiseaseSymptom(data_file) 136 | # statistics.statistics() 137 | # statistics.save(save_to) -------------------------------------------------------------------------------- /src/dialogue_system/utils/draw_curve_each.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | For parameters. Drwaring the learning curve for each combination of parameters. 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | import pickle 9 | import os 10 | import time 11 | 12 | 13 | class Ploter(object): 14 | def __init__(self, performance_file): 15 | self.performance_file = performance_file 16 | self.epoch_size = 0 17 | self.success_rate = {} 18 | self.average_reward = {} 19 | self.average_wrong_disease = {} 20 | self.average_turn = {} 21 | 22 | def load_data(self, performance_file, label): 23 | performance = pickle.load(file=open(performance_file, "rb")) 24 | self.epoch_size = max(self.epoch_size, len(performance.keys())) 25 | sr, ar, awd,at = self.__load_data(performance=performance) 26 | self.success_rate[label] = sr 27 | self.average_reward[label] = ar 28 | self.average_wrong_disease[label] = awd 29 | self.average_turn[label] = at 30 | 31 | def __load_data(self, performance): 32 | success_rate = [] 33 | average_reward = [] 34 | average_wrong_disease = [] 35 | average_turn = [] 36 | for index in range(0, len(performance.keys()),1): 37 | print(performance[index].keys()) 38 | success_rate.append(performance[index]["success_rate"]) 39 | average_reward.append(performance[index]["average_reward"]) 40 | average_wrong_disease.append(performance[index]["average_wrong_disease"]) 41 | average_turn.append(performance[index]["average_turn"]) 42 | return success_rate, average_reward, average_wrong_disease,average_turn 43 | 44 | 45 | def plot(self, save_name, label_list): 46 | # epoch_index = [i for i in range(0, 500, 1)] 47 | 48 | for label in self.success_rate.keys(): 49 | epoch_index = [i for i in range(0, len(self.success_rate[label]), 1)] 50 | 51 | plt.plot(epoch_index,self.success_rate[label][0:max(epoch_index)+1], label=label, linewidth=1) 52 | # plt.plot(epoch_index,self.average_turn[label][0:max(epoch_index)+1], label=label+"at", linewidth=1) 53 | 54 | # plt.hlines(0.11,0,epoch_index,label="Random Agent", linewidth=1, colors="r") 55 | # plt.hlines(0.38,0,epoch_index,label="Rule Agent", linewidth=1, colors="purple") 56 | 57 | plt.xlabel("Simulation Epoch") 58 | plt.ylabel("Success Rate") 59 | plt.title("Learning Curve") 60 | # if len(label_list) >= 2: 61 | # plt.legend() 62 | # plt.legend(loc="lower right") 63 | plt.grid(True) 64 | plt.savefig(save_name,dpi=400) 65 | 66 | plt.show() 67 | 68 | @staticmethod 69 | def get_dirlist(path, key_word_list=None, no_key_word_list=None): 70 | file_name_list = os.listdir(path) # 获得原始json文件所在目录里面的所有文件名称 71 | if key_word_list == None and no_key_word_list == None: 72 | temp_file_list = file_name_list 73 | elif key_word_list != None and no_key_word_list == None: 74 | temp_file_list = [] 75 | for file_name in file_name_list: 76 | have_key_words = True 77 | for key_word in key_word_list: 78 | if key_word not in file_name: 79 | have_key_words = False 80 | break 81 | else: 82 | pass 83 | if have_key_words == True: 84 | temp_file_list.append(file_name) 85 | elif key_word_list == None and no_key_word_list != None: 86 | temp_file_list = [] 87 | for file_name in file_name_list: 88 | have_no_key_word = False 89 | for no_key_word in no_key_word_list: 90 | if no_key_word in file_name: 91 | have_no_key_word = True 92 | break 93 | if have_no_key_word == False: 94 | temp_file_list.append(file_name) 95 | elif key_word_list != None and no_key_word_list != None: 96 | temp_file_list = [] 97 | for file_name in file_name_list: 98 | have_key_words = True 99 | for key_word in key_word_list: 100 | if key_word not in file_name: 101 | have_key_words = False 102 | break 103 | else: 104 | pass 105 | have_no_key_word = False 106 | for no_key_word in no_key_word_list: 107 | if no_key_word in file_name: 108 | have_no_key_word = True 109 | break 110 | else: 111 | pass 112 | if have_key_words == True and have_no_key_word == False: 113 | temp_file_list.append(file_name) 114 | 115 | return temp_file_list 116 | 117 | 118 | if __name__ == "__main__": 119 | # file_name = "./../model/dqn/learning_rate/learning_rate_d4_e999_agent1_dqn1.p" 120 | # file_name = "/Users/qianlong/Desktop/learning_rate_d4_e_agent1_dqn1_T22_lr0.001_SR44_mls0_gamma0.95_epsilon0.1_1499.p" 121 | # save_name = file_name + ".png" 122 | # ploter = Ploter(file_name) 123 | # ploter.load_data(performance_file=file_name, label="DQN Agent") 124 | # ploter.plot(save_name, label_list=["DQN Agent"]) 125 | 126 | # ploter.load_data("./../model/dqn/learning_rate/learning_rate_d7_e999_agent1_dqn1.p",label="d7a1q1") 127 | # ploter.load_data("./../model/dqn/learning_rate/learning_rate_d10_e999_agent1_dqn0.p",label="d10a1q0") 128 | # ploter.load_data("./../model/dqn/learning_rate/learning_rate_d10_e999_agent1_dqn1.p",label="d10a1q1") 129 | # ploter.plot(save_name, label_list=["d7a1q0", "d7a1q1", "d10a1q0", "d10a1q1"]) 130 | 131 | 132 | # Draw learning curve from directory. 133 | path = "/Volumes/LIUQL/dataset/1100/learning_rate/" 134 | save_path = "/Volumes/LIUQL/dataset/1100/learning_curve/" 135 | no_key_word_list = ["_99.", "_199.", "_299.", "_399."] 136 | performance_file_list = Ploter.get_dirlist(path=path,key_word_list=["_1499"],no_key_word_list=["_99.","_199.","_299.","_399."]) 137 | print("file_number:", len(performance_file_list)) 138 | time.sleep(8) 139 | 140 | for file_name in performance_file_list: 141 | print(file_name) 142 | performance_file = path + file_name 143 | save_name = save_path + file_name + ".png" 144 | ploter = Ploter(performance_file=performance_file) 145 | ploter.load_data(performance_file=performance_file,label="DQN Agent") 146 | ploter.plot(save_name=save_name,label_list=["DQN Agent"]) -------------------------------------------------------------------------------- /preprocess/kliao/svm_class.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 12 16:47:05 2018 4 | 5 | @author: DELL 6 | """ 7 | 8 | import numpy as np 9 | import os,sys 10 | import pickle 11 | import copy 12 | import pandas as pd 13 | sys.path.append(os.getcwd().replace('/resource/tagger2','')) 14 | from preprocess.label.preprocess_label import GoalDumper 15 | from sklearn.svm import SVC 16 | from sklearn import svm 17 | #from sklearn.metrics import accuracy_score 18 | #from sklearn.model_selection import train_test_split 19 | from sklearn.model_selection import train_test_split,cross_val_score,cross_validate 20 | 21 | def disease_symptom_clip(disease_symptom, denominator): 22 | """ 23 | Keep the top min(symptom_num, max_turn//denominator) for each disease, and the related symptoms are sorted 24 | descendent according to their frequencies. 25 | 26 | Args: 27 | disease_symptom: a dict, key is the names of diseases, and the corresponding value is a dict too which 28 | contains the index of this disease and the related symptoms. 29 | denominator: int, the number of symptoms for each diseases is max_turn // denominator. 30 | parameter: the super-parameter. 31 | 32 | Returns: 33 | and dict, whose keys are the names of diseases, and the values are dicts too with two keys: {'index', symptom} 34 | """ 35 | max_turn = 22 36 | temp_disease_symptom = copy.deepcopy(disease_symptom) 37 | for key, value in disease_symptom.items(): 38 | symptom_list = sorted(value['symptom'].items(),key = lambda x:x[1],reverse = True) 39 | symptom_list = [v[0] for v in symptom_list] 40 | symptom_list = symptom_list[0:min(len(symptom_list), int(max_turn / float(denominator)))] 41 | temp_disease_symptom[key]['symptom'] = symptom_list 42 | #print('\n',disease_symptom) 43 | #print('\n',temp_disease_symptom) 44 | return temp_disease_symptom 45 | 46 | def svm_model(dataset,min_count,target,svm_c): 47 | index_len=600 48 | slots_x=np.zeros((index_len,sum(sum(abs(dataset),1)>min_count))) 49 | count=0 50 | for i,value in enumerate(sum(dataset,1)): 51 | if value<=min_count: 52 | pass 53 | else: 54 | slots_x[:,count]=dataset[:,i] 55 | count+=1 56 | 57 | slots_input=pd.DataFrame(index=range(index_len)) 58 | for col in range(slots_x.shape[1]): 59 | column=slots_x[:,col] 60 | column_mod=[] 61 | for j in column: 62 | if j==1: 63 | column_mod.append('yes') 64 | elif j==-1: 65 | column_mod.append('no') 66 | else: 67 | column_mod.append('UNK') 68 | slots_input[str(col)]=column_mod 69 | 70 | slots_input=pd.get_dummies(slots_input) 71 | target=np.array(target) 72 | #x_train, x_val, y_train, y_val = train_test_split(slots_input, disease_y, test_size=0.3, random_state=10) 73 | clf = svm.SVC(kernel='linear', C=svm_c) 74 | scores=[] 75 | for i in range(10): 76 | scores_exp = cross_validate(clf, slots_input, target, cv=5,scoring='accuracy',return_train_score=False) 77 | scores_tot=sum(scores_exp['test_score'])/5 78 | scores.append(scores_tot) 79 | return np.mean(scores) 80 | 81 | goal_file = "./goal_batch2.json" 82 | goal_dump_file = "./goal_set.p" 83 | slots_dump_file = "./slot_set.p" 84 | goal=GoalDumper(goal_file) 85 | goal.dump(goal_dump_file) 86 | goal_set=goal.goalset 87 | goal.dump_slot(slots_dump_file) 88 | #goal_set,slot_set=goal.set_return() 89 | slot_set=pickle.load(open(slots_dump_file,'rb')) 90 | slot_set.pop('disease') 91 | disease_symptom=goal.disease_symptom 92 | disease_symptom1=disease_symptom_clip(disease_symptom,2) 93 | #slot_set.pop('发热39度3') 94 | #slot_set.pop('发热37.7至38.4度') 95 | 96 | disease_y=[] 97 | total_set=copy.deepcopy(goal_set['train']) 98 | total_set.extend(goal_set['test']) 99 | slots_exp=np.zeros((len(total_set),len(slot_set))) 100 | slots_all=np.zeros((len(total_set),len(slot_set))) 101 | #slots_exp=pd.DataFrame(slots_exp,columns=slot_set.keys()) 102 | #slots_all=pd.DataFrame(slots_all,columns=slot_set.keys()) 103 | for i,dialogue in enumerate(total_set): 104 | tag=dialogue['disease_tag'] 105 | tag_group=disease_symptom1[tag]['symptom'] 106 | disease_y.append(tag) 107 | goal=dialogue['goal'] 108 | explicit=goal['explicit_inform_slots'] 109 | implicit=goal['implicit_inform_slots'] 110 | for slot,value in implicit.items(): 111 | try: 112 | slot_id=slot_set[slot] 113 | if value==True: 114 | slots_all[i,slot_id]='1' 115 | if value==False: 116 | slots_all[i,slot_id]='-1' 117 | except: 118 | pass 119 | for exp_slot,value in explicit.items(): 120 | try: 121 | slot_id=slot_set[exp_slot] 122 | if value==True: 123 | slots_exp[i,slot_id]='1' 124 | slots_all[i,slot_id]='1' 125 | if value==False: 126 | slots_exp[i,slot_id]='-1' 127 | slots_all[i,slot_id]='-1' 128 | except: 129 | pass 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | score_tot_exp=svm_model(dataset=slots_exp,min_count=0,target=disease_y,svm_c=10) 139 | score_tot_all=svm_model(dataset=slots_all,min_count=0,target=disease_y,svm_c=10) 140 | ''' 141 | slots_x_all=np.zeros((len(total_set),sum(sum(abs(slots_all),1)>5))) 142 | count=0 143 | for i,value in enumerate(sum(slots_all,1)): 144 | if value<=5: 145 | pass 146 | else: 147 | slots_x_all[:,count]=slots_all[:,i] 148 | count+=1 149 | 150 | slots_input_all=pd.DataFrame(index=range(len(total_set))) 151 | for col in range(slots_x_all.shape[1]): 152 | column=slots_x_all[:,col] 153 | column_mod=[] 154 | for j in column: 155 | if j==1: 156 | column_mod.append('yes') 157 | elif j==-1: 158 | column_mod.append('no') 159 | else: 160 | column_mod.append('UNK') 161 | slots_input_all[str(col)]=column_mod 162 | 163 | slots_input_all=pd.get_dummies(slots_input_all) 164 | disease_y=np.array(disease_y) 165 | #x_train, x_val, y_train, y_val = train_test_split(slots_input, disease_y, test_size=0.3, random_state=10) 166 | clf = svm.SVC(kernel='linear', C=5) 167 | 168 | scores_all = cross_validate(clf, slots_input_all, disease_y, cv=5,scoring='accuracy',return_train_score=False) 169 | scores_tot_all=sum(scores_all['test_score'])/5 170 | ''' 171 | 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /preprocess/label/svm_class.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 12 16:47:05 2018 4 | 5 | @author: DELL 6 | """ 7 | 8 | import numpy as np 9 | import os,sys 10 | import pickle 11 | import copy 12 | import pandas as pd 13 | sys.path.append(os.getcwd().replace('/resource/tagger2','')) 14 | from preprocess.label.preprocess_label import GoalDumper 15 | from sklearn.svm import SVC 16 | from sklearn import svm 17 | #from sklearn.metrics import accuracy_score 18 | #from sklearn.model_selection import train_test_split 19 | from sklearn.model_selection import train_test_split,cross_val_score,cross_validate 20 | 21 | def disease_symptom_clip(disease_symptom, denominator): 22 | """ 23 | Keep the top min(symptom_num, max_turn//denominator) for each disease, and the related symptoms are sorted 24 | descendent according to their frequencies. 25 | 26 | Args: 27 | disease_symptom: a dict, key is the names of diseases, and the corresponding value is a dict too which 28 | contains the index of this disease and the related symptoms. 29 | denominator: int, the number of symptoms for each diseases is max_turn // denominator. 30 | parameter: the super-parameter. 31 | 32 | Returns: 33 | and dict, whose keys are the names of diseases, and the values are dicts too with two keys: {'index', symptom} 34 | """ 35 | max_turn = 22 36 | temp_disease_symptom = copy.deepcopy(disease_symptom) 37 | for key, value in disease_symptom.items(): 38 | symptom_list = sorted(value['symptom'].items(),key = lambda x:x[1],reverse = True) 39 | symptom_list = [v[0] for v in symptom_list] 40 | symptom_list = symptom_list[0:min(len(symptom_list), int(max_turn / float(denominator)))] 41 | temp_disease_symptom[key]['symptom'] = symptom_list 42 | #print('\n',disease_symptom) 43 | #print('\n',temp_disease_symptom) 44 | return temp_disease_symptom 45 | 46 | def svm_model(dataset,min_count,target,svm_c): 47 | print(dataset.shape) 48 | index_len=dataset.shape[0] 49 | slots_x=np.zeros((index_len,sum(sum(abs(dataset),1)>min_count))) 50 | count=0 51 | for i,value in enumerate(sum(dataset,1)): 52 | if value<=min_count: 53 | pass 54 | else: 55 | slots_x[:,count]=dataset[:,i] 56 | count+=1 57 | 58 | slots_input=pd.DataFrame(index=range(index_len)) 59 | for col in range(slots_x.shape[1]): 60 | column=slots_x[:,col] 61 | column_mod=[] 62 | for j in column: 63 | if j==1: 64 | column_mod.append('yes') 65 | elif j==-1: 66 | column_mod.append('no') 67 | else: 68 | column_mod.append('UNK') 69 | slots_input[str(col)]=column_mod 70 | 71 | slots_input=pd.get_dummies(slots_input) 72 | target=np.array(target) 73 | #x_train, x_val, y_train, y_val = train_test_split(slots_input, disease_y, test_size=0.3, random_state=10) 74 | clf = svm.SVC(kernel='linear', C=svm_c) 75 | scores=[] 76 | for i in range(10): 77 | scores_exp = cross_validate(clf, slots_input, target, cv=5,scoring='accuracy',return_train_score=False) 78 | scores_tot=sum(scores_exp['test_score'])/5 79 | scores.append(scores_tot) 80 | return np.mean(scores) 81 | 82 | 83 | slots_set_file = './../../src/dialogue_system/data/slot_set.p' 84 | goal_set_file = './../../src/dialogue_system/data/goal_set.p' 85 | disease_symptom_file = './../../src/dialogue_system/data/disease_symptom.p' 86 | goal_set = pickle.load(open(goal_set_file,'rb')) 87 | slot_set=pickle.load(open(slots_set_file,'rb')) 88 | slot_set.pop('disease') 89 | disease_symptom= pickle.load(open(disease_symptom_file, 'rb')) 90 | disease_symptom1=disease_symptom_clip(disease_symptom,2) 91 | #slot_set.pop('发热39度3') 92 | #slot_set.pop('发热37.7至38.4度') 93 | 94 | disease_y=[] 95 | total_set=copy.deepcopy(goal_set['train']) 96 | total_set.extend(goal_set['test']) 97 | slots_exp=np.zeros((len(total_set),len(slot_set))) 98 | slots_all=np.zeros((len(total_set),len(slot_set))) 99 | #slots_exp=pd.DataFrame(slots_exp,columns=slot_set.keys()) 100 | #slots_all=pd.DataFrame(slots_all,columns=slot_set.keys()) 101 | for i,dialogue in enumerate(total_set): 102 | tag=dialogue['disease_tag'] 103 | tag_group=disease_symptom1[tag]['symptom'] 104 | disease_y.append(tag) 105 | goal=dialogue['goal'] 106 | explicit=goal['explicit_inform_slots'] 107 | implicit=goal['implicit_inform_slots'] 108 | for slot,value in implicit.items(): 109 | try: 110 | slot_id=slot_set[slot] 111 | if value==True: 112 | slots_all[i,slot_id]='1' 113 | if value==False: 114 | slots_all[i,slot_id]='-1' 115 | except: 116 | pass 117 | for exp_slot,value in explicit.items(): 118 | try: 119 | slot_id=slot_set[exp_slot] 120 | if value==True: 121 | slots_exp[i,slot_id]='1' 122 | slots_all[i,slot_id]='1' 123 | if value==False: 124 | slots_exp[i,slot_id]='-1' 125 | slots_all[i,slot_id]='-1' 126 | except: 127 | pass 128 | 129 | 130 | 131 | 132 | score_tot_exp=svm_model(dataset=slots_exp,min_count=0,target=disease_y,svm_c=10) 133 | score_tot_all=svm_model(dataset=slots_all,min_count=0,target=disease_y,svm_c=10) 134 | print('exp', score_tot_exp) 135 | print('all', score_tot_all) 136 | ''' 137 | slots_x_all=np.zeros((len(total_set),sum(sum(abs(slots_all),1)>5))) 138 | count=0 139 | for i,value in enumerate(sum(slots_all,1)): 140 | if value<=5: 141 | pass 142 | else: 143 | slots_x_all[:,count]=slots_all[:,i] 144 | count+=1 145 | 146 | slots_input_all=pd.DataFrame(index=range(len(total_set))) 147 | for col in range(slots_x_all.shape[1]): 148 | column=slots_x_all[:,col] 149 | column_mod=[] 150 | for j in column: 151 | if j==1: 152 | column_mod.append('yes') 153 | elif j==-1: 154 | column_mod.append('no') 155 | else: 156 | column_mod.append('UNK') 157 | slots_input_all[str(col)]=column_mod 158 | 159 | slots_input_all=pd.get_dummies(slots_input_all) 160 | disease_y=np.array(disease_y) 161 | #x_train, x_val, y_train, y_val = train_test_split(slots_input, disease_y, test_size=0.3, random_state=10) 162 | clf = svm.SVC(kernel='linear', C=5) 163 | 164 | scores_all = cross_validate(clf, slots_input_all, disease_y, cv=5,scoring='accuracy',return_train_score=False) 165 | scores_tot_all=sum(scores_all['test_score'])/5 166 | ''' -------------------------------------------------------------------------------- /src/dialogue_system/memory/onpolicy.py: -------------------------------------------------------------------------------- 1 | # Modified by Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from convlab.agent.memory.base import Memory 5 | from convlab.lib import logger, util 6 | from convlab.lib.decorator import lab_api 7 | 8 | logger = logger.get_logger(__name__) 9 | 10 | 11 | class OnPolicyReplay(Memory): 12 | ''' 13 | Stores agent experiences and returns them in a batch for agent training. 14 | 15 | An experience consists of 16 | - state: representation of a state 17 | - action: action taken 18 | - reward: scalar value 19 | - next state: representation of next state (should be same as state) 20 | - done: 0 / 1 representing if the current state is the last in an episode 21 | 22 | The memory does not have a fixed size. Instead the memory stores data from N episodes, where N is determined by the user. After N episodes, all of the examples are returned to the agent to learn from. 23 | 24 | When the examples are returned to the agent, the memory is cleared to prevent the agent from learning from off policy experiences. This memory is intended for on policy algorithms. 25 | 26 | Differences vs. Replay memory: 27 | - Experiences are nested into episodes. In Replay experiences are flat, and episode is not tracked 28 | - The entire memory constitues a batch. In Replay batches are sampled from memory. 29 | - The memory is cleared automatically when a batch is given to the agent. 30 | 31 | e.g. memory_spec 32 | "memory": { 33 | "name": "OnPolicyReplay" 34 | } 35 | ''' 36 | 37 | def __init__(self, memory_spec, body): 38 | super().__init__(memory_spec, body) 39 | # NOTE for OnPolicy replay, frequency = episode; for other classes below frequency = frames 40 | util.set_attr(self, self.body.agent.agent_spec['algorithm'], ['training_frequency']) 41 | # Don't want total experiences reset when memory is 42 | self.is_episodic = True 43 | self.size = 0 # total experiences stored 44 | self.seen_size = 0 # total experiences seen cumulatively 45 | # declare what data keys to store 46 | self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones'] 47 | self.reset() 48 | 49 | @lab_api 50 | def reset(self): 51 | '''Resets the memory. Also used to initialize memory vars''' 52 | for k in self.data_keys: 53 | setattr(self, k, []) 54 | self.cur_epi_data = {k: [] for k in self.data_keys} 55 | self.most_recent = (None,) * len(self.data_keys) 56 | self.size = 0 57 | 58 | @lab_api 59 | def update(self, state, action, reward, next_state, done): 60 | '''Interface method to update memory''' 61 | self.add_experience(state, action, reward, next_state, done) 62 | 63 | def add_experience(self, state, action, reward, next_state, done): 64 | '''Interface helper method for update() to add experience to memory''' 65 | self.most_recent = (state, action, reward, next_state, done) 66 | for idx, k in enumerate(self.data_keys): 67 | self.cur_epi_data[k].append(self.most_recent[idx]) 68 | # If episode ended, add to memory and clear cur_epi_data 69 | if util.epi_done(done): 70 | for k in self.data_keys: 71 | getattr(self, k).append(self.cur_epi_data[k]) 72 | self.cur_epi_data = {k: [] for k in self.data_keys} 73 | # If agent has collected the desired number of episodes, it is ready to train 74 | # length is num of epis due to nested structure 75 | # if len(self.states) == self.body.agent.algorithm.training_frequency: 76 | if len(self.states) % self.body.agent.algorithm.training_frequency == 0: 77 | self.body.agent.algorithm.to_train = 1 78 | # Track memory size and num experiences 79 | self.size += 1 80 | self.seen_size += 1 81 | 82 | def get_most_recent_experience(self): 83 | '''Returns the most recent experience''' 84 | return self.most_recent 85 | 86 | def sample(self): 87 | ''' 88 | Returns all the examples from memory in a single batch. Batch is stored as a dict. 89 | Keys are the names of the different elements of an experience. Values are nested lists of the corresponding sampled elements. Elements are nested into episodes 90 | e.g. 91 | batch = { 92 | 'states' : [[s_epi1], [s_epi2], ...], 93 | 'actions' : [[a_epi1], [a_epi2], ...], 94 | 'rewards' : [[r_epi1], [r_epi2], ...], 95 | 'next_states': [[ns_epi1], [ns_epi2], ...], 96 | 'dones' : [[d_epi1], [d_epi2], ...]} 97 | ''' 98 | batch = {k: getattr(self, k) for k in self.data_keys} 99 | self.reset() 100 | return batch 101 | 102 | 103 | class OnPolicyBatchReplay(OnPolicyReplay): 104 | ''' 105 | Same as OnPolicyReplay Memory with the following difference. 106 | 107 | The memory does not have a fixed size. Instead the memory stores data from N experiences, where N is determined by the user. After N experiences or if an episode has ended, all of the examples are returned to the agent to learn from. 108 | 109 | In contrast, OnPolicyReplay stores entire episodes and stores them in a nested structure. OnPolicyBatchReplay stores experiences in a flat structure. 110 | 111 | e.g. memory_spec 112 | "memory": { 113 | "name": "OnPolicyBatchReplay" 114 | } 115 | * batch_size is training_frequency provided by algorithm_spec 116 | ''' 117 | 118 | def __init__(self, memory_spec, body): 119 | super().__init__(memory_spec, body) 120 | self.is_episodic = False 121 | 122 | def add_experience(self, state, action, reward, next_state, done): 123 | '''Interface helper method for update() to add experience to memory''' 124 | self.most_recent = [state, action, reward, next_state, done] 125 | for idx, k in enumerate(self.data_keys): 126 | getattr(self, k).append(self.most_recent[idx]) 127 | # Track memory size and num experiences 128 | self.size += 1 129 | self.seen_size += 1 130 | # Decide if agent is to train 131 | if len(self.states) == self.body.agent.algorithm.training_frequency: 132 | self.body.agent.algorithm.to_train = 1 133 | 134 | def sample(self): 135 | ''' 136 | Returns all the examples from memory in a single batch. Batch is stored as a dict. 137 | Keys are the names of the different elements of an experience. Values are a list of the corresponding sampled elements 138 | e.g. 139 | batch = { 140 | 'states' : states, 141 | 'actions' : actions, 142 | 'rewards' : rewards, 143 | 'next_states': next_states, 144 | 'dones' : dones} 145 | ''' 146 | return super().sample() 147 | -------------------------------------------------------------------------------- /src/dialogue_system/memory/replay.py: -------------------------------------------------------------------------------- 1 | # Modified by Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | 6 | from src.dialogue_system.memory.base import Memory 7 | from src.dialogue_system.memory import util 8 | 9 | 10 | def sample_next_states(head, max_size, ns_idx_offset, batch_idxs, states, ns_buffer): 11 | '''Method to sample next_states from states, with proper guard for next_state idx being out of bound''' 12 | # idxs for next state is state idxs with offset, modded 13 | ns_batch_idxs = (batch_idxs + ns_idx_offset) % max_size 14 | # if head < ns_idx <= head + ns_idx_offset, ns is stored in ns_buffer 15 | ns_batch_idxs = ns_batch_idxs % max_size 16 | buffer_ns_locs = np.argwhere( 17 | (head < ns_batch_idxs) & (ns_batch_idxs <= head + ns_idx_offset)).flatten() 18 | # find if there is any idxs to get from buffer 19 | to_replace = buffer_ns_locs.size != 0 20 | if to_replace: 21 | # extract the buffer_idxs first for replacement later 22 | # given head < ns_idx <= head + offset, and valid buffer idx is [0, offset) 23 | # get 0 < ns_idx - head <= offset, or equiv. 24 | # get -1 < ns_idx - head - 1 <= offset - 1, i.e. 25 | # get 0 <= ns_idx - head - 1 < offset, hence: 26 | buffer_idxs = ns_batch_idxs[buffer_ns_locs] - head - 1 27 | # set them to 0 first to allow sampling, then replace later with buffer 28 | ns_batch_idxs[buffer_ns_locs] = 0 29 | # guard all against overrun idxs from offset 30 | ns_batch_idxs = ns_batch_idxs % max_size 31 | next_states = util.batch_get(states, ns_batch_idxs) 32 | if to_replace: 33 | # now replace using buffer_idxs and ns_buffer 34 | buffer_ns = util.batch_get(ns_buffer, buffer_idxs) 35 | next_states[buffer_ns_locs] = buffer_ns 36 | return next_states 37 | 38 | 39 | class Replay(Memory): 40 | ''' 41 | Stores agent experiences and samples from them for agent training 42 | 43 | An experience consists of 44 | - state: representation of a state 45 | - action: action taken 46 | - reward: scalar value 47 | - next state: representation of next state (should be same as state) 48 | - done: 0 / 1 representing if the current state is the last in an episode 49 | 50 | The memory has a size of N. When capacity is reached, the oldest experience 51 | is deleted to make space for the lastest experience. 52 | - This is implemented as a circular buffer so that inserting experiences are O(1) 53 | - Each element of an experience is stored as a separate array of size N * element dim 54 | 55 | When a batch of experiences is requested, K experiences are sampled according to a random uniform distribution. 56 | 57 | If 'use_cer', sampling will add the latest experience. 58 | 59 | e.g. memory_spec 60 | "memory": { 61 | "name": "Replay", 62 | "batch_size": 32, 63 | "max_size": 10000, 64 | "use_cer": true 65 | } 66 | ''' 67 | 68 | def __init__(self, paramter): 69 | super().__init__(paramter) 70 | self.batch_size = paramter.get("batch_size") 71 | self.max_size = paramter.get("experience_replay_pool_size") 72 | self.use_cer = True 73 | 74 | 75 | self.is_episodic = False 76 | self.batch_idxs = None 77 | self.size = 0 # total experiences stored 78 | self.seen_size = 0 # total experiences seen cumulatively 79 | self.head = -1 # index of most recent experience 80 | # generic next_state buffer to store last next_states (allow for multiple for venv) 81 | # self.ns_idx_offset = self.body.env.num_envs if body.env.is_venv else 1 82 | # self.ns_buffer = deque(maxlen=self.ns_idx_offset) 83 | # declare what data keys to store 84 | self.data_keys = ['states', 'actions', 'rewards', 'next_states', 'dones'] 85 | self.reset() 86 | 87 | def reset(self): 88 | '''Initializes the memory arrays, size and head pointer''' 89 | # set self.states, self.actions, ... 90 | for k in self.data_keys: 91 | setattr(self, k, [None] * self.max_size) 92 | # if k != 'next_states': # reuse self.states 93 | # # list add/sample is over 10x faster than np, also simpler to handle 94 | # setattr(self, k, [None] * self.max_size) 95 | self.size = 0 96 | self.head = -1 97 | # self.ns_buffer.clear() 98 | 99 | def update(self, state, action, reward, next_state, done): 100 | '''Interface method to update memory''' 101 | self.add_experience(state, action, reward, next_state, done) 102 | 103 | def add_experience(self, state, action, reward, next_state, done): 104 | '''Implementation for update() to add experience to memory, expanding the memory size if necessary''' 105 | # Move head pointer. Wrap around if necessary 106 | self.head = (self.head + 1) % self.max_size #当经验池的数据量超过max_size的时候就会自动把原始的更新掉 107 | self.states[self.head] = state.astype(np.float16) 108 | self.actions[self.head] = action 109 | self.rewards[self.head] = reward 110 | self.next_states[self.head] = next_state 111 | # self.ns_buffer.append(next_state.astype(np.float16)) 112 | self.dones[self.head] = done 113 | 114 | # Actually occupied size of memory 115 | if self.size < self.max_size: 116 | self.size += 1 117 | self.seen_size += 1 118 | # set to_train using memory counters head, seen_size instead of tick since clock will step by num_envs when on venv; to_train will be set to 0 after training step 119 | 120 | def sample(self): 121 | ''' 122 | Returns a batch of batch_size samples. Batch is stored as a dict. 123 | Keys are the names of the different elements of an experience. Values are an array of the corresponding sampled elements 124 | e.g. 125 | batch = { 126 | 'states' : states, 127 | 'actions' : actions, 128 | 'rewards' : rewards, 129 | 'next_states': next_states, 130 | 'dones' : dones} 131 | ''' 132 | self.batch_idxs = self.sample_idxs(self.batch_size) 133 | batch = {} 134 | for k in self.data_keys: 135 | batch[k] = util.batch_get(getattr(self, k), self.batch_idxs) 136 | # if k == 'next_states': 137 | # batch[k] = sample_next_states(self.head, self.max_size, self.ns_idx_offset, self.batch_idxs, self.states, self.ns_buffer) 138 | # else: 139 | # batch[k] = util.batch_get(getattr(self, k), self.batch_idxs) 140 | return batch 141 | 142 | def sample_idxs(self, batch_size): 143 | '''Batch indices a sampled random uniformly''' 144 | batch_idxs = np.random.randint(self.size, size=batch_size) 145 | if self.use_cer: # add the latest sample 146 | batch_idxs[-1] = self.head 147 | return batch_idxs 148 | -------------------------------------------------------------------------------- /src/classifier/symptom_as_feature/svm_kliao.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 12 16:47:05 2018 4 | 5 | @author: DELL 6 | """ 7 | 8 | import numpy as np 9 | import os, sys 10 | import pickle 11 | import copy 12 | import pandas as pd 13 | 14 | sys.path.append(os.getcwd().replace('\\resource\\tagger2', '')) 15 | os.chdir(os.getcwd().replace('\\resource\\tagger2', '')) 16 | from preprocess import GoalDumper 17 | from sklearn.svm import SVC 18 | from sklearn import svm 19 | # from sklearn.metrics import accuracy_score 20 | # from sklearn.model_selection import train_test_split 21 | from sklearn.model_selection import train_test_split, cross_val_score, cross_validate 22 | 23 | 24 | def disease_symptom_clip(disease_symptom, denominator): 25 | """ 26 | Keep the top min(symptom_num, max_turn//denominator) for each disease, and the related symptoms are sorted 27 | descendent according to their frequencies. 28 | 29 | Args: 30 | disease_symptom: a dict, key is the names of diseases, and the corresponding value is a dict too which 31 | contains the index of this disease and the related symptoms. 32 | denominator: int, the number of symptoms for each diseases is max_turn // denominator. 33 | parameter: the super-parameter. 34 | 35 | Returns: 36 | and dict, whose keys are the names of diseases, and the values are dicts too with two keys: {'index', symptom} 37 | """ 38 | max_turn = 22 39 | temp_disease_symptom = copy.deepcopy(disease_symptom) 40 | for key, value in disease_symptom.items(): 41 | symptom_list = sorted(value['symptom'].items(), key=lambda x: x[1], reverse=True) 42 | symptom_list = [v[0] for v in symptom_list] 43 | symptom_list = symptom_list[0:min(len(symptom_list), int(max_turn / float(denominator)))] 44 | temp_disease_symptom[key]['symptom'] = symptom_list 45 | # print('\n',disease_symptom) 46 | # print('\n',temp_disease_symptom) 47 | return temp_disease_symptom 48 | 49 | 50 | def svm_model(dataset, min_count, target, svm_c, epoch): 51 | ''' 52 | min_count:only slot with frequency bigger than min_count will be kept 53 | svm_c:the tunning parameter of svm 54 | epoch: the times of k-fold and then take the average 55 | ''' 56 | index_len = len(dataset) 57 | slots_x = np.zeros((index_len, sum(sum(abs(dataset), 1) > min_count))) 58 | count = 0 59 | for i, value in enumerate(sum(dataset, 1)): 60 | if value <= min_count: 61 | pass 62 | else: 63 | slots_x[:, count] = dataset[:, i] 64 | count += 1 65 | 66 | slots_input = pd.DataFrame(index=range(index_len)) 67 | for col in range(slots_x.shape[1]): 68 | column = slots_x[:, col] 69 | column_mod = [] 70 | for j in column: 71 | if j == 1: 72 | column_mod.append('yes') 73 | elif j == -1: 74 | column_mod.append('no') 75 | else: 76 | column_mod.append('UNK') 77 | slots_input[str(col)] = column_mod 78 | 79 | slots_input = pd.get_dummies(slots_input) 80 | target = np.array(target) 81 | # x_train, x_val, y_train, y_val = train_test_split(slots_input, disease_y, test_size=0.3, random_state=10) 82 | clf = svm.SVC(kernel='linear', C=svm_c) 83 | scores = [] 84 | for i in range(epoch): 85 | scores_exp = cross_validate(clf, slots_input, target, cv=5, scoring='accuracy', return_train_score=False) 86 | scores_tot = sum(scores_exp['test_score']) / 5 87 | scores.append(scores_tot) 88 | return np.mean(scores) 89 | 90 | 91 | # goal_file = "./goal_batch2.json" 92 | goal_file = "./../../data/goal_find.json" 93 | goal_dump_file = "./../../data/goal_set.p" 94 | slots_dump_file = "./../../data/slot_set.p" 95 | goal = GoalDumper(goal_file) 96 | goal.dump(goal_dump_file) 97 | goal_set = goal.goalset 98 | goal.dump_slot(slots_dump_file) 99 | # goal_set,slot_set=goal.set_return() 100 | slot_set = pickle.load(open(slots_dump_file, 'rb')) 101 | slot_set.pop('disease') 102 | disease_symptom = goal.disease_symptom 103 | disease_symptom1 = disease_symptom_clip(disease_symptom, 2) 104 | # slot_set.pop('发热39度3') 105 | # slot_set.pop('发热37.7至38.4度') 106 | 107 | disease_y = [] 108 | total_set = copy.deepcopy(goal_set['train']) 109 | total_set.extend(goal_set['test']) 110 | slots_exp = np.zeros((len(total_set), len(slot_set))) 111 | slots_all = np.zeros((len(total_set), len(slot_set))) 112 | # slots_exp=pd.DataFrame(slots_exp,columns=slot_set.keys()) 113 | # slots_all=pd.DataFrame(slots_all,columns=slot_set.keys()) 114 | for i, dialogue in enumerate(total_set): 115 | tag = dialogue['disease_tag'] 116 | tag_group = disease_symptom1[tag]['symptom'] 117 | disease_y.append(tag) 118 | goal = dialogue['goal'] 119 | explicit = goal['explicit_inform_slots'] 120 | implicit = goal['implicit_inform_slots'] 121 | for slot, value in implicit.items(): 122 | try: 123 | slot_id = slot_set[slot] 124 | if value == True: 125 | slots_all[i, slot_id] = '1' 126 | if value == False: 127 | slots_all[i, slot_id] = '-1' 128 | except: 129 | pass 130 | for exp_slot, value in explicit.items(): 131 | try: 132 | slot_id = slot_set[exp_slot] 133 | if value == True: 134 | slots_exp[i, slot_id] = '1' 135 | slots_all[i, slot_id] = '1' 136 | if value == False: 137 | slots_exp[i, slot_id] = '-1' 138 | slots_all[i, slot_id] = '-1' 139 | except: 140 | pass 141 | 142 | score_tot_exp = svm_model(dataset=slots_exp, min_count=0, target=disease_y, svm_c=0.3, epoch=10) 143 | score_tot_all = svm_model(dataset=slots_all, min_count=0, target=disease_y, svm_c=0.3, epoch=10) 144 | print("score of explicit is %f" % score_tot_exp) 145 | print("score of explicit and implicit is %f" % score_tot_all) 146 | ''' 147 | slots_x_all=np.zeros((len(total_set),sum(sum(abs(slots_all),1)>5))) 148 | count=0 149 | for i,value in enumerate(sum(slots_all,1)): 150 | if value<=5: 151 | pass 152 | else: 153 | slots_x_all[:,count]=slots_all[:,i] 154 | count+=1 155 | 156 | slots_input_all=pd.DataFrame(index=range(len(total_set))) 157 | for col in range(slots_x_all.shape[1]): 158 | column=slots_x_all[:,col] 159 | column_mod=[] 160 | for j in column: 161 | if j==1: 162 | column_mod.append('yes') 163 | elif j==-1: 164 | column_mod.append('no') 165 | else: 166 | column_mod.append('UNK') 167 | slots_input_all[str(col)]=column_mod 168 | 169 | slots_input_all=pd.get_dummies(slots_input_all) 170 | disease_y=np.array(disease_y) 171 | #x_train, x_val, y_train, y_val = train_test_split(slots_input, disease_y, test_size=0.3, random_state=10) 172 | clf = svm.SVC(kernel='linear', C=5) 173 | 174 | scores_all = cross_validate(clf, slots_input_all, disease_y, cv=5,scoring='accuracy',return_train_score=False) 175 | scores_tot_all=sum(scores_all['test_score'])/5 176 | ''' 177 | -------------------------------------------------------------------------------- /preprocess/label/frequency.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import copy 5 | import csv 6 | 7 | 8 | class Frequency(object): 9 | """ 10 | 统计每个症状出现的次数,每种疾病下每个症状出现的次数。 11 | """ 12 | def __init__(self): 13 | pass 14 | 15 | def load(self, goal_file, symptom_frequency_file, disease_symptom_frequency_file): 16 | data_file = open(file=goal_file, mode='r') 17 | symptom_frequency_dict = {} 18 | disease_symptom_frequency_dict = {} 19 | for line in data_file: 20 | line = json.loads(line) 21 | disease_symptom_frequency_dict.setdefault(line['disease_tag'], dict()) 22 | for symptom, value in line['goal']['implicit_inform_slots'].items(): 23 | symptom_frequency_dict.setdefault(symptom, 0) 24 | symptom_frequency_dict[symptom] += 1 25 | disease_symptom_frequency_dict[line['disease_tag']].setdefault(symptom, 0) 26 | disease_symptom_frequency_dict[line['disease_tag']][symptom] += 1 27 | 28 | for symptom, value in line['goal']['explicit_inform_slots'].items(): 29 | symptom_frequency_dict.setdefault(symptom, 0) 30 | symptom_frequency_dict[symptom] += 1 31 | disease_symptom_frequency_dict[line['disease_tag']].setdefault(symptom, 0) 32 | disease_symptom_frequency_dict[line['disease_tag']][symptom] += 1 33 | symptom_file = open(file=symptom_frequency_file, mode='w', encoding='utf8') 34 | disease_file = open(file=disease_symptom_frequency_file, mode='w',encoding='utf8') 35 | symptom_writer = csv.writer(symptom_file) 36 | disease_writer = csv.writer(disease_file) 37 | 38 | for symptom, count in symptom_frequency_dict.items(): 39 | symptom_writer.writerow([symptom, count]) 40 | for disease, symptom_count in disease_symptom_frequency_dict.items(): 41 | for symptom, count in symptom_count.items(): 42 | disease_writer.writerow([disease, symptom, count]) 43 | symptom_file.close() 44 | disease_file.close() 45 | data_file.close() 46 | 47 | 48 | class Normalize(object): 49 | """ 50 | 人工对标注得到的症状进行了部分归一。 51 | """ 52 | def __init__(self, normalize_file): 53 | self.spoken_normal = {} 54 | data_file = open(normalize_file, mode='r', encoding='utf8') 55 | reader = csv.reader(data_file) 56 | for line in reader: 57 | line = line[0].split('\t') 58 | print(line) 59 | self.spoken_normal[line[0]] = line[1] 60 | data_file.close() 61 | 62 | def load(self, goal_file): 63 | data_file = open(goal_file, 'r', encoding='utf8') 64 | new_file = open(goal_file.split('.json')[0] + '_normal.json', 'w', encoding='utf8') 65 | for line in data_file: 66 | line = json.loads(line) 67 | temp_line = copy.deepcopy(line) 68 | for symptom, value in temp_line['goal']['implicit_inform_slots'].items(): 69 | if symptom in self.spoken_normal.keys(): 70 | line['goal']['implicit_inform_slots'][self.spoken_normal[symptom]] = value 71 | line['goal']['implicit_inform_slots'].pop(symptom) 72 | for symptom, value in temp_line['goal']['explicit_inform_slots'].items(): 73 | if symptom in self.spoken_normal.keys(): 74 | line['goal']['explicit_inform_slots'][self.spoken_normal[symptom]] = value 75 | line['goal']['explicit_inform_slots'].pop(symptom) 76 | new_file.write(json.dumps(line) + '\n') 77 | data_file.close() 78 | new_file.close() 79 | 80 | 81 | class FilterFrequency(object): 82 | """ 83 | 根据频率过滤症状。 84 | """ 85 | def __init__(self, threshold = 1): 86 | self.threshold = threshold 87 | 88 | def load(self, goal_file): 89 | data_file = open(goal_file, 'r', encoding='utf8') 90 | self.symptom_frequency = {} 91 | for line in data_file: 92 | line = json.loads(line) 93 | for symptom in line['goal']['implicit_inform_slots'].keys(): 94 | self.symptom_frequency.setdefault(symptom, 0) 95 | self.symptom_frequency[symptom] += 1 96 | 97 | for symptom in line['goal']['explicit_inform_slots'].keys(): 98 | self.symptom_frequency.setdefault(symptom, 0) 99 | self.symptom_frequency[symptom] += 1 100 | data_file.close() 101 | print(self.symptom_frequency) 102 | 103 | data_file = open(goal_file, 'r', encoding='utf8') 104 | new_file = open(goal_file.split('.json')[0] + '_filter_' + str(self.threshold) + '.json', 'w', encoding='utf8') 105 | for line in data_file: 106 | line = json.loads(line) 107 | temp_line = copy.deepcopy(line) 108 | for symptom, value in temp_line['goal']['implicit_inform_slots'].items(): 109 | if self.symptom_frequency[symptom] < self.threshold: 110 | line['goal']['implicit_inform_slots'].pop(symptom) 111 | for symptom, value in temp_line['goal']['explicit_inform_slots'].items(): 112 | if self.symptom_frequency[symptom] < self.threshold: 113 | line['goal']['explicit_inform_slots'].pop(symptom) 114 | new_file.write(json.dumps(line) + '\n') 115 | data_file.close() 116 | new_file.close() 117 | 118 | 119 | class FirstRun(object): 120 | """ 121 | 去除疾病名称、症状名词里面的空格。 122 | """ 123 | def read(self, goal_file): 124 | data_file = open(file=goal_file, mode="r") 125 | new_file = open(file=goal_file.split('.json')[0] + '_2.json', mode='w') 126 | for line in data_file: 127 | line = json.loads(line) 128 | line['disease_tag'] = line['disease_tag'].replace(' ', '') 129 | temp_line = copy.deepcopy(line) 130 | for symptom, value in temp_line['goal']['implicit_inform_slots'].items(): 131 | line['goal']['implicit_inform_slots'].pop(symptom) 132 | line['goal']['implicit_inform_slots'][symptom.replace(' ', '')] = value 133 | 134 | for symptom, value in temp_line['goal']['explicit_inform_slots'].items(): 135 | line['goal']['explicit_inform_slots'].pop(symptom) 136 | line['goal']['explicit_inform_slots'][symptom.replace(' ', '')] = value 137 | new_file.write(json.dumps(line) + '\n') 138 | data_file.close() 139 | new_file.close() 140 | 141 | 142 | if __name__ == '__main__': 143 | goal_file = './../../resources/label/new/goal2.json' 144 | # goal_file = './../../resources/label/goal2_normal.json' 145 | # goal_file = './../../resources/label/goal2.json' 146 | symptom_file = './../../resources/label/symptom_frequency.csv' 147 | disease_file = './../../resources/label/disease_frequency.csv' 148 | 149 | goal_file = './../../resources/label/goal2_normal.json' 150 | # first = FirstRun() 151 | # first.read(goal_file) 152 | 153 | 154 | # normal_file = './../../resources/label/症状归一手动.csv' 155 | # normal = Normalize(normal_file) 156 | # normal.load(goal_file) 157 | 158 | # frequency = Frequency() 159 | # frequency.load(goal_file,symptom_file, disease_file) 160 | 161 | 162 | filter = FilterFrequency(threshold=10) 163 | filter.load(goal_file) 164 | -------------------------------------------------------------------------------- /src/dialogue_system/policy_learning/dqn_with_goal_joint.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import torch 4 | import random 5 | from collections import deque 6 | import sys, os 7 | sys.path.append(os.getcwd().replace("src/dialogue_system/policy_learning","")) 8 | from src.dialogue_system.policy_learning.dqn_torch import DQN 9 | from collections import namedtuple 10 | 11 | 12 | class DQNModelWithGoal(torch.nn.Module): 13 | """ 14 | The model in this file is reference to `Florensa, C., Duan, Y., & Abbeel, P. (2017). Stochastic neural networks for 15 | hierarchical reinforcement learning. arXiv preprint arXiv:1704.03012.` 16 | https://arxiv.org/abs/1704.03012 17 | """ 18 | def __init__(self, input_size, hidden_size, output_size, number_of_latent_variables, goal_embedding_value, parameter): 19 | super(DQNModelWithGoal, self).__init__() 20 | self.params = parameter 21 | self.number_of_latent_variables = number_of_latent_variables 22 | self.tau = self.params.get("temperature") 23 | 24 | # self.goal_embed_layer = torch.nn.Embedding.from_pretrained(torch.Tensor(goal_embedding_value), freeze=True) 25 | # self.goal_embed_layer.weight.requires_grad_(False) 26 | 27 | # different layers. Two layers. 28 | self.policy_layer = torch.nn.Sequential( 29 | torch.nn.Linear(input_size + number_of_latent_variables, hidden_size, bias=True), 30 | torch.nn.Dropout(0.5), 31 | torch.nn.LeakyReLU(), 32 | torch.nn.Linear(hidden_size, output_size, bias=True) 33 | ) 34 | 35 | self.goal_layer = torch.nn.Sequential( 36 | torch.nn.Linear(input_size, hidden_size, bias=True), 37 | torch.nn.Dropout(0.5), 38 | torch.nn.LeakyReLU(), 39 | torch.nn.Linear(hidden_size, number_of_latent_variables, bias=True) 40 | ) 41 | 42 | # one layer. 43 | # self.policy_layer = torch.nn.Linear(input_size + number_of_latent_variables, output_size, bias=True) 44 | # self.goal_layer = torch.nn.Linear(input_size, number_of_latent_variables, bias=True) 45 | 46 | def forward(self, x): 47 | if torch.cuda.is_available(): 48 | x.cuda() 49 | goal = self.goal_generator(x) 50 | q_values = self.compute_q_value(x,goal) 51 | return q_values 52 | 53 | def goal_generator(self, x): 54 | logits = self.goal_layer(x) 55 | # print('logits', logits) 56 | goal_rep = torch.nn.functional.gumbel_softmax(logits=logits, tau=self.tau, hard=False) 57 | # print('goal', goal_rep) 58 | return goal_rep 59 | 60 | def compute_q_value(self, x, goal): 61 | temp = torch.cat((x, goal), dim=1) 62 | q_values = self.policy_layer(temp) 63 | return q_values 64 | 65 | 66 | class DQNModelWithGoal2(torch.nn.Module): 67 | """ 68 | Weighting sum the goal embedding. 69 | """ 70 | def __init__(self, input_size, hidden_size, output_size, number_of_latent_variables, goal_embedding_value, parameter): 71 | super(DQNModelWithGoal2, self).__init__() 72 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | self.params = parameter 74 | self.number_of_latent_variables = number_of_latent_variables 75 | self.tau = self.params.get("temperature") 76 | 77 | # self.goal_embed_layer = torch.nn.Embedding.from_pretrained(torch.Tensor(goal_embedding_value), freeze=True) 78 | # self.goal_embed_layer.weight.requires_grad_(False) 79 | self.goal_embed = torch.Tensor(goal_embedding_value).to(self.device) 80 | self.goal_embed.requires_grad_(False) 81 | 82 | # different layers. Two layers. 83 | self.policy_layer = torch.nn.Sequential( 84 | torch.nn.Linear(input_size + self.goal_embed.size()[1], hidden_size, bias=True), 85 | torch.nn.Dropout(0.5), 86 | torch.nn.LeakyReLU(), 87 | torch.nn.Linear(hidden_size, output_size, bias=True) 88 | ) 89 | 90 | self.goal_layer = torch.nn.Sequential( 91 | torch.nn.Linear(input_size, hidden_size, bias=True), 92 | torch.nn.Dropout(0.5), 93 | torch.nn.LeakyReLU(), 94 | torch.nn.Linear(hidden_size, number_of_latent_variables, bias=True) 95 | ) 96 | 97 | def forward(self, x): 98 | if torch.cuda.is_available(): 99 | x.cuda() 100 | goal = self.goal_generator(x) 101 | goal = goal.mm(self.goal_embed) 102 | 103 | q_values = self.compute_q_value(x,goal) 104 | return q_values 105 | 106 | def goal_generator(self, x): 107 | logits = self.goal_layer(x) 108 | # print('logits', logits) 109 | goal_rep = torch.nn.functional.gumbel_softmax(logits=logits, tau=self.tau, hard=False) 110 | # print('goal', goal_rep) 111 | return goal_rep 112 | 113 | def compute_q_value(self, x, goal): 114 | temp = torch.cat((x, goal), dim=1) 115 | q_values = self.policy_layer(temp) 116 | return q_values 117 | 118 | 119 | 120 | class DQNWithGoalJoint(DQN): 121 | def __init__(self, input_size, hidden_size, output_size, goal_embedding_value,parameter): 122 | super(DQNWithGoalJoint, self).__init__(input_size, hidden_size, output_size, parameter) 123 | del self.current_net 124 | del self.target_net 125 | 126 | self.params = parameter 127 | self.Transition = namedtuple('Transition', ('state', 'agent_action', 'reward', 'next_state', 'episode_over')) 128 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | 130 | self.current_net = DQNModelWithGoal(input_size, hidden_size, output_size, 4, goal_embedding_value, parameter).to(self.device) 131 | self.target_net = DQNModelWithGoal(input_size, hidden_size, output_size, 4,goal_embedding_value, parameter).to(self.device) 132 | print(self.current_net) 133 | 134 | if torch.cuda.is_available(): 135 | if parameter["multi_GPUs"] == True: # multi GPUs 136 | self.current_net = torch.nn.DataParallel(self.current_net) 137 | self.target_net = torch.nn.DataParallel(self.target_net) 138 | else:# Single GPU 139 | self.current_net.cuda(device=self.device) 140 | self.target_net.cuda(device=self.device) 141 | 142 | self.target_net.load_state_dict(self.current_net.state_dict()) # Copy paraameters from current networks. 143 | self.target_net.eval() # set this model as evaluate mode. And it's parameters will not be updated. 144 | 145 | # Optimizer with L2 regularization 146 | weight_p, bias_p = [], [] 147 | for name, p in self.current_net.named_parameters(): 148 | if 'bias' in name: 149 | bias_p.append(p) 150 | else: 151 | weight_p.append(p) 152 | 153 | self.optimizer = torch.optim.Adam([ 154 | {'params': weight_p, 'weight_decay': 0.1}, # with L2 regularization 155 | {'params': bias_p, 'weight_decay': 0} # no L2 regularization. 156 | ], lr=self.params.get("dqn_learning_rate",0.001)) 157 | 158 | if self.params.get("train_mode") is False and self.params.get('agent_id').lower() == 'agentwithgoaljoint': 159 | self.restore_model(self.params.get("saved_model")) 160 | self.current_net.eval() 161 | self.target_net.eval() -------------------------------------------------------------------------------- /src/classifier/self_report_as_feature/report_classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | 使用主诉内容进行分类,word embedding of n-grams. 4 | """ 5 | import copy 6 | import tensorflow as tf 7 | import time 8 | from sklearn import svm 9 | import os, sys 10 | import jieba 11 | import csv 12 | import numpy as np 13 | import random 14 | sys.path.append(os.getcwd().replace("src/classifier/self_report_as_feature","")) 15 | 16 | 17 | class ReportClassifier(object): 18 | def __init__(self, stop_words, data_file): 19 | self.corpus = Corpus(stop_words=stop_words) 20 | self.corpus.load_data(data_file=data_file) 21 | self.clf = svm.SVC(decision_function_shape='ovo') 22 | self.__build_tf_model() 23 | 24 | def train_sklearn_svm(self): 25 | print("fitting svm model...") 26 | self.clf.fit(X = self.corpus.data_set["train"]["x"],y=self.corpus.data_set["train"]["y"]) 27 | 28 | def evaluate_sklearn_svm(self): 29 | predict = self.clf.predict(X=self.corpus.data_set["test"]["x"]) 30 | count = 0 31 | for index in range(0,len(predict),1): 32 | if predict[index] == self.corpus.data_set["test"]["y"][index]: 33 | count += 1 34 | print("accuracy of sklearn svm:", float(count)/len(predict)) 35 | 36 | def train_tf(self): 37 | data = self.corpus.data_set 38 | train_input_fn = self.__get_input_fn(data["train"], batch_size=64) 39 | self.estimator.fit(input_fn=train_input_fn, steps=2000) 40 | 41 | def evaluate_tf(self): 42 | data = self.corpus.data_set 43 | eval_input_fn = self.__get_input_fn(data["test"], batch_size=5000) 44 | eval_metrics = self.estimator.evaluate(input_fn=eval_input_fn, steps=1) 45 | print("result of tf:",eval_metrics) 46 | 47 | def __build_tf_model(self): 48 | self.feature = tf.contrib.layers.real_valued_column('word_index', dimension=self.corpus.vocabulary_size) 49 | self.optimizer = tf.train.FtrlOptimizer(learning_rate=50.0, l2_regularization_strength=0.001) 50 | self.kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(input_dim=self.corpus.vocabulary_size, output_dim=1000, stddev=5.0, name='rffm') 51 | kernel_mappers = {self.feature: [self.kernel_mapper]} 52 | self.estimator = tf.contrib.kernel_methods.KernelLinearClassifier( 53 | n_classes=len(self.corpus.disease_to_index), optimizer=self.optimizer, kernel_mappers=kernel_mappers) 54 | 55 | def __get_input_fn(self,dataset_split, batch_size, capacity=15000, min_after_dequeue=1000): 56 | def _input_fn(): 57 | xs = np.array(dataset_split["x"]).astype(np.float32) 58 | ys = np.array(dataset_split["y"]) 59 | report_batch, labels_batch = tf.train.shuffle_batch( 60 | # tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)], 61 | tensors=[xs, ys.astype(np.int32)], 62 | batch_size=batch_size, 63 | capacity=capacity, 64 | min_after_dequeue=min_after_dequeue, 65 | enqueue_many=True, 66 | num_threads=4) 67 | features_map = {'word_index': report_batch} 68 | return features_map, labels_batch 69 | return _input_fn 70 | 71 | 72 | class Corpus(object): 73 | def __init__(self, stop_words): 74 | self.max_document_length = 0 75 | self.data_set = {"train":{"x":[],"y":[],"word_index":[]}, "test":{"x":[],"y":[],"word_index":[]}} 76 | self.stop_words = self._load_stop_words(stop_words_file=stop_words) 77 | 78 | def load_data(self,data_file, train=0.8, test=0.2): 79 | """ 80 | 加载数据并分词 81 | :param data_file: 包含主诉内容的文件 82 | :return: 83 | """ 84 | assert (train*100+test*100==100), "train + test + validate not equals to 1.0." 85 | # Mapping disease to index and index to disease. 86 | print("disease to index...") 87 | data_reader = csv.reader(open(data_file, "r",encoding="utf-8")) 88 | self.disease_to_index = {} 89 | self.index_to_disease = {} 90 | index = 0 91 | for line in data_reader: 92 | disease = line[5] 93 | if disease not in self.disease_to_index.keys() and disease != "小儿发热": 94 | self.disease_to_index[disease] = index 95 | self.index_to_disease[index] = disease 96 | index += 1 97 | 98 | # Word segmentation. 99 | print("word segmentation...") 100 | data_set = [] 101 | data_reader = csv.reader(open(data_file, "r",encoding="utf-8")) 102 | all_word_list = [] 103 | for line in data_reader: 104 | if line[5] == "小儿发热": 105 | continue 106 | disease_index = self.disease_to_index[line[5]] 107 | # seg_list = jieba.cut(line[6], cut_all=False) 108 | seg_list = jieba.cut_for_search(line[6]) 109 | 110 | word_list = [] 111 | for word in seg_list: 112 | if word not in self.stop_words: 113 | word_list.append(word) 114 | all_word_list.append(word) 115 | if len(word_list) > self.max_document_length: 116 | self.max_document_length = len(word_list) 117 | data_set.append({"disease":disease_index,"text":" ".join(word_list)}) 118 | # print(" ".join(word_list)) 119 | # word to index. 120 | print("word to index...") 121 | self.vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(max_document_length=self.max_document_length,min_frequency=3) 122 | self.vocab_processor.fit(all_word_list) 123 | self.vocabulary_size = len(self.vocab_processor.vocabulary_) 124 | 125 | print("preparing dataset...") 126 | for data in data_set: 127 | data["text"] = next(self.vocab_processor.transform([data["text"]])).tolist() 128 | text_rep = np.zeros(self.vocabulary_size) 129 | for index in data["text"]: 130 | if index >= 1: 131 | text_rep[index-1] += 1.0 132 | 133 | random_float = random.random() 134 | if random_float <= train: 135 | self.data_set["train"]["x"].append(copy.deepcopy(text_rep)) 136 | self.data_set["train"]["word_index"].append(copy.deepcopy(data["text"])) 137 | self.data_set["train"]["y"].append(copy.deepcopy(data["disease"])) 138 | else: 139 | self.data_set["test"]["x"].append(copy.deepcopy(text_rep)) 140 | self.data_set["test"]["word_index"].append(copy.deepcopy(data["text"])) 141 | self.data_set["test"]["y"].append(copy.deepcopy(data["disease"])) 142 | 143 | def _load_stop_words(self,stop_words_file): 144 | """ 145 | Load stop words. 146 | :param stop_words_file: the path of file that contains stop words, on word for each line. 147 | :return: dictionary of stop words, key: word, value: word. 148 | """ 149 | stop_words = [line.strip() for line in open(stop_words_file, encoding="utf-8").readlines()] 150 | temp_dict = {} 151 | for word in stop_words: 152 | temp_dict[word] = word 153 | return temp_dict 154 | 155 | 156 | if __name__ == "__main__": 157 | data_file = "./../../../resources/top_self_report_extracted_symptom.csv" 158 | stop_words = "./data/stopwords.txt" 159 | classifier = ReportClassifier(stop_words=stop_words,data_file=data_file) 160 | -------------------------------------------------------------------------------- /src/dialogue_system/utils/plot_curve_each.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | For parameters. Drwaring the learning curve for each combination of parameters. 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | import pickle 9 | import os 10 | import time 11 | 12 | 13 | class Ploter(object): 14 | def __init__(self, performance_file): 15 | self.performance_file = performance_file 16 | self.epoch_size = 0 17 | self.success_rate = {} 18 | self.average_reward = {} 19 | self.average_wrong_disease = {} 20 | self.average_turn = {} 21 | 22 | def load_data(self, performance_file, label): 23 | performance = pickle.load(file=open(performance_file, "rb")) 24 | self.epoch_size = max(self.epoch_size, len(performance.keys())) 25 | sr, ar, awd,at = self.__load_data(performance=performance) 26 | self.success_rate[label] = sr 27 | self.average_reward[label] = ar 28 | self.average_wrong_disease[label] = awd 29 | self.average_turn[label] = at 30 | 31 | def __load_data(self, performance): 32 | success_rate = [] 33 | average_reward = [] 34 | average_wrong_disease = [] 35 | average_turn = [] 36 | for index in range(0, len(performance.keys()),1): 37 | print(performance[index].keys()) 38 | success_rate.append(performance[index]["success_rate"]) 39 | average_reward.append(performance[index]["average_reward"]) 40 | average_wrong_disease.append(performance[index]["average_wrong_disease"]) 41 | average_turn.append(performance[index]["average_turn"]) 42 | return success_rate, average_reward, average_wrong_disease,average_turn 43 | 44 | 45 | def plot(self, save_name, label_list): 46 | # epoch_index = [i for i in range(0, 500, 0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN)] 47 | 48 | for label in self.success_rate.keys(): 49 | epoch_index = [i for i in range(0, len(self.success_rate[label]), 1)] 50 | 51 | plt.plot(epoch_index,self.success_rate[label][0:max(epoch_index)+1], label=label, linewidth=1) 52 | # plt.plot(epoch_index,self.average_turn[label][0:max(epoch_index)+0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN], label=label+"at", linewidth=0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN) 53 | 54 | # plt.hlines(0.11,0,epoch_index,label="Random Agent", linewidth=0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN, colors="r") 55 | # plt.hlines(0.38,0,epoch_index,label="Rule Agent", linewidth=0220173244_AgentWithGoal_T22_lr0.0001_RFS44_RFF-22_RFNCY-1_RFIRS-1_mls0_gamma0.95_gammaW0.95_epsilon0.1_awd0_crs0_hwg0_wc0_var0_sdai0_wfrs0.0_dtft1_dataReal_World_RID3_DQN, colors="purple") 56 | 57 | plt.xlabel("Simulation Epoch") 58 | plt.ylabel("Success Rate") 59 | plt.title("Learning Curve") 60 | # if len(label_list) >= 2: 61 | # plt.legend() 62 | # plt.legend(loc="lower right") 63 | plt.grid(True) 64 | plt.savefig(save_name,dpi=400) 65 | 66 | plt.show() 67 | 68 | @staticmethod 69 | def get_dirlist(path, key_word_list=None, no_key_word_list=None): 70 | file_name_list = os.listdir(path) # 获得原始json文件所在目录里面的所有文件名称 71 | if key_word_list == None and no_key_word_list == None: 72 | temp_file_list = file_name_list 73 | elif key_word_list != None and no_key_word_list == None: 74 | temp_file_list = [] 75 | for file_name in file_name_list: 76 | have_key_words = True 77 | for key_word in key_word_list: 78 | if key_word not in file_name: 79 | have_key_words = False 80 | break 81 | else: 82 | pass 83 | if have_key_words == True: 84 | temp_file_list.append(file_name) 85 | elif key_word_list == None and no_key_word_list != None: 86 | temp_file_list = [] 87 | for file_name in file_name_list: 88 | have_no_key_word = False 89 | for no_key_word in no_key_word_list: 90 | if no_key_word in file_name: 91 | have_no_key_word = True 92 | break 93 | if have_no_key_word == False: 94 | temp_file_list.append(file_name) 95 | elif key_word_list != None and no_key_word_list != None: 96 | temp_file_list = [] 97 | for file_name in file_name_list: 98 | have_key_words = True 99 | for key_word in key_word_list: 100 | if key_word not in file_name: 101 | have_key_words = False 102 | break 103 | else: 104 | pass 105 | have_no_key_word = False 106 | for no_key_word in no_key_word_list: 107 | if no_key_word in file_name: 108 | have_no_key_word = True 109 | break 110 | else: 111 | pass 112 | if have_key_words == True and have_no_key_word == False: 113 | temp_file_list.append(file_name) 114 | 115 | return temp_file_list 116 | 117 | 118 | if __name__ == "__main__": 119 | # file_name = "./../model/dqn/learning_rate/learning_rate_d4_e999_agent1_dqn1.p" 120 | # file_name = "/Users/qianlong/Desktop/learning_rate_d4_e_agent1_dqn1_T22_lr0.001_SR44_mls0_gamma0.95_epsilon0.1_1499.p" 121 | # save_name = file_name + ".png" 122 | # ploter = Ploter(file_name) 123 | # ploter.load_data(performance_file=file_name, label="DQN Agent") 124 | # ploter.plot(save_name, label_list=["DQN Agent"]) 125 | 126 | # ploter.load_data("./../model/dqn/learning_rate/learning_rate_d7_e999_agent1_dqn1.p",label="d7a1q1") 127 | # ploter.load_data("./../model/dqn/learning_rate/learning_rate_d10_e999_agent1_dqn0.p",label="d10a1q0") 128 | # ploter.load_data("./../model/dqn/learning_rate/learning_rate_d10_e999_agent1_dqn1.p",label="d10a1q1") 129 | # ploter.plot(save_name, label_list=["d7a1q0", "d7a1q1", "d10a1q0", "d10a1q1"]) 130 | 131 | 132 | # Draw learning curve from directory. 133 | path = "/Volumes/LIUQL/dataset/1100/learning_rate/" 134 | save_path = "/Volumes/LIUQL/dataset/1100/learning_curve/" 135 | no_key_word_list = ["_99.", "_199.", "_299.", "_399."] 136 | performance_file_list = Ploter.get_dirlist(path=path,key_word_list=["_1499"],no_key_word_list=["_99.","_199.","_299.","_399."]) 137 | print("file_number:", len(performance_file_list)) 138 | time.sleep(8) 139 | 140 | for file_name in performance_file_list: 141 | print(file_name) 142 | performance_file = path + file_name 143 | save_name = save_path + file_name + ".png" 144 | ploter = Ploter(performance_file=performance_file) 145 | ploter.load_data(performance_file=performance_file,label="DQN Agent") 146 | ploter.plot(save_name=save_name,label_list=["DQN Agent"]) -------------------------------------------------------------------------------- /src/dialogue_system/policy_learning/dqn_with_goal.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import torch 4 | import random 5 | from collections import deque 6 | import sys, os 7 | sys.path.append(os.getcwd().replace("src/dialogue_system/policy_learning","")) 8 | from src.dialogue_system.policy_learning.dqn_torch import DQN 9 | from collections import namedtuple 10 | 11 | 12 | class DQNModelWithGoal(torch.nn.Module): 13 | """ 14 | The model in this file is reference to `Florensa, C., Duan, Y., & Abbeel, P. (2017). Stochastic neural networks for 15 | hierarchical reinforcement learning. arXiv preprint arXiv:1704.03012.` 16 | https://arxiv.org/abs/1704.03012 17 | """ 18 | def __init__(self, input_size, hidden_size, output_size, number_of_latent_variables, parameter): 19 | super(DQNModelWithGoal, self).__init__() 20 | self.params = parameter 21 | self.number_of_latent_variables = number_of_latent_variables 22 | self.tau = self.params.get("temperature") 23 | # different layers 24 | self.goal_layer1 = torch.nn.Linear(input_size, number_of_latent_variables, bias=True) 25 | self.policy_layer1 = torch.nn.Linear(input_size + number_of_latent_variables, output_size, bias=True) 26 | 27 | def forward(self, x): 28 | if torch.cuda.is_available(): 29 | x.cuda() 30 | goal = self.goal_generator(x) 31 | # print(goal) 32 | q_values = self.compute_q_value(x,goal) 33 | return q_values 34 | 35 | def goal_generator(self, x): 36 | logits = self.goal_layer1(x) 37 | goal_rep = torch.nn.functional.gumbel_softmax(logits=logits, tau=self.tau, hard=False) 38 | # goal_rep = torch.nn.functional.softmax(input=logits) 39 | return goal_rep 40 | 41 | def compute_q_value(self, x, goal): 42 | temp = torch.cat((x, goal), dim=1) 43 | q_values = self.policy_layer1(temp) 44 | return q_values 45 | 46 | 47 | class DQNModelWithGoal2(torch.nn.Module): 48 | """ 49 | The model in this file is reference to `Florensa, C., Duan, Y., & Abbeel, P. (2017). Stochastic neural networks for 50 | hierarchical reinforcement learning. arXiv preprint arXiv:1704.03012.` 51 | https://arxiv.org/abs/1704.03012 52 | """ 53 | def __init__(self, input_size, hidden_size, output_size, number_of_latent_variables, parameter): 54 | super(DQNModelWithGoal2, self).__init__() 55 | self.params = parameter 56 | self.number_of_latent_variables = number_of_latent_variables 57 | self.tau = self.params.get("temperature") 58 | # different layers 59 | self.goal_input_layer = torch.nn.Linear(input_size, hidden_size, bias=True) 60 | self.goal_state_abstract_layer = torch.nn.Linear(hidden_size, hidden_size, bias=True) 61 | self.goal_generate_layer = torch.nn.Linear(hidden_size, number_of_latent_variables, bias=True) 62 | 63 | self.policy_layer1 = torch.nn.Linear(hidden_size * number_of_latent_variables, output_size, bias=True) 64 | 65 | def forward(self, x): 66 | if torch.cuda.is_available(): 67 | x.cuda() 68 | batch_size = x.size()[0] 69 | goal, abstract_state = self.goal_generator(x) 70 | # print(goal) 71 | temp = torch.bmm(goal.unsqueeze(2), abstract_state.unsqueeze(1)) 72 | temp = temp.view(batch_size, -1) 73 | q_values = self.compute_q_value(temp) 74 | return q_values 75 | 76 | def goal_generator(self, x): 77 | h1 = self.goal_input_layer(x) 78 | h_state = self.goal_state_abstract_layer(torch.nn.functional.relu(h1)) 79 | goal_logits = self.goal_generate_layer(torch.nn.functional.relu(h1)) 80 | goal_rep = torch.nn.functional.gumbel_softmax(logits=goal_logits, tau=self.tau, hard=True) 81 | # goal_rep = torch.nn.functional.softmax(input=logits) 82 | return goal_rep, h_state 83 | 84 | def compute_q_value(self, x): 85 | q_values = self.policy_layer1(x) 86 | return q_values 87 | 88 | 89 | class DQNModelWithGoal3(torch.nn.Module): 90 | """ 91 | The model in this file is reference to `Florensa, C., Duan, Y., & Abbeel, P. (2017). Stochastic neural networks for 92 | hierarchical reinforcement learning. arXiv preprint arXiv:1704.03012.` 93 | https://arxiv.org/abs/1704.03012 94 | """ 95 | def __init__(self, input_size, hidden_size, output_size, number_of_latent_variables, parameter): 96 | super(DQNModelWithGoal3, self).__init__() 97 | self.params = parameter 98 | self.number_of_latent_variables = number_of_latent_variables 99 | self.tau = self.params.get("temperature") 100 | # different layers 101 | self.goal_input_layer = torch.nn.Linear(input_size, number_of_latent_variables, bias=True) 102 | 103 | self.policy_layer1 = torch.nn.Linear(input_size * number_of_latent_variables, output_size, bias=True) 104 | 105 | def forward(self, x): 106 | if torch.cuda.is_available(): 107 | x.cuda() 108 | batch_size = x.size()[0] 109 | goal = self.goal_generator(x) 110 | # print(goal) 111 | temp = torch.bmm(goal.unsqueeze(2), x.unsqueeze(1)) 112 | temp = temp.view(batch_size, -1) 113 | q_values = self.compute_q_value(temp) 114 | return q_values 115 | 116 | def goal_generator(self, x): 117 | logits = self.goal_input_layer(x) 118 | goal_rep = torch.nn.functional.gumbel_softmax(logits=logits, tau=self.tau, hard=True) 119 | # goal_rep = torch.nn.functional.softmax(input=logits) 120 | return goal_rep 121 | 122 | def compute_q_value(self, x): 123 | q_values = self.policy_layer1(x) 124 | return q_values 125 | 126 | 127 | class DQNWithGoal(DQN): 128 | def __init__(self, input_size, hidden_size, output_size, parameter): 129 | super(DQNWithGoal, self).__init__(input_size, hidden_size, output_size, parameter) 130 | del self.current_net 131 | del self.target_net 132 | 133 | self.params = parameter 134 | self.Transition = namedtuple('Transition', ('state', 'agent_action', 'reward', 'next_state', 'episode_over')) 135 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 136 | 137 | self.current_net = DQNModelWithGoal(input_size, hidden_size, output_size, 4, parameter).to(self.device) 138 | self.target_net = DQNModelWithGoal(input_size, hidden_size, output_size, 4, parameter).to(self.device) 139 | print(self.current_net) 140 | 141 | if torch.cuda.is_available(): 142 | if parameter["multi_GPUs"] == True: # multi GPUs 143 | self.current_net = torch.nn.DataParallel(self.current_net) 144 | self.target_net = torch.nn.DataParallel(self.target_net) 145 | else:# Single GPU 146 | self.current_net.cuda(device=self.device) 147 | self.target_net.cuda(device=self.device) 148 | 149 | self.target_net.load_state_dict(self.current_net.state_dict()) # Copy paraameters from current networks. 150 | self.target_net.eval() # set this model as evaluate mode. And it's parameters will not be updated. 151 | 152 | # Optimizer with L2 regularization 153 | weight_p, bias_p = [], [] 154 | for name, p in self.current_net.named_parameters(): 155 | if 'bias' in name: 156 | bias_p.append(p) 157 | else: 158 | weight_p.append(p) 159 | 160 | self.optimizer = torch.optim.Adam([ 161 | {'params': weight_p, 'weight_decay': 0.1}, # with L2 regularization 162 | {'params': bias_p, 'weight_decay': 0} # no L2 regularization. 163 | ], lr=self.params.get("dqn_learning_rate",0.001)) 164 | 165 | if self.params.get("train_mode") is False: 166 | self.restore_model(self.params.get("saved_model")) -------------------------------------------------------------------------------- /src/dialogue_system/utils/draw_curve_std.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | """ 3 | 用于画learning curve的图,这里只是不同agent之间进行对比分析,不包含simulator的greedy程度。 4 | """ 5 | 6 | from __future__ import print_function 7 | import argparse, json 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | import numpy as np 11 | import pickle 12 | import sys, os 13 | sys.path.append(os.getcwd().replace('/src/utils','')) 14 | 15 | sns.set(style="darkgrid") 16 | sns.set(font_scale=1.4) 17 | 18 | width = 8 19 | height = 5.8 20 | plt.figure(figsize=(width, height)) 21 | 22 | linewidth = 1.1 23 | 24 | class DrawCurve(object): 25 | def __init__(self, params): 26 | self.params = params 27 | 28 | def read_performance_records(self, path): 29 | """ load the performance score (.json) file """ 30 | print(path) 31 | performance = pickle.load(file=open(path, 'rb')) 32 | 33 | success_rate = [] 34 | average_reward = [] 35 | average_wrong_disease = [] 36 | average_turn = [] 37 | for index in range(0, len(performance.keys()),1): 38 | print(performance[index].keys()) 39 | success_rate.append(performance[index]["success_rate"]) 40 | average_reward.append(performance[index]["average_reward"]) 41 | average_wrong_disease.append(performance[index]["average_wrong_disease"]) 42 | average_turn.append(performance[index]["average_turn"]) 43 | 44 | smooth_num = 1 45 | d = [success_rate[i * smooth_num:i * smooth_num + smooth_num] for i in 46 | range(int(len(success_rate) / smooth_num))] 47 | 48 | success_rate_new = [] 49 | cache = 0 50 | alpha = 0.8 51 | for i in d: 52 | cur = sum(i) / float(smooth_num) 53 | cache = cache * alpha + (1 - alpha) * cur 54 | success_rate_new.append(cache) 55 | return success_rate_new, success_rate[399] 56 | 57 | def get_mean_var(self, path,key_word_list=None, no_key_word_list=None): 58 | file_list = DrawCurve.get_dir_list(path=path, key_word_list=key_word_list, no_key_word_list=no_key_word_list) 59 | BBQ_datapoint = [] 60 | data_point = [] 61 | for file_name in file_list: 62 | data_list, data_scalar = self.read_performance_records(os.path.join(path,file_name)) 63 | BBQ_datapoint.append(data_list) 64 | data_point.append(data_scalar) 65 | # BBQ_datapoint.append(self.read_performance_records(file_name,role,metric)) 66 | min_len = min(len(i) for i in BBQ_datapoint) 67 | print([len(i) for i in BBQ_datapoint]) 68 | data = np.asarray([i[0:min_len] for i in BBQ_datapoint]) 69 | mean = np.mean(data, axis=0) 70 | var = np.std(data, axis=0) 71 | mean_data_point = np.mean(data_point) 72 | return mean, var, min_len, mean_data_point 73 | 74 | def plot(self): 75 | colors = ['#2f79c0', '#278b18', '#ff5186', '#8660a4', '#D49E0F', '#a8d40f'] 76 | global_idx = 1500 77 | min_len_list = [] 78 | ave_result = {} 79 | 80 | 81 | no_key_word_list = ['.DS_Store','.pdf'] 82 | 83 | key_word_list = ['dqn'] 84 | mean, var, min_len,mean_point = self.get_mean_var(path=self.params['result_path'], 85 | key_word_list=key_word_list, 86 | no_key_word_list=no_key_word_list) 87 | min_len_list.append(min_len) 88 | l1, = plt.plot(range(mean.shape[0]), mean, colors[0], label='RL-agent', linewidth=linewidth) 89 | plt.fill_between(range(mean.shape[0]), mean + var / 2, mean - var / 2, facecolor=colors[0], alpha=0.2) 90 | ave_result['RL-agent'] = mean_point 91 | 92 | 93 | # key_word_list = ['sim1', 'issdecay1', 'rac0'] 94 | # mean, var, min_len,mean_point = self.get_mean_var(path=self.params['result_path'], 95 | # key_word_list=key_word_list, 96 | # no_key_word_list=no_key_word_list) 97 | # min_len_list.append(min_len) 98 | # l2, = plt.plot(range(mean.shape[0]), mean, colors[1], label='DQN-Sim(Decay=1)', linewidth=linewidth) 99 | # plt.fill_between(range(mean.shape[0]), mean + var / 2, mean - var / 2, facecolor=colors[1], alpha=0.2) 100 | # ave_result['DQN(Decay=1,RC=0)'] = mean_point 101 | 102 | min_len = min(min_len_list) 103 | plt.grid(True) 104 | plt.ylabel('Success Rate') 105 | plt.xlabel('Simulation Epoch') 106 | plt.xlim([0, min_len]) 107 | plt.legend(loc=4) 108 | # plt.savefig('learning_curve.png') 109 | # plt.show() 110 | plt.savefig(os.path.join(self.params['result_path'] + '_sr_' + str(min_len) + '.pdf')) 111 | print(ave_result) 112 | 113 | @staticmethod 114 | def get_dir_list(path, key_word_list=None, no_key_word_list=None): 115 | file_name_list = os.listdir(path) # 获得原始json文件所在目录里面的所有文件名称 116 | if key_word_list == None and no_key_word_list == None: 117 | temp_file_list = file_name_list 118 | elif key_word_list != None and no_key_word_list == None: 119 | temp_file_list = [] 120 | for file_name in file_name_list: 121 | have_key_words = True 122 | for key_word in key_word_list: 123 | if key_word not in file_name: 124 | have_key_words = False 125 | break 126 | else: 127 | pass 128 | if have_key_words == True: 129 | temp_file_list.append(file_name) 130 | elif key_word_list == None and no_key_word_list != None: 131 | temp_file_list = [] 132 | for file_name in file_name_list: 133 | have_no_key_word = False 134 | for no_key_word in no_key_word_list: 135 | if no_key_word in file_name: 136 | have_no_key_word = True 137 | break 138 | if have_no_key_word == False: 139 | temp_file_list.append(file_name) 140 | elif key_word_list != None and no_key_word_list != None: 141 | temp_file_list = [] 142 | for file_name in file_name_list: 143 | have_key_words = True 144 | for key_word in key_word_list: 145 | if key_word not in file_name: 146 | have_key_words = False 147 | break 148 | else: 149 | pass 150 | have_no_key_word = False 151 | for no_key_word in no_key_word_list: 152 | if no_key_word in file_name: 153 | have_no_key_word = True 154 | break 155 | else: 156 | pass 157 | if have_key_words == True and have_no_key_word == False: 158 | temp_file_list.append(file_name) 159 | print(key_word_list, len(temp_file_list)) 160 | # time.sleep(2) 161 | return temp_file_list 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser() 166 | 167 | parser.add_argument('--result_path', dest='result_path', type=str, default='/Users/qianlong/Desktop/flat_dqn/', help='the directory of the results.') 168 | 169 | parser.add_argument('--metric', dest='metric', type=str, default='recall', help='the metric to show') 170 | 171 | args = parser.parse_args() 172 | params = vars(args) 173 | drawer = DrawCurve(params) 174 | drawer.plot() 175 | -------------------------------------------------------------------------------- /preprocess/label/preprocess_label.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | 将txt的action和json的action、症状都转化为只包含action, symptom的list,并持久化到文件保存, 4 | 后期使用中直接调用持久化的文件即可。这里每一个symptom都作为一个slot进行处理。 5 | """ 6 | import pickle 7 | import json 8 | import random 9 | import copy 10 | 11 | 12 | class ActionDumper(object): 13 | """ 14 | 处理action文件,保存成list并进行持久化处理。 15 | """ 16 | def __init__(self, action_set_file): 17 | self.file_name = action_set_file 18 | 19 | def dump(self, dump_file_name): 20 | data_file = open(self.file_name, "r") 21 | action_set = [] 22 | for line in data_file: 23 | action_set.append(line.replace("\n","")) 24 | data_file.close() 25 | action_set_dict = {} 26 | for index in range(0, len(action_set), 1): 27 | action_set_dict[action_set[index]] = index 28 | pickle.dump(file=open(dump_file_name,"wb"), obj=action_set_dict) 29 | 30 | 31 | class SlotDumper(object): 32 | """ 33 | 处理disease_symptom文件,将里面的每一个symptom作为一个slot处理,进行持久化。 34 | """ 35 | def __init__(self, slots_file, hand_crafted_symptom=True): 36 | self.file_name = slots_file 37 | self.hand_crafted_symptom = hand_crafted_symptom 38 | 39 | def dump(self, slot_dump_file_name, disease_dump_file_name): 40 | self._load_slot() 41 | self.slot_set.add("disease") 42 | # self.slot_set.add("taskcomplete") 43 | 44 | slot_set = list(self.slot_set) 45 | slot_set_dict = {} 46 | for index in range(0, len(slot_set), 1): 47 | slot_set_dict[slot_set[index]] = index 48 | pickle.dump(file=open(slot_dump_file_name,"wb"), obj=slot_set_dict) 49 | pickle.dump(file=open(disease_dump_file_name, "wb"), obj=self.disease_symptom) 50 | 51 | def _load_slot(self): 52 | self.slot_set = set() 53 | self.disease_symptom = {} 54 | data_file = open(file=self.file_name, mode="r",encoding="utf-8") 55 | if self.hand_crafted_symptom == True: 56 | index = 0 57 | for line in data_file: 58 | line = json.loads(line) 59 | self.disease_symptom[line["name"]] = {} 60 | self.disease_symptom[line["name"]]["index"] = index 61 | self.disease_symptom[line["name"]]["symptom"] = list(line["symptom"].keys()) 62 | for key in line["symptom"].keys(): 63 | self.slot_set.add(key) 64 | index += 1 65 | else: 66 | index = 0 67 | for line in data_file: 68 | line = json.loads(line) 69 | self.disease_symptom[line["name"]] = {} 70 | self.disease_symptom[line["name"]]["index"] = index 71 | self.disease_symptom[line["name"]]["symptom"] = line["symptom"] 72 | for symptom in line["symptom"]: 73 | self.slot_set.add(symptom) 74 | index += 1 75 | 76 | data_file.close() 77 | 78 | 79 | class GoalDumper(object): 80 | def __init__(self, goal_file): 81 | self.file_name = goal_file 82 | self.slot_set = set() 83 | self.disease_symptom = {} 84 | 85 | def dump(self, dump_file_name, train=0.8, test=0.2, validate=0.0): 86 | assert (train*100+test*100+validate*100==100), "train + test + validate not equals to 1.0." 87 | self.goal_set = [] 88 | data_file = open(file=self.file_name, mode="r") 89 | for line in data_file: 90 | line = json.loads(line) 91 | line['disease_tag'] = line['disease_tag'].replace(' ', '') 92 | temp_line = copy.deepcopy(line) 93 | for symptom, value in temp_line['goal']['implicit_inform_slots'].items(): 94 | line['goal']['implicit_inform_slots'].pop(symptom) 95 | line['goal']['implicit_inform_slots'][symptom.replace(' ', '')] = value 96 | 97 | for symptom, value in temp_line['goal']['explicit_inform_slots'].items(): 98 | line['goal']['explicit_inform_slots'].pop(symptom) 99 | line['goal']['explicit_inform_slots'][symptom.replace(' ', '')] = value 100 | 101 | self.goal_set.append(line) 102 | data_file.close() 103 | goal_number = len(self.goal_set) 104 | data_set = { 105 | "train":[], 106 | "test":[], 107 | "validate":[] 108 | } 109 | 110 | # for goal set 111 | goal_disease = {} 112 | for goal in self.goal_set: 113 | goal_disease.setdefault(goal['disease_tag'], list()) 114 | goal_disease[goal['disease_tag']].append(goal) 115 | 116 | for disease, v in goal_disease.items(): 117 | random.shuffle(v) 118 | number = len(v) 119 | train_n = int(number * train) 120 | test_n = int(number * test) 121 | validate_n = int(number * validate) 122 | data_set['train'] = data_set['train'] + v[0:train_n] 123 | data_set['test'] = data_set['test'] + v[train_n:train_n + test_n] 124 | data_set['validate'] = data_set['validate'] + v[number-validate_n:number] 125 | 126 | 127 | for goal in self.goal_set: 128 | # random_float = random.random() 129 | # if random_float <= train: 130 | # data_set["train"].append(goal) 131 | # elif train < random_float and random_float <= train+test: 132 | # data_set["test"].append(goal) 133 | # else: 134 | # data_set["validate"].append(goal) 135 | 136 | for slot, value in goal["goal"]["explicit_inform_slots"].items(): 137 | if value == False: print(goal) 138 | break 139 | for slot, value in goal["goal"]["implicit_inform_slots"].items(): 140 | if value == False: print(goal) 141 | break 142 | 143 | # for slot. 144 | for symptom in goal["goal"]["explicit_inform_slots"].keys(): self.slot_set.add(symptom) 145 | for symptom in goal["goal"]["implicit_inform_slots"].keys(): self.slot_set.add(symptom) 146 | 147 | # for disease_symptom 148 | key_num = len(self.disease_symptom.keys()) 149 | disease = goal['disease_tag'] 150 | self.disease_symptom.setdefault(disease,{'index':key_num,'symptom':dict()}) 151 | for symptom in goal["goal"]["explicit_inform_slots"].keys(): 152 | self.disease_symptom[disease]['symptom'].setdefault(symptom, 0) 153 | self.disease_symptom[disease]['symptom'][symptom] += 1 154 | for symptom in goal["goal"]["implicit_inform_slots"].keys(): 155 | self.disease_symptom[disease]['symptom'].setdefault(symptom, 0) 156 | self.disease_symptom[disease]['symptom'][symptom] += 1 157 | 158 | pickle.dump(file=open(dump_file_name,"wb"), obj=data_set) 159 | 160 | def dump_slot(self,slot_file): 161 | slot_set_dict = {} 162 | slot_set = list(self.slot_set) 163 | for index in range(0, len(slot_set), 1): 164 | slot_set_dict[slot_set[index]] = index 165 | slot_set_dict['disease'] = index + 1 166 | print(slot_set_dict) 167 | pickle.dump(file=open(slot_file,"wb"),obj=slot_set_dict) 168 | 169 | def dump_disease_symptom(self, disease_symptom_file): 170 | print(self.disease_symptom) 171 | pickle.dump(file=open(disease_symptom_file,'wb'), obj=self.disease_symptom) 172 | 173 | 174 | 175 | 176 | 177 | if __name__ == "__main__": 178 | # Action 179 | # action_file = "./../../../resources/action_set.txt" 180 | # action_dump_file = "./../data/action_set.p" 181 | # 182 | # action_dumper = ActionDumper(action_set_file=action_file) 183 | # action_dumper.dump(dump_file_name=action_dump_file) 184 | 185 | # Slots. 186 | # slots_file = "./../../../resources/top_disease_symptom_aligned.json" 187 | # slots_dump_file = "./../data/slot_set.p" 188 | # disease_dump_file = "./../data/disease_symptom.p" 189 | # slots_dumper = SlotDumper(slots_file=slots_file) 190 | # slots_dumper.dump(slot_dump_file_name=slots_dump_file,disease_dump_file_name=disease_dump_file) 191 | 192 | # Goal 193 | goal_file = "./../../resources/label/used/goal_find.json" 194 | goal_dump_file = "./../../resources/label/used/goal_set.p" 195 | slots_dump_file = "./../../resources/label/used/slot_set.p" 196 | disease_dump_file = "./../../resources/label/used/disease_symptom.p" 197 | goal_dumper = GoalDumper(goal_file=goal_file) 198 | goal_dumper.dump(dump_file_name=goal_dump_file) 199 | goal_dumper.dump_slot(slots_dump_file) 200 | goal_dumper.dump_disease_symptom(disease_dump_file) 201 | 202 | slot = pickle.load(open(slots_dump_file,'rb')) 203 | 204 | print(len(slot)) -------------------------------------------------------------------------------- /preprocess/aligned_symptoms_extracting.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | """ 3 | 对主诉症状、问答症状进行归一。 4 | 使用简单的字符串相识度匹配。 5 | """ 6 | import csv 7 | import pandas 8 | import difflib 9 | import time 10 | import json 11 | import Levenshtein 12 | 13 | 14 | class SymptomAligner(object): 15 | """ 16 | Aligning spoken symptom on writing symptom. 17 | """ 18 | def __init__(self, aligned_symptom_file, threshold, hand_crafted_symptom=True): 19 | self.hand_crafted_symptom = hand_crafted_symptom 20 | self.threshold = threshold 21 | self.aligned_symptom = dict() 22 | data_file = open(aligned_symptom_file, "r", encoding="utf-8") 23 | for line in data_file: 24 | line = json.loads(line) 25 | self.aligned_symptom.setdefault(line["name"],dict()) 26 | self.aligned_symptom[line["name"]]["symptom"] = line["symptom"] 27 | self.aligned_symptom[line["name"]]["src_symptom"] = line["src_symptom"] 28 | data_file.close() 29 | 30 | def align(self, spoken_symptom): 31 | """ 32 | Return the writing symptom given a spoken symptom using the similarity score between strings. 33 | :param spoken_symptom: spoken_symptom 34 | :return: writing symptom aligned with the spoken_symptom. 35 | """ 36 | similarity_score = {} 37 | if self.hand_crafted_symptom == True: 38 | for disease in self.aligned_symptom.keys(): 39 | for key, value in self.aligned_symptom[disease]["symptom"].items(): 40 | similarity_score[key] = Levenshtein.ratio(spoken_symptom.replace("小儿", ""), key.replace("小儿", "")) 41 | for symptom in value: 42 | score = Levenshtein.ratio(spoken_symptom.replace("小儿", ""), symptom.replace("小儿", "")) 43 | if score > similarity_score[key]: 44 | similarity_score[key] = score 45 | else: 46 | for disease in self.aligned_symptom.keys(): 47 | for symptom in self.aligned_symptom[disease]["symptom"]: 48 | similarity_score[symptom] = Levenshtein.ratio(spoken_symptom.replace("小儿", ""), symptom.replace("小儿", "")) 49 | 50 | # for key, value in self.aligned_symptom[disease]["src_symptom"].items(): 51 | # score = Levenshtein.ratio(spoken_symptom.replace("小儿", ""), key.replace("小儿", "")) 52 | 53 | writing_symptom = sorted(similarity_score, key=lambda x:similarity_score[x])[-1] 54 | score = similarity_score[writing_symptom] 55 | if score >= self.threshold: 56 | # print("writing_symptom:", writing_symptom, "score:", score,"spoken_symptom:", spoken_symptom) 57 | return writing_symptom 58 | else: 59 | return None 60 | 61 | 62 | class DataLoader(object): 63 | def __init__(self,threshold, disease_symptom_aligned_file, hand_crafted_symptom, top_disease_list): 64 | self.top_disease_list = top_disease_list 65 | self.symptom_aligner = SymptomAligner(disease_symptom_aligned_file, threshold=threshold,hand_crafted_symptom=hand_crafted_symptom) 66 | self.sample = {} 67 | self.symptom_slots = set() 68 | self.deny_list = ["不","否","没有","没"] 69 | 70 | def load_self_report(self, self_report_file): 71 | """ 72 | 用来对主诉内容的症状进行归一化处理。 73 | :param self_report_file: 74 | :return: 75 | """ 76 | data_reader = csv.reader(open(self_report_file, "r",encoding="utf-8")) 77 | for line in data_reader: 78 | # print(line) 79 | if line[5] not in self.top_disease_list: continue 80 | self.sample.setdefault(line[4], dict()) 81 | self.sample[line[4]]["request_slots"] = {"disease":line[5]} 82 | self.sample[line[4]].setdefault("explicit_inform_slots", dict()) 83 | self.sample[line[4]].setdefault("implicit_inform_slots", dict()) 84 | try: 85 | index = line.index("") 86 | except: 87 | index = len(line) 88 | symptom_list = line[8:index] 89 | for symptom in symptom_list: 90 | spoken_symptom = symptom.replace("\n","") 91 | writing_symptom = self.symptom_aligner.align(spoken_symptom) 92 | if writing_symptom != None: 93 | self.sample[line[4]]["explicit_inform_slots"][spoken_symptom] = writing_symptom 94 | 95 | def load_conversation(self, conversation_file): 96 | """ 97 | 用来对conversation的症状数据进行归一化处理。 98 | :param conversation_file: 99 | :return: 100 | """ 101 | data_file = open(conversation_file, mode="r", encoding="utf-8") 102 | for line in data_file: 103 | line = line.split("\t") 104 | temp_line = [] 105 | for index in range(0, len(line)): 106 | if len(line[index]) != 0: temp_line.append(line[index]) 107 | line = temp_line 108 | 109 | # 判断是否是四种疾病下的conversation,然后进行症状归一化。 110 | # if line[0] in self.sample.keys() and str(line[1]) == str(2):#只抽取患者说的症状。2:患者说的话,3:医生说的话 111 | if line[0] in self.sample.keys():#抽取患者和医生说的症状。 112 | for index in range(3, len(line)): 113 | spoken_symptom = line[index].replace("\n","") 114 | writing_symptom = self.symptom_aligner.align(spoken_symptom) 115 | if writing_symptom != None and (spoken_symptom not in self.sample[line[0]]["explicit_inform_slots"].keys()): 116 | self.sample[line[0]]["implicit_inform_slots"][spoken_symptom] = writing_symptom 117 | data_file.close() 118 | 119 | def write_slot_value(self, file_name): 120 | data_file = open(file_name,mode="w") 121 | for key, value in self.sample.items(): 122 | line = {} 123 | line["consult_id"] = key 124 | line["disease_tag"] = self.sample[key]["request_slots"]["disease"] 125 | line["goal"] = {} 126 | line["goal"].setdefault("request_slots", dict()) 127 | line["goal"].setdefault("explicit_inform_slots", dict()) 128 | line["goal"].setdefault("implicit_inform_slots", dict()) 129 | line["goal"]["request_slots"]["disease"] = "UNK" 130 | 131 | for spoken_symptom, writing_symptom in value["explicit_inform_slots"].items(): 132 | print("spoken:",spoken_symptom, writing_symptom) 133 | line["goal"]["explicit_inform_slots"][writing_symptom] = self._true_or_false(spoken_symptom, writing_symptom) 134 | self.symptom_slots.add(writing_symptom) 135 | for spoken_symptom, writing_symptom in value["implicit_inform_slots"].items(): 136 | print("spoken:",spoken_symptom, writing_symptom) 137 | if writing_symptom in line["goal"]["explicit_inform_slots"].keys(): continue 138 | line["goal"]["implicit_inform_slots"][writing_symptom] = self._true_or_false(spoken_symptom,writing_symptom) 139 | self.symptom_slots.add(writing_symptom) 140 | 141 | data_file.write(json.dumps(line) + "\n") 142 | data_file.close() 143 | 144 | def write(self, file_name): 145 | data_file = open(file_name,mode="w") 146 | data_file.write(json.dumps(self.sample) + "\n") 147 | data_file.close() 148 | 149 | def write_slots(self, file_name): 150 | data_file = open(file_name, mode="w",encoding="utf-8") 151 | for symptom in self.symptom_slots: 152 | data_file.write(symptom + "\n") 153 | data_file.close() 154 | 155 | def _true_or_false(self, spoken_symptom, writing_symptom): 156 | exception_symptom_list = ["烦躁不安","呼吸不畅"] 157 | exception_list = ["不舒服"] 158 | return_value = True 159 | for s_ in self.deny_list: 160 | if s_ in spoken_symptom and writing_symptom not in exception_symptom_list: 161 | return_value = False 162 | 163 | for e in exception_list: 164 | if e in spoken_symptom: 165 | return_value = True 166 | return return_value 167 | 168 | 169 | 170 | if __name__ == "__main__": 171 | threshold = 0.2 172 | disease_symptom_aligned_file = "./../resources/top_disease_symptom_aligned.json" 173 | report_loader = DataLoader(threshold=threshold,disease_symptom_aligned_file=disease_symptom_aligned_file) 174 | report_loader.load_self_report("./../resources/top_self_report_extracted_symptom.csv") 175 | print("Conversation:") 176 | time.sleep(5) 177 | report_loader.load_conversation("/Users/qianlong/Documents/Qianlong/Research/MedicalChatbot/origin_file/conversation_symptom.txt") 178 | 179 | report_loader.write("./../resources/goal_spoken_writing_"+str(threshold) + ".json") 180 | report_loader.write_slot_value("./../resources/goal_slot_value_"+str(threshold) + ".json") 181 | -------------------------------------------------------------------------------- /src/dialogue_system/disease_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional 3 | import os 4 | import numpy as np 5 | from collections import namedtuple 6 | import pickle 7 | import copy 8 | import random 9 | 10 | class Model(torch.nn.Module): 11 | """ 12 | DQN model with one fully connected layer, written in pytorch. 13 | """ 14 | def __init__(self, input_size, hidden_size, output_size): 15 | super(Model, self).__init__() 16 | # different layers. Two layers. 17 | self.policy_layer = torch.nn.Sequential( 18 | torch.nn.Linear(input_size, hidden_size, bias=True), 19 | torch.nn.Dropout(0.3), 20 | torch.nn.LeakyReLU(), 21 | #torch.nn.Linear(hidden_size,hidden_size), 22 | #torch.nn.Dropout(0.5), 23 | #torch.nn.LeakyReLU(), 24 | torch.nn.Linear(hidden_size, output_size, bias=True) 25 | ) 26 | 27 | # one layer. 28 | #self.policy_layer = torch.nn.Linear(input_size, output_size, bias=True) 29 | 30 | def forward(self, x): 31 | if torch.cuda.is_available(): 32 | x.cuda() 33 | q_values = self.policy_layer(x.float()) 34 | return q_values 35 | 36 | class dl_classifier(object): 37 | def __init__(self, input_size, hidden_size, output_size, parameter): 38 | self.parameter = parameter 39 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | self.model = Model(input_size=input_size, hidden_size=hidden_size, output_size=output_size).to(self.device) 41 | 42 | weight_p, bias_p = [], [] 43 | for name, p in self.model.named_parameters(): 44 | if 'bias' in name: 45 | bias_p.append(p) 46 | else: 47 | weight_p.append(p) 48 | 49 | self.optimizer = torch.optim.Adam([ 50 | {'params': weight_p, 'weight_decay': 0.001}, # with L2 regularization 51 | {'params': bias_p, 'weight_decay': 0} # no L2 regularization. 52 | ], lr=0.0004) 53 | #], lr=parameter.get("dqn_learning_rate")) 54 | 55 | self.criterion = torch.nn.CrossEntropyLoss() 56 | named_tuple = ("slot","disease") 57 | self.Transition = namedtuple('Transition', named_tuple) 58 | #self.test_batch = self.create_data(train_mode=False) 59 | 60 | #if self.params.get("train_mode") is False and self.params.get("agent_id").lower() == 'agentdqn': 61 | # self.restore_model(self.params.get("saved_model")) 62 | 63 | def train(self, batch): 64 | batch = self.Transition(*zip(*batch)) 65 | #print(batch.slot.shape) 66 | slot = torch.LongTensor(batch.slot).to(self.device) 67 | disease = torch.LongTensor(batch.disease).to(self.device) 68 | out = self.model.forward(slot) 69 | #print(disease.shape) 70 | #print(out.shape) 71 | #print(out.shape, disease) 72 | loss = self.criterion(out, disease) 73 | 74 | self.optimizer.zero_grad() 75 | loss.backward() 76 | self.optimizer.step() 77 | return {"loss": loss.item()} 78 | 79 | def predict(self, slots): 80 | self.model.eval() 81 | # print(batch.slot.shape) 82 | slots = torch.LongTensor(slots).to(self.device) 83 | Ys = self.model.forward(slots) 84 | max_index = np.argmax(Ys.detach().cpu().numpy(), axis=1) 85 | self.model.train() 86 | return Ys, max_index 87 | 88 | 89 | def train_dl_classifier(self, epochs): 90 | batch_size = self.parameter.get("batch_size") 91 | #print(batch_size) 92 | #print(self.total_batch[0]) 93 | total_batch = self.create_data(train_mode=True) 94 | for iter in range(epochs): 95 | batch = random.sample(total_batch, batch_size) 96 | #print(batch[0][0].shape) 97 | loss = self.train(batch) 98 | if iter%100==0: 99 | print('epoch:{},loss:{:.4f}'.format(iter, loss["loss"])) 100 | 101 | def test_dl_classifier(self): 102 | self.model.eval() 103 | self.test_batch = self.create_data(train_mode=False) 104 | batch = self.Transition(*zip(*self.test_batch)) 105 | slot = torch.LongTensor(batch.slot).to(self.device) 106 | #disease = torch.LongTensor(batch.disease).to(self.device) 107 | disease = batch.disease 108 | Ys, pred = self.predict(slot) 109 | #print(pred) 110 | num_correct = len([1 for i in range(len(disease)) if disease[i]==pred[i]]) 111 | print("the test accuracy is %f", num_correct / len(self.test_batch)) 112 | self.model.train() 113 | 114 | def test(self, test_batch): 115 | #self.model.eval() 116 | 117 | batch = self.Transition(*zip(*test_batch)) 118 | slot = torch.LongTensor(batch.slot).to(self.device) 119 | #disease = torch.LongTensor(batch.disease).to(self.device) 120 | disease = batch.disease 121 | Ys, pred = self.predict(slot.cpu()) 122 | #print(pred) 123 | num_correct = len([1 for i in range(len(disease)) if disease[i]==pred[i]]) 124 | #print("the test accuracy is %f", num_correct / len(self.test_batch)) 125 | test_acc = num_correct / len(test_batch) 126 | #self.model.train() 127 | return test_acc 128 | 129 | 130 | 131 | def create_data(self, train_mode): 132 | goal_set = pickle.load(open(self.parameter.get("goal_set"), 'rb')) 133 | self.slot_set = pickle.load(open(self.parameter.get("slot_set"), 'rb')) 134 | disease_symptom = pickle.load(open(self.parameter.get("disease_symptom"),'rb')) 135 | 136 | self.disease2id = {} 137 | for disease, v in disease_symptom.items(): 138 | self.disease2id[disease] = v['index'] 139 | self.slot_set.pop('disease') 140 | disease_y = [] 141 | # total_set = random.sample(goal_set['train'], 10000) 142 | if train_mode==True: 143 | total_set = copy.deepcopy(goal_set["train"]) 144 | else: 145 | total_set = copy.deepcopy(goal_set["test"]) 146 | total_batch = [] 147 | 148 | 149 | for i, dialogue in enumerate(total_set): 150 | slots_exp = [0] * len(self.slot_set) 151 | tag = dialogue['disease_tag'] 152 | # tag_group=disease_symptom1[tag]['symptom'] 153 | disease_y.append(tag) 154 | goal = dialogue['goal'] 155 | explicit = goal['explicit_inform_slots'] 156 | for exp_slot, value in explicit.items(): 157 | #try: 158 | slot_id = self.slot_set[exp_slot] 159 | if value == True: 160 | slots_exp[slot_id] = 1 161 | #except: 162 | # pass 163 | if sum(slots_exp) == 0: 164 | print("############################") 165 | total_batch.append((slots_exp, self.disease2id[tag])) 166 | #print("the disease data creation is over") 167 | return total_batch 168 | 169 | def save_model(self, model_performance, episodes_index, checkpoint_path): 170 | if os.path.isdir(checkpoint_path) == False: 171 | os.makedirs(checkpoint_path) 172 | agent_id = self.parameter.get("agent_id").lower() 173 | disease_number = self.parameter.get("disease_number") 174 | success_rate = model_performance["success_rate"] 175 | average_reward = model_performance["average_reward"] 176 | average_turn = model_performance["average_turn"] 177 | average_match_rate = model_performance["average_match_rate"] 178 | average_match_rate2 = model_performance["average_match_rate2"] 179 | model_file_name = os.path.join(checkpoint_path, "model_d" + str(disease_number) + str(agent_id) + "_s" + str( 180 | success_rate) + "_r" + str(average_reward) + "_t" + str(average_turn) \ 181 | + "_mr" + str(average_match_rate) + "_mr2-" + str( 182 | average_match_rate2) + "_e-" + str(episodes_index) + ".pkl") 183 | 184 | torch.save(self.model.state_dict(), model_file_name) 185 | 186 | def restore_model(self, saved_model): 187 | """ 188 | Restoring the trained parameters for the model. Both current and target net are restored from the same parameter. 189 | 190 | Args: 191 | saved_model (str): the file name which is the trained model. 192 | """ 193 | print("loading trained model", saved_model) 194 | if torch.cuda.is_available() is False: 195 | map_location = 'cpu' 196 | else: 197 | map_location = None 198 | self.model.load_state_dict(torch.load(saved_model,map_location=map_location)) 199 | 200 | def eval_mode(self): 201 | self.model.eval() 202 | 203 | 204 | -------------------------------------------------------------------------------- /src/dialogue_system/state_tracker/state_tracker.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | State tracker of the dialogue system, which tracks the state of the dialogue during interaction. 4 | """ 5 | 6 | import sys, os 7 | import copy 8 | import json 9 | sys.path.append(os.getcwd().replace("src/dialogue_system/state_tracker", "")) 10 | 11 | from src.dialogue_system import dialogue_configuration 12 | 13 | 14 | class StateTracker(object): 15 | def __init__(self, user, agent,parameter): 16 | self.user = user 17 | self.agent = agent 18 | self._init() 19 | 20 | def get_state(self): 21 | return copy.deepcopy(self.state) 22 | # return self.state 23 | 24 | def state_updater(self, user_action=None, agent_action=None): 25 | assert (user_action is None or agent_action is None), "user action and agent action cannot be None at the same time." 26 | self.state["turn"] = self.turn 27 | if user_action is not None: 28 | self._state_update_with_user_action(user_action=user_action) 29 | elif agent_action is not None: 30 | self._state_update_with_agent_action(agent_action=agent_action) 31 | self.turn += 1 32 | 33 | def initialize(self): 34 | self._init() 35 | 36 | def _init(self): 37 | self.turn = 0 38 | self.state = { 39 | "agent_action":None, 40 | "user_action":None, 41 | "turn":self.turn, 42 | "current_slots":{ 43 | "user_request_slots":{}, 44 | "agent_request_slots":{}, 45 | "inform_slots":{}, 46 | "explicit_inform_slots":{}, 47 | "implicit_inform_slots":{}, 48 | "proposed_slots":{}, 49 | "wrong_diseases":[] 50 | }, 51 | "history":[] 52 | } 53 | 54 | def set_agent(self, agent): 55 | self.agent = agent 56 | 57 | def _state_update_with_user_action(self, user_action): 58 | # Updating dialog state with user_action. 59 | self.state["user_action"] = user_action 60 | temp_action = copy.deepcopy(user_action) 61 | temp_action["current_slots"] = copy.deepcopy(self.state["current_slots"])# Save current_slots for every turn. 62 | self.state["history"].append(temp_action) 63 | for slot in user_action["request_slots"].keys(): 64 | self.state["current_slots"]["user_request_slots"][slot] = user_action["request_slots"][slot] 65 | 66 | # Inform_slots. 67 | inform_slots = list(user_action["inform_slots"].keys()) 68 | if "disease" in inform_slots and user_action["action"] == "deny": 69 | if user_action["inform_slots"]["disease"] not in self.state["current_slots"]["wrong_diseases"]: 70 | self.state["current_slots"]["wrong_diseases"].append(user_action["inform_slots"]["disease"]) 71 | if "disease" in inform_slots: inform_slots.remove("disease") 72 | for slot in inform_slots: 73 | if slot in self.user.goal["goal"]["request_slots"].keys(): 74 | self.state["current_slots"]["proposed_slots"][slot] = user_action["inform_slots"][slot] 75 | else: 76 | self.state["current_slots"]['inform_slots'][slot] = user_action["inform_slots"][slot] 77 | if slot in self.state["current_slots"]["agent_request_slots"].keys(): 78 | self.state["current_slots"]["agent_request_slots"].pop(slot) 79 | 80 | # TODO (Qianlong): explicit_inform_slots and implicit_inform_slots are handled differently. 81 | # Explicit_inform_slots. 82 | explicit_inform_slots = list(user_action["explicit_inform_slots"].keys()) 83 | if "disease" in explicit_inform_slots and user_action["action"] == "deny": 84 | if user_action["inform_slots"]["disease"] not in self.state["current_slots"]["wrong_diseases"]: 85 | self.state["current_slots"]["wrong_diseases"].append(user_action["explicit_inform_slots"]["disease"]) 86 | if "disease" in explicit_inform_slots: explicit_inform_slots.remove("disease") 87 | for slot in explicit_inform_slots: 88 | if slot in self.user.goal["goal"]["request_slots"].keys(): 89 | self.state["current_slots"]["proposed_slots"][slot] = user_action["explicit_inform_slots"][slot] 90 | else: 91 | self.state["current_slots"]["explicit_inform_slots"][slot] = user_action["explicit_inform_slots"][slot] 92 | if slot in self.state["current_slots"]["agent_request_slots"].keys(): 93 | self.state["current_slots"]["agent_request_slots"].pop(slot) 94 | # Implicit_inform_slots. 95 | implicit_inform_slots = list(user_action["implicit_inform_slots"].keys()) 96 | if "disease" in implicit_inform_slots and user_action["action"] == "deny": 97 | if user_action["inform_slots"]["disease"] not in self.state["current_slots"]["wrong_diseases"]: 98 | self.state["current_slots"]["wrong_diseases"].append(user_action["implicit_inform_slots"]["disease"]) 99 | if "disease" in implicit_inform_slots: implicit_inform_slots.remove("disease") 100 | for slot in implicit_inform_slots: 101 | if slot in self.user.goal["goal"]["request_slots"].keys(): 102 | self.state["current_slots"]["proposed_slots"][slot] = user_action["implicit_inform_slots"][slot] 103 | else: 104 | self.state["current_slots"]["implicit_inform_slots"][slot] = user_action["implicit_inform_slots"][slot] 105 | if slot in self.state["current_slots"]["agent_request_slots"].keys(): 106 | self.state["current_slots"]["agent_request_slots"].pop(slot) 107 | 108 | def _state_update_with_agent_action(self, agent_action): 109 | # Updating dialog state with agent_action. 110 | explicit_implicit_slot_value = copy.deepcopy(self.user.goal["goal"]["explicit_inform_slots"]) 111 | explicit_implicit_slot_value.update(self.user.goal["goal"]["implicit_inform_slots"]) 112 | 113 | self.state["agent_action"] = agent_action 114 | temp_action = copy.deepcopy(agent_action) 115 | temp_action["current_slots"] = copy.deepcopy(self.state["current_slots"])# save current_slots for every turn. 116 | self.state["history"].append(temp_action) 117 | # import json 118 | # print(json.dumps(agent_action, indent=2)) 119 | for slot in agent_action["request_slots"].keys(): 120 | self.state["current_slots"]["agent_request_slots"][slot] = agent_action["request_slots"][slot] 121 | 122 | # Inform slots. 123 | for slot in agent_action["inform_slots"].keys(): 124 | # The slot is come from user's goal["request_slots"] 125 | slot_value = agent_action["inform_slots"][slot] 126 | if slot in self.user.goal["goal"]["request_slots"].keys() and slot_value == self.user.goal["disease_tag"]: 127 | self.state["current_slots"]["proposed_slots"][slot] = agent_action["inform_slots"][slot] 128 | elif slot in explicit_implicit_slot_value.keys() and slot_value == explicit_implicit_slot_value[slot]: 129 | self.state["current_slots"]["inform_slots"][slot] = agent_action["inform_slots"][slot] 130 | # Remove the slot if it is in current_slots["user_request_slots"] 131 | if slot in self.state["current_slots"]["user_request_slots"].keys(): 132 | self.state["current_slots"]["user_request_slots"].pop(slot) 133 | 134 | # TODO (Qianlong): explicit_inform_slots and implicit_inform_slots are handled differently. 135 | # Explicit_inform_slots. 136 | for slot in agent_action["explicit_inform_slots"].keys(): 137 | # The slot is come from user's goal["request_slots"] 138 | slot_value = agent_action["explicit_inform_slots"][slot] 139 | if slot in self.user.goal["goal"]["request_slots"].keys() and slot_value == self.user.goal["disease_tag"]: 140 | self.state["current_slots"]["proposed_slots"][slot] = agent_action["explicit_inform_slots"][slot] 141 | elif slot in explicit_implicit_slot_value.keys() and slot_value == explicit_implicit_slot_value[slot]: 142 | self.state["current_slots"]["explicit_inform_slots"][slot] = agent_action["explicit_inform_slots"][slot] 143 | # Remove the slot if it is in current_slots["user_request_slots"] 144 | if slot in self.state["current_slots"]["user_request_slots"].keys(): 145 | self.state["current_slots"]["user_request_slots"].pop(slot) 146 | 147 | # Implicit_inform_slots. 148 | for slot in agent_action["implicit_inform_slots"].keys(): 149 | # The slot is come from user's goal["request_slots"] 150 | slot_value = agent_action["implicit_inform_slots"][slot] 151 | if slot in self.user.goal["goal"]["request_slots"].keys() and slot_value == self.user.goal["disease_tag"]: 152 | self.state["current_slots"]["proposed_slots"][slot] = agent_action["implicit_inform_slots"][slot] 153 | elif slot in explicit_implicit_slot_value.keys() and slot_value == explicit_implicit_slot_value[slot]: 154 | self.state["current_slots"]["implicit_inform_slots"][slot] = agent_action["implicit_inform_slots"][slot] 155 | # Remove the slot if it is in current_slots["user_request_slots"] 156 | if slot in self.state["current_slots"]["user_request_slots"].keys(): 157 | self.state["current_slots"]["user_request_slots"].pop(slot) --------------------------------------------------------------------------------