├── .gitignore ├── README.md ├── core ├── evaluator.py ├── game.py ├── gen_models.py ├── helpers.py ├── mcts.py └── players.py ├── data └── p4g │ └── 300_dialog_turn_based.pkl ├── interactive.py ├── outputs ├── chatgpt_raw_prompt.pkl ├── gdpzero_10sims_3rlz_0.25Q0_20dialogs.pkl ├── gdpzero_10sims_v_chatgpt.pkl ├── gdpzero_20sims_3rlz_0.0Q0_20dialogs.pkl ├── gdpzero_50sims_3rlz_0.0Q0_20dialogs.pkl └── gdpzero_5sims_3rlz_0.0Q0_20dialogs.pkl ├── runners ├── gdpzero.py ├── gdpzero_noRS.py ├── gdpzero_noopenloop.py └── raw_prompting.py ├── test.py └── utils ├── prompt_examples.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | 4 | *.ipynb 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GDP-Zero 2 | 3 | This repository contains code for the EMNLP'2023 paper "[Prompt-Based Monte-Carlo Tree Search for Goal-Oriented Dialogue Policy Planning](https://arxiv.org/abs/2305.13660)". 4 | 5 | ## Prerequisites 6 | 7 | 1. **OPENAI API KEYS**: this project relies on prompting LLM to perform dialogue simulations 8 | ```bash 9 | # for OpenAI users 10 | export OPENAI_API_KEY=sk-xxxx 11 | # for MS Azure users 12 | export MS_OPENAI_API_KEY=xxxx 13 | export MS_OPENAI_API_BASE="https://xxx.com" 14 | export MS_OPENAI_API_VERSION="xxx" 15 | export MS_OPENAI_API_CHAT_VERSION="xxx" 16 | ``` 17 | 2. Before executing any of the scripts, make sure to **add the project to the PYTHONPATH environment variable**: 18 | ```bash 19 | > ~/GDPZero$ export PYTHONPATH=$(pwd) 20 | ``` 21 | 22 | ## Interactive Demo 23 | You can converse with both PDP-Zero planning and raw-prompting based planning using the `interactive.py` script. **We note that its simulation speed is heavily dependent on OpenAI API's speed.** 24 | 25 | The default option is to use PDP-Zero as the planning algorithm: 26 | ```bash 27 | ~/GDPZero$ python interactive.py 28 | using GDPZero as planning algorithm 29 | You are now the Persuadee. Type 'q' to quit, and 'r' to restart. 30 | Persuader: Hello. How are you? 31 | You: Hi, I am good. What about you? 32 | 100%|██████████████████| 10/10 [00:32<00:00, 3.17s/it] 33 | Persuader: I'm doing well, thank you. I was just wondering if you've heard of the charity called Save the Children? 34 | You: No I have not. What does this charity do? 35 | 100%|██████████████████| 10/10 [00:37<00:00, 3.69s/it] 36 | Persuader: Save the Children is an organization that helps children in developing countries by providing relief and promoting children's rights. It's a great charity that makes a positive impact on so many children's lives. They help with things like education, health care, and safety. 37 | You: 38 | ``` 39 | in the above example, PDP-Zero performs a tree search with `n=10` simulations and `k=3` realizations per state. You can change these parameters using the `--num_mcts_sims` and `--max_realizations` flags, respectively. See `interactive.py -h` and the [Experiments](#experiments) section for more details. 40 | ```bash 41 | ~/GDPZero$ python interactive.py -h 42 | optional arguments: 43 | -h, --help show this help message and exit 44 | --log {20,10,30} logging mode 45 | --algo {gdpzero,raw-prompt} 46 | planning algorithm 47 | --llm {code-davinci-002,gpt-3.5-turbo,text-davinci-002,chatgpt} 48 | OpenAI model name 49 | --gen_sentences GEN_SENTENCES 50 | number of sentences to generate from the llm. Longer ones will be truncated by nltk. 51 | --num_mcts_sims NUM_MCTS_SIMS 52 | number of mcts simulations 53 | --max_realizations MAX_REALIZATIONS 54 | number of realizations per mcts state 55 | --Q_0 Q_0 initial Q value for unitialized states. to control exploration 56 | ``` 57 | 58 | ## Experiments 59 | 60 | We mainly test PDP-Zero on the [PersuasionForGood](https://arxiv.org/abs/1906.06725) dataset. The scripts below will take the first 20 dialogues from the dataset, and perform planning/response generation for each turn. The output is a pickle file containing the generated responses and the corresponding contexts. This output pickle file is then used for evaluation (see [Static Evaluation](#static-evaluation)). 61 | 62 | *PDP-Zero*: 63 | ```bash 64 | > ~/GDPZero$ python runners/gdpzero.py -h 65 | optional arguments: 66 | -h, --help show this help message and exit 67 | --output OUTPUT output file 68 | --llm {code-davinci-002,chatgpt,gpt-3.5-turbo} 69 | OpenAI model name 70 | --gen_sentences GEN_SENTENCES 71 | number of sentences to generate from the llm. Longer ones will be truncated by nltk. 72 | --num_mcts_sims NUM_MCTS_SIMS 73 | number of mcts simulations 74 | --max_realizations MAX_REALIZATIONS 75 | number of realizations per mcts state 76 | --Q_0 Q_0 initial Q value for unitialized states. to control exploration 77 | --num_dialogs NUM_DIALOGS 78 | number of dialogs to test MCTS on 79 | --debug debug mode 80 | ``` 81 | for example, using `gpt-3.5-turbo` as backbone with `n=10` simulations, `k=3` realizations per state, and `Q_0=0.25` for exploration, do: 82 | ```bash 83 | > python runners/gdpzero.py --output outputs/gdpzero.pkl --llm gpt-3.5-turbo --num_mcts_sims 10 --max_realizations 3 --Q_0 0.25 84 | ``` 85 | 86 | *Baseline*: 87 | ```bash 88 | > ~/GDPZero$ python runners/raw_prompting.py -h 89 | optional arguments: 90 | -h, --help show this help message and exit 91 | --llm {code-davinci-002,gpt-3.5-turbo,chatgpt} 92 | OpenAI model name 93 | --gen_sentences GEN_SENTENCES 94 | max number of sentences to generate. -1 for no limit 95 | --output OUTPUT output file 96 | ``` 97 | for example, using `gpt-3.5-turbo` as backbone, do 98 | ```bash 99 | > python runners/raw_prompting.py --output outputs/chatgpt_raw_prompt.pkl --llm gpt-3.5-turbo 100 | ``` 101 | 102 | *Ablations*: 103 | ```bash 104 | # without OpenLoop 105 | ~/GDPZero$ python runners/gdpzero_noopenloop.py -h 106 | optional arguments: 107 | -h, --help show this help message and exit 108 | --output OUTPUT output file 109 | --llm {code-davinci-002,gpt-3.5-turbo,chatgpt} 110 | OpenAI model name 111 | --gen_sentences GEN_SENTENCES 112 | max number of sentences to generate 113 | --num_mcts_sims NUM_MCTS_SIMS 114 | number of mcts simulations 115 | ``` 116 | ```bash 117 | # without response selection 118 | ~/GDPZero$ python runners/gdpzero_noRS.py -h 119 | optional arguments: 120 | -h, --help show this help message and exit 121 | --output OUTPUT output file 122 | --llm {code-davinci-002,gpt-3.5-turbo,chatgpt} 123 | OpenAI model name 124 | --gen_sentences GEN_SENTENCES 125 | max number of sentences to generate 126 | --num_mcts_sims NUM_MCTS_SIMS 127 | number of mcts simulations 128 | --max_realizations MAX_REALIZATIONS 129 | number of realizations per mcts state 130 | --Q_0 Q_0 initial Q value for unitialized states. to control exploration 131 | ``` 132 | where most of the arguments are the same ones in `gdpzero.py`. 133 | 134 | ## Static Evaluation 135 | We mainly use `gpt-3.5-turbo` as the judge for static evaluation. To evaluate the planned dialogues from the [Experiments](Experiments) section, use the `test.py` script which prompts ChatGPT to compare the responses between either human demonstrations in P4G or against some generated responses: 136 | ```bash 137 | > ~/GDPZero$ python test.py -h 138 | optional arguments: 139 | -h, --help show this help message and exit 140 | -f F path to the data file for comparing against human in p4g. See P4GEvaluator documentation to see the format of the file. 141 | --judge {gpt-3.5-turbo,chatgpt} 142 | which judge to use. 143 | --h2h H2H path to the data file for head to head comparison. If empty compare against human in p4g. 144 | --output OUTPUT output file 145 | --debug debug mode 146 | ``` 147 | For example to compare `outputs/gdpzero_50sims_3rlz_0.25Q0_20dialogs.pkl` 148 | - against human demonstration 149 | ```bash 150 | > ~/GDPZero$ python test.py -f outputs/gdpzero_50sims_3rlz_20dialogs.pkl --output eval.pkl --judge gpt-3.5-turbo 151 | evaluating: 100%|███████████████| 154/154 [03:49<00:00, 1.49s/it] 152 | win rate: 93.51% 153 | stats: {'win': 144, 'draw': 0, 'lose': 10} 154 | ``` 155 | - head-to-head comparison against ChatGPT generated responses (e.g. `outputs/chatgpt_raw_prompt.pkl`, see [Experiments](#experiments) section for more details) 156 | ```bash 157 | > ~/GDPZero$ python test.py -f outputs/gdpzero_50sims_3rlz_20dialogs.pkl --h2h outputs/chatgpt_raw_prompt.pkl --output eval.pkl --judge gpt-3.5-turbo 158 | evaluating: 100%|███████████████| 154/154 [03:29<00:00, 1.36s/it] 159 | win rate: 59.09% 160 | stats: {'win': 91, 'draw': 2, 'lose': 61} 161 | ``` 162 | 163 | ## Examples 164 | 165 | We provided some example generations in the `output` directory. For instance: 166 | ```bash 167 | output 168 | ├── chatgpt_raw_prompt.pkl # chatgpt baseline 169 | ├── gdpzero_10sims_3rlz_0.25Q0_20dialogs.pkl # gdp-zero with n=10, k=3, Q_0=0.25 170 | ├── gdpzero_10sims_v_chatgpt.pkl # evaluation result of gdp-zero against chatgpt 171 | ├── gdpzero_20sims_3rlz_0.0Q0_20dialogs.pkl 172 | ├── gdpzero_50sims_3rlz_0.0Q0_20dialogs.pkl 173 | └── gdpzero_5sims_3rlz_0.0Q0_20dialogs.pkl 174 | ``` -------------------------------------------------------------------------------- /core/evaluator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import List 6 | from core.gen_models import GenerationModel 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class RespRanker(ABC): 13 | @abstractmethod 14 | def evaluate(self, context, resp_a, resp_b): 15 | """ 16 | Compare two responses and return the preference. 17 | """ 18 | raise NotImplementedError 19 | 20 | 21 | class P4GEvaluator(RespRanker): 22 | def __init__(self, gen_model: GenerationModel): 23 | super().__init__() 24 | self.gen_model = gen_model 25 | self.inference_args = { 26 | "max_tokens": 2, 27 | "temperature": 0.7, 28 | "echo": False, 29 | "n": 5, 30 | "stop": "" 31 | } 32 | 33 | def evaluate(self, context, resp_a, resp_b): 34 | do_swap = False 35 | if random.random() < 0.5: 36 | do_swap = True 37 | resp_a, resp_b = resp_b, resp_a 38 | prompt = f""" 39 | The following is background information about Save the Children. 40 | Save the Children is head-quartered in London, and they work to help fight poverty around the world. Children need help in developing countries and war zones. Small donations like $1 or $2 go a long way to help. 41 | The following is a conversation between a Persuader and a Persuadee about a charity called Save the Children. The Persuader is trying to persuade the Persuadee to donate to Save the Children. 42 | {context} 43 | Which of the following responses can better help the Persuader convince the Persuadee to donate to Save the Children? 44 | A. Persuader: {resp_a} 45 | B. Persuader: {resp_b} 46 | C. Can't tell. 47 | Your can choose from either A, B, or C. 48 | Your choice: 49 | """.replace('\t', '').strip() 50 | logger.debug(f"prompt: {prompt}") 51 | resps = self.gen_model.generate(prompt, **self.inference_args) 52 | choices, rationales = self._process_resps(resps) 53 | preference = self._majority_vote(choices, do_swap) 54 | return preference, {'choices': choices, 'rationales': rationales, 'do_swap': do_swap} 55 | 56 | def _process_resps(self, resps:List[dict]): 57 | choices = [] 58 | rationales = [] 59 | for resp in resps: 60 | gen = resp['generated_text'].strip() 61 | 62 | if len(gen) == 0: 63 | print("Empty response") 64 | choice = 'c' 65 | else: 66 | choice = gen[0].lower() 67 | 68 | if choice not in ['a', 'b', 'c']: 69 | print(f"Invalid choice: {choice}") 70 | choice = 'c' 71 | choices.append(choice) 72 | # see if there is a rationale # just dump the entire response 73 | rationale = gen 74 | rationales.append(rationale) 75 | return choices, rationales 76 | 77 | def _majority_vote(self, resps:List[str], do_swap=False): 78 | # if there is a majority vote between A=0 and B=1, return the majority vote 79 | # otherwise, return C=2 80 | a_cnt = 0 81 | b_cnt = 0 82 | for resp in resps: 83 | if resp == 'a': 84 | a_cnt += 1 85 | elif resp == 'b': 86 | b_cnt += 1 87 | if a_cnt > b_cnt: 88 | return 0 if not do_swap else 1 89 | elif b_cnt > a_cnt: 90 | return 1 if not do_swap else 0 91 | return 2 -------------------------------------------------------------------------------- /core/game.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | from core.gen_models import DialogModel 5 | from core.helpers import DialogSession 6 | from abc import ABC, abstractmethod 7 | from typing import List 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class DialogGame(ABC): 14 | def __init__(self, 15 | system_name:str, system_agent:DialogModel, 16 | user_name: str, user_agent:DialogModel): 17 | self.SYS = system_name 18 | self.system_agent = system_agent 19 | self.USR = user_name 20 | self.user_agent = user_agent 21 | return 22 | 23 | @staticmethod 24 | @abstractmethod 25 | def get_game_ontology() -> dict: 26 | """returns game related information such as dialog acts, slots, etc. 27 | """ 28 | raise NotImplementedError 29 | 30 | def init_dialog(self) -> DialogSession: 31 | # [(sys_act, sys_utt, user_act, user_utt), ...] 32 | return DialogSession(self.SYS, self.USR) 33 | 34 | def get_next_state(self, state:DialogSession, action) -> DialogSession: 35 | next_state = state.copy() 36 | 37 | sys_utt = self.system_agent.get_utterance(next_state, action) # action is DA 38 | sys_da = self.system_agent.dialog_acts[action] 39 | next_state.add_single(state.SYS, sys_da, sys_utt) 40 | 41 | # state in user's perspective 42 | user_da, user_resp = self.user_agent.get_utterance_w_da(next_state, None) # user just reply 43 | next_state.add_single(state.USR, user_da, user_resp) 44 | return next_state 45 | 46 | def get_next_state_batched(self, state:DialogSession, action, batch=3) -> List[DialogSession]: 47 | all_next_states = [state.copy() for _ in range(batch)] 48 | 49 | sys_utts = self.system_agent.get_utterance_batched(state.copy(), action, batch) # action is DA 50 | sys_da = self.system_agent.dialog_acts[action] 51 | for i in range(batch): 52 | all_next_states[i].add_single(state.SYS, sys_da, sys_utts[i]) 53 | 54 | # state in user's perspective 55 | user_das, user_resps = self.user_agent.get_utterance_w_da_from_batched_states(all_next_states, None) # user just reply 56 | for i in range(batch): 57 | all_next_states[i].add_single(state.USR, user_das[i], user_resps[i]) 58 | return all_next_states 59 | 60 | def display(self, state:DialogSession): 61 | string_rep = state.to_string_rep(keep_sys_da=True, keep_user_da=True) 62 | print(string_rep) 63 | return 64 | 65 | @abstractmethod 66 | def get_dialog_ended(self, state) -> float: 67 | """returns 0 if not ended, then (in general) 1 if system success, -1 if failure 68 | """ 69 | raise NotImplementedError 70 | 71 | 72 | class PersuasionGame(DialogGame): 73 | SYS = "Persuader" 74 | USR = "Persuadee" 75 | 76 | S_PersonalStory = "personal story" 77 | S_CredibilityAppeal = "credibility appeal" 78 | S_EmotionAppeal = "emotion appeal" 79 | S_PropositionOfDonation = "proposition of donation" 80 | S_FootInTheDoor = "foot in the door" 81 | S_LogicalAppeal = "logical appeal" 82 | S_SelfModeling = "self modeling" 83 | S_TaskRelatedInquiry = "task related inquiry" 84 | S_SourceRelatedInquiry = "source related inquiry" 85 | S_PersonalRelatedInquiry = "personal related inquiry" 86 | S_NeutralToInquiry = "neutral to inquiry" 87 | S_Greeting = "greeting" 88 | S_Other = "other" 89 | 90 | U_NoDonation = "no donation" 91 | U_NegativeReaction = "negative reaction" 92 | U_Neutral = "neutral" 93 | U_PositiveReaction = "positive reaction" 94 | U_Donate = "donate" 95 | 96 | def __init__(self, system_agent:DialogModel, user_agent:DialogModel, 97 | max_conv_turns=15): 98 | super().__init__(PersuasionGame.SYS, system_agent, PersuasionGame.USR, user_agent) 99 | self.max_conv_turns = max_conv_turns 100 | return 101 | 102 | @staticmethod 103 | def get_game_ontology() -> dict: 104 | return { 105 | "system": { 106 | "dialog_acts": [ 107 | PersuasionGame.S_PersonalStory, PersuasionGame.S_CredibilityAppeal, PersuasionGame.S_EmotionAppeal, 108 | PersuasionGame.S_PropositionOfDonation, PersuasionGame.S_FootInTheDoor, PersuasionGame.S_LogicalAppeal, 109 | PersuasionGame.S_SelfModeling, PersuasionGame.S_TaskRelatedInquiry, PersuasionGame.S_SourceRelatedInquiry, 110 | PersuasionGame.S_PersonalRelatedInquiry, PersuasionGame.S_NeutralToInquiry, PersuasionGame.S_Greeting, 111 | PersuasionGame.S_Other 112 | ], 113 | }, 114 | "user": { 115 | "dialog_acts": [ 116 | PersuasionGame.U_NoDonation, PersuasionGame.U_NegativeReaction, PersuasionGame.U_Neutral, 117 | PersuasionGame.U_PositiveReaction, PersuasionGame.U_Donate 118 | ] 119 | } 120 | } 121 | 122 | def get_dialog_ended(self, state) -> float: 123 | # terminate if there is a action in persudee resp 124 | # allow only 10 turns 125 | if len(state) >= self.max_conv_turns: 126 | logger.info("Dialog ended with persuasion failure") 127 | return -1.0 128 | for (_, da, _) in state: 129 | if da == PersuasionGame.U_Donate: 130 | logger.info("Dialog ended with donate") 131 | return 1.0 132 | if da == PersuasionGame.U_NoDonation: 133 | logger.info("Dialog ended with no-donation") 134 | return -1.0 135 | return 0.0 136 | 137 | 138 | class EmotionalSupportGame(PersuasionGame): 139 | pass -------------------------------------------------------------------------------- /core/gen_models.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import logging 3 | import torch 4 | import openai 5 | import os 6 | import multiprocessing as mp 7 | import nltk 8 | 9 | from abc import ABC, abstractmethod 10 | from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed 11 | from typing import List, Tuple, Dict 12 | from core.helpers import DialogSession 13 | from functools import lru_cache 14 | from tenacity import retry, stop_after_attempt, wait_exponential, wait_fixed # for exponential backoff 15 | from utils.utils import hashabledict 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class GenerationModel(ABC): 22 | # used to generate text in general. e.g. could be using API, or local model 23 | @abstractmethod 24 | def generate(self, input_text, **gen_args): 25 | """ 26 | Generate text from the model. 27 | """ 28 | raise NotImplementedError 29 | 30 | def chat_generate(self, messages, **gen_args): 31 | """ 32 | Generate text from the model. Used for chatbot. 33 | """ 34 | raise NotImplementedError 35 | 36 | def chat_generate_batched(self, messages_list, **gen_args): 37 | """ 38 | Generate text from the model when you have multiple message histories 39 | """ 40 | raise NotImplementedError 41 | 42 | def _cleaned_resp(self, data, prompt) -> "List[str]": 43 | # default helper function to clean extract the generated text from the returned json 44 | logger.debug("promopt:") 45 | logger.debug(prompt) 46 | cleaned_resps = [] 47 | for gen_resp in data: 48 | logger.debug("raw response:") 49 | logger.debug(gen_resp['generated_text']) 50 | cleaned_resp = gen_resp['generated_text'].strip() 51 | if "\n" in cleaned_resp: 52 | cleaned_resp = cleaned_resp[:cleaned_resp.index("\n")] 53 | logger.debug(f"cleaned response: {cleaned_resp}") 54 | cleaned_resps.append(cleaned_resp) 55 | return cleaned_resps 56 | 57 | def _cleaned_chat_resp(self, data, assistant_role="Persuader:", user_role="Persuadee:") -> "List[str]": 58 | # remove the user_role and keep the assistant_role 59 | # default helper function to clean extract the generated text from the returned json 60 | cleaned_resps = [] 61 | for gen_resp in data: 62 | logger.debug("raw response:") 63 | logger.debug(gen_resp['generated_text']) 64 | cleaned_resp = gen_resp['generated_text'].strip() 65 | if "\n" in cleaned_resp: 66 | cleaned_resp = cleaned_resp[:cleaned_resp.index("\n")] 67 | if assistant_role in cleaned_resp: 68 | cleaned_resp = cleaned_resp[cleaned_resp.index(assistant_role) + len(assistant_role):].strip() 69 | if user_role in cleaned_resp: 70 | cleaned_resp = cleaned_resp[:cleaned_resp.index(user_role)].strip() 71 | logger.debug(f"cleaned response: {cleaned_resp}") 72 | cleaned_resps.append(cleaned_resp) 73 | return cleaned_resps 74 | 75 | 76 | class DialogModel(ABC): 77 | # used to play DialogGame 78 | def __init__(self): 79 | self.dialog_acts = [] 80 | return 81 | 82 | @abstractmethod 83 | def get_utterance(self, state:DialogSession, action) -> str: 84 | raise NotImplementedError 85 | 86 | def get_utterance_batched(self, state:DialogSession, action:int, batch:int) -> List[str]: 87 | raise NotImplementedError 88 | 89 | @abstractmethod 90 | def get_utterance_w_da(self, state:DialogSession, action) -> Tuple[str, str]: 91 | # this is used for user agent. should not be used for system agent 92 | raise NotImplementedError 93 | 94 | def get_utterance_w_da_from_batched_states(self, states:List[DialogSession], action=None): 95 | # this is used for user agent. should not be used for system agent 96 | raise NotImplementedError 97 | 98 | 99 | 100 | class APIModel(GenerationModel): 101 | API_TOKEN = os.environ.get("HF_API_KEY") 102 | 103 | def __init__(self): 104 | # self.API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B" 105 | self.API_URL = "https://api-inference.huggingface.co/models/gpt2-large" 106 | self.headers: dict[str, str] = {"Authorization": f"Bearer {APIModel.API_TOKEN}"} 107 | self.inference_args = { 108 | "max_new_tokens": 100, 109 | "temperature": 0.7, 110 | "repetition_penalty": 1.2, 111 | "return_full_text": False 112 | } 113 | return 114 | 115 | def generate(self, input_text, **_args): 116 | data = { 117 | "inputs": input_text, 118 | "parameters": _args or self.inference_args 119 | } 120 | response = requests.post(self.API_URL, headers=self.headers, json=data) 121 | return response.json() 122 | 123 | 124 | class OpenAIModel(GenerationModel): 125 | API_TOKEN = os.environ.get("OPENAI_API_KEY") 126 | 127 | def __init__(self, model_name="text-curie-001"): 128 | # check if model exists 129 | openai.api_key = OpenAIModel.API_TOKEN 130 | models = openai.Engine.list() 131 | if model_name not in [model.id for model in models.data]: 132 | raise ValueError(f"model {model_name} not found") 133 | 134 | self.inference_args = { 135 | "model": model_name, 136 | "max_tokens": 64, 137 | "temperature": 0.7, 138 | "echo": False, 139 | "n": 1, 140 | "stop": "\n" 141 | } 142 | return 143 | 144 | def _update_args(self, new_args): 145 | args = {**self.inference_args} 146 | from_cache = False 147 | if "max_new_tokens" in new_args: 148 | new_args["max_tokens"] = new_args.pop("max_new_tokens") 149 | if "return_full_text" in new_args: 150 | new_args["echo"] = new_args.pop("return_full_text") 151 | if "do_sample" in new_args: 152 | from_cache = not new_args.pop("do_sample") # rely on caching 153 | if "num_return_sequences" in new_args: 154 | new_args["n"] = new_args.pop("num_return_sequences") 155 | if "repetition_penalty" in new_args: 156 | new_args["frequency_penalty"] = new_args.pop("repetition_penalty") 157 | return from_cache, {**args, **new_args} 158 | 159 | @lru_cache(maxsize=None) 160 | def _cached_generate(**parameters): 161 | response = openai.Completion.create(**parameters) 162 | return response 163 | 164 | # tried custom implementation of waiting before request, but I think openai is lying about how it calculates the rate limit 165 | # takes 3 trials to reach 2^3=8. Then 7 * 8 = 56 sec max. Just to safe we wait a bit more than 10 times 166 | @retry(wait=wait_exponential(multiplier=2, min=2, max=8), stop=stop_after_attempt(15)) 167 | def generate(self, input_text, **_args): 168 | from_cache, parameters = self._update_args(_args) 169 | parameters["prompt"] = input_text 170 | if from_cache: 171 | response = OpenAIModel._cached_generate(**parameters) 172 | else: 173 | response = openai.Completion.create(**parameters) 174 | 175 | # format to a common format 176 | gen_output = [] 177 | for resp in response.choices: 178 | text = resp.text 179 | gen_output.append({"generated_text": text}) 180 | return gen_output 181 | 182 | 183 | class OpenAIChatModel(OpenAIModel): 184 | def __init__(self, model_name="gpt-3.5-turbo", gen_sentences=-1): 185 | # check if model exists 186 | openai.api_key = self.API_TOKEN 187 | 188 | self.inference_args = { 189 | "model": model_name, 190 | "max_tokens": 64, 191 | "temperature": 0.7, 192 | "n": 1, 193 | # "stop": "\n" # no longer need since we are using chat 194 | # "echo": False, 195 | } 196 | self.gen_sentences = None if gen_sentences < 0 else gen_sentences 197 | return 198 | 199 | def _update_args(self, new_args): 200 | if "stop" in new_args: 201 | new_args.pop("stop") 202 | if "echo" in new_args: 203 | new_args.pop("echo") 204 | if "return_full_text" in new_args: 205 | new_args.pop("return_full_text") 206 | return super()._update_args(new_args) 207 | 208 | def generate(self, input_text, **_args): 209 | logging.info("It is recommended to use chat_generate instead of generate for OpenAIChatModel") 210 | messages = [{ 211 | "role": "user", 212 | "content": input_text 213 | }] 214 | return self.chat_generate(messages, **_args) 215 | 216 | @lru_cache(maxsize=None) 217 | def _cached_generate(**parameters): 218 | parameters["messages"] = list(parameters["messages"]) 219 | response = openai.ChatCompletion.create(**parameters) 220 | return response 221 | 222 | @retry(wait=wait_exponential(multiplier=2, min=2, max=8), stop=stop_after_attempt(15)) 223 | def chat_generate(self, messages: List[Dict], **gen_args): 224 | # generate in a chat format 225 | from_cache, parameters = self._update_args(gen_args) 226 | hashable_messages = [hashabledict(m) for m in messages] 227 | parameters["messages"] = hashable_messages 228 | if from_cache: 229 | parameters["messages"] = tuple(hashable_messages) # list cannot be hashed, so cannot do **parameters 230 | response = OpenAIChatModel._cached_generate(**parameters) 231 | else: 232 | response = openai.ChatCompletion.create(**parameters) 233 | 234 | # format to a common format 235 | gen_output = [] 236 | for resp in response.choices: 237 | text = resp['message']['content'] 238 | if self.gen_sentences is not None: 239 | sentences = nltk.sent_tokenize(text) 240 | if len(sentences) > self.gen_sentences: 241 | text = " ".join(sentences[:self.gen_sentences]) 242 | gen_output.append({"generated_text": text}) 243 | return gen_output 244 | 245 | def chat_generate_batched(self, messages_list: List[List[Dict]], **gen_args): 246 | pool = mp.Pool(processes=len(messages_list)) 247 | results = [] 248 | for messages in messages_list: 249 | results.append(pool.apply_async(self.chat_generate, args=(messages,), kwds=gen_args)) 250 | pool.close() 251 | pool.join() 252 | return [r.get() for r in results] 253 | 254 | 255 | class AzureOpenAIModel(OpenAIModel): 256 | API_TOKEN = os.environ.get("MS_OPENAI_API_KEY") 257 | API_BASE = os.environ.get("MS_OPENAI_API_BASE") 258 | API_TYPE = "azure" 259 | API_VERSION = "2022-12-01" 260 | 261 | def __init__(self, model_name="chatgpt-turbo"): 262 | # check if model exists 263 | openai.api_key = AzureOpenAIModel.API_TOKEN 264 | openai.api_base = AzureOpenAIModel.API_BASE 265 | openai.api_type = AzureOpenAIModel.API_TYPE 266 | openai.api_version = AzureOpenAIModel.API_VERSION 267 | 268 | self.inference_args = { 269 | "engine": model_name, 270 | "max_tokens": 64, 271 | "temperature": 0.7, 272 | "echo": False, 273 | "n": 1, 274 | "stop": "\n" 275 | } 276 | return 277 | 278 | 279 | class AzureOpenAIChatModel(AzureOpenAIModel): 280 | def __init__(self, model_name="chatgpt", gen_sentences=-1): 281 | # check if model exists 282 | openai.api_key = self.API_TOKEN 283 | openai.api_base = self.API_BASE 284 | openai.api_type = self.API_TYPE 285 | openai.api_version = "2023-03-15-preview" 286 | 287 | self.inference_args = { 288 | "engine": model_name, 289 | "max_tokens": 64, 290 | "temperature": 0.7, 291 | "n": 1, 292 | # "stop": "\n" # no longer need since we are using chat 293 | # "echo": False, 294 | } 295 | self.gen_sentences = None if gen_sentences < 0 else gen_sentences 296 | return 297 | 298 | def _update_args(self, new_args): 299 | if "stop" in new_args: 300 | new_args.pop("stop") 301 | if "echo" in new_args: 302 | new_args.pop("echo") 303 | if "return_full_text" in new_args: 304 | new_args.pop("return_full_text") 305 | return super()._update_args(new_args) 306 | 307 | @lru_cache(maxsize=None) 308 | def _cached_generate(**parameters): 309 | parameters["messages"] = list(parameters["messages"]) 310 | response = openai.ChatCompletion.create(**parameters) 311 | return response 312 | 313 | @retry(wait=wait_exponential(multiplier=2, min=2, max=8), stop=stop_after_attempt(15)) 314 | def chat_generate(self, messages: List[Dict], **gen_args): 315 | # generate in a chat format 316 | from_cache, parameters = self._update_args(gen_args) 317 | hashable_messages = [hashabledict(m) for m in messages] 318 | parameters["messages"] = hashable_messages 319 | if from_cache: 320 | parameters["messages"] = tuple(hashable_messages) # list cannot be hashed, so cannot do **parameters 321 | response = AzureOpenAIChatModel._cached_generate(**parameters) 322 | else: 323 | response = openai.ChatCompletion.create(**parameters) 324 | 325 | # format to a common format 326 | gen_output = [] 327 | for resp in response.choices: 328 | text = resp['message']['content'] 329 | if self.gen_sentences is not None: 330 | sentences = nltk.sent_tokenize(text) 331 | if len(sentences) > self.gen_sentences: 332 | text = " ".join(sentences[:self.gen_sentences]) 333 | gen_output.append({"generated_text": text}) 334 | return gen_output 335 | 336 | def chat_generate_batched(self, messages_list: List[List[Dict]], **gen_args): 337 | pool = mp.Pool(processes=len(messages_list)) 338 | results = [] 339 | for messages in messages_list: 340 | results.append(pool.apply_async(self.chat_generate, args=(messages,), kwds=gen_args)) 341 | pool.close() 342 | pool.join() 343 | return [r.get() for r in results] 344 | 345 | def generate(self, input_text, **_args): 346 | messages = [{ 347 | "role": "user", 348 | "content": input_text 349 | }] 350 | return self.chat_generate(messages, **_args) 351 | 352 | 353 | class LocalModel(GenerationModel): 354 | def __init__(self, model_name="EleutherAI/gpt-neo-2.7B", input_max_len=512, stop_symbol="\n", cuda=True): 355 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side="left") 356 | self.model = AutoModelForCausalLM.from_pretrained(model_name) 357 | stop_token_ids = self.tokenizer.encode(stop_symbol)[0] 358 | set_seed(42) 359 | if cuda and torch.cuda.is_available(): 360 | self.cuda = True 361 | self.model = self.model.cuda() 362 | else: 363 | self.cuda = False 364 | 365 | self.input_max_len = input_max_len 366 | self.inference_args = { 367 | "max_new_tokens": 128, 368 | "temperature": 0.7, 369 | "repetition_penalty": 1.0, 370 | "eos_token_id": stop_token_ids, 371 | "pad_token_id": self.tokenizer.eos_token_id 372 | # "return_full_text": False # not available for manual generation 373 | } 374 | 375 | def generate(self, input_text:str, **gen_args): 376 | # override if gen_args specified 377 | gen_params = {**self.inference_args, **gen_args} 378 | inputs = self.tokenizer([input_text], return_tensors='pt', truncation=True, max_length=self.input_max_len) 379 | if self.cuda: 380 | inputs = {k: v.cuda() for k, v in inputs.items()} 381 | 382 | outputs = self.model.generate(**inputs, **gen_params) 383 | gen_only_outputs = outputs[:, len(inputs['input_ids'][0]):] 384 | gen_resps = self.tokenizer.batch_decode(gen_only_outputs, skip_special_tokens=True) 385 | 386 | # format output 387 | gen_output = [] 388 | for resp in gen_resps: 389 | gen_output.append({"generated_text": resp}) 390 | return gen_output -------------------------------------------------------------------------------- /core/helpers.py: -------------------------------------------------------------------------------- 1 | class DialogSession(): 2 | def __init__(self, sys_name, user_name) -> None: 3 | self.SYS = sys_name 4 | self.USR = user_name 5 | self.history: list = [] # [(role, da, utt), ....] 6 | return 7 | 8 | def from_history(self, history): 9 | self.history = history 10 | return self 11 | 12 | def to_string_rep(self, keep_sys_da=False, keep_user_da=False, max_turn_to_display=-1): 13 | history = "" 14 | num_turns_to_truncate = 0 15 | if max_turn_to_display > 0: 16 | num_turns_to_truncate = max(0, len(self.history) // 2 - max_turn_to_display) 17 | 18 | for i, (role, da, utt) in enumerate(self.history): 19 | if (i // 2) < num_turns_to_truncate: 20 | continue 21 | if i % 2 == 0: 22 | assert(role == self.SYS) 23 | if keep_sys_da: 24 | history += f"{role}: [{da}] {utt}\n" 25 | else: 26 | history += f"{role}: {utt}\n" 27 | else: 28 | assert(role == self.USR) 29 | if keep_user_da: 30 | history += f"{role}: [{da}] {utt}\n" 31 | else: 32 | history += f"{role}: {utt}\n" 33 | return history.strip() 34 | 35 | def copy(self): 36 | new_session = DialogSession(self.SYS, self.USR) 37 | new_session.from_history(self.history.copy()) 38 | return new_session 39 | 40 | def add_single(self, role, da, utt): 41 | if len(self.history) % 2 == 0: 42 | assert(role == self.SYS) 43 | else: 44 | assert(role == self.USR) 45 | self.history.append((role, da, utt)) 46 | return 47 | 48 | def get_turn_utt(self, turn, role): 49 | if role == self.SYS: 50 | return self.history[turn * 2][-1] 51 | else: 52 | return self.history[turn * 2 + 1][-1] 53 | 54 | def __iter__(self): 55 | return iter(self.history) 56 | 57 | def __len__(self): 58 | return len(self.history) // 2 # number of turns 59 | 60 | def __getitem__(self, index): 61 | return self.history[index] 62 | 63 | def __eq__(self, __o: object) -> bool: 64 | if not isinstance(__o, DialogSession): 65 | return False 66 | return self.history == __o.history -------------------------------------------------------------------------------- /core/mcts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import math 4 | 5 | from core.helpers import DialogSession 6 | from core.game import DialogGame 7 | from core.players import DialogPlanner 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class MCTS(): 14 | def __init__(self, game:DialogGame, player:DialogPlanner, configs) -> None: 15 | self.game = game 16 | self.player = player 17 | self.configs = configs 18 | # U(s,a) = Q(s,a) + c * P(s,a) * (\sqrt{ \sum_{a'} N(s,a')}) / (1+N(s,a)) 19 | self.Ns: dict = {} # saves compute 20 | self.Nsa: dict = {} 21 | self.Q: dict = {} 22 | self.P: dict = {} 23 | # utility 24 | self.valid_moves: dict = {} 25 | self.terminals: dict = {} 26 | # debugging / more information 27 | self.Vs: dict = {} 28 | return 29 | 30 | def _to_string_rep(self, state:DialogSession): 31 | # for tree search, keep all dialog turns 32 | return state.to_string_rep(keep_sys_da=True, keep_user_da=True, max_turn_to_display=-1) 33 | 34 | def _init_node(self, state:DialogSession): 35 | hashable_state = self._to_string_rep(state) 36 | allowed_actions = self.player.get_valid_moves(state) 37 | self.valid_moves[hashable_state] = allowed_actions.nonzero()[0] 38 | 39 | self.Ns[hashable_state] = 0 40 | self.Nsa[hashable_state] = {action: 0 for action in self.valid_moves[hashable_state]} 41 | self.Q[hashable_state] = {action: self.configs.Q_0 for action in self.valid_moves[hashable_state]} 42 | 43 | prior, v = self.player.predict(state) 44 | self.Vs[state.to_string_rep(keep_sys_da=True, keep_user_da=True)] = v # for debugging 45 | self.P[hashable_state] = prior * allowed_actions 46 | # renormalize 47 | if np.sum(self.P[hashable_state]) == 0: 48 | self.P[hashable_state] = allowed_actions / np.sum(allowed_actions) 49 | logger.warning("This should never happen") 50 | else: 51 | self.P[hashable_state] /= np.sum(self.P[hashable_state]) 52 | return v 53 | 54 | def search(self, state:DialogSession): 55 | hashable_state = self._to_string_rep(state) 56 | 57 | is_leaf_node = False 58 | v = 0.0 59 | if hashable_state not in self.terminals: 60 | # selected leaf node, expand 61 | self.terminals[hashable_state] = self.game.get_dialog_ended(state) 62 | v = self._init_node(state) 63 | is_leaf_node = True 64 | # if this leaf node is terminal, return the value 65 | if self.terminals[hashable_state] > 0: 66 | # terminal node 67 | logger.debug("ended") 68 | return self.terminals[hashable_state] 69 | # otherwise, return v 70 | if is_leaf_node: 71 | return v 72 | 73 | # existing, continue selection 74 | # go next state by picking best according to U(s,a) 75 | best_uct = -float('inf') 76 | best_action = -1 77 | for a in self.valid_moves[hashable_state]: 78 | Ns = self.Ns[hashable_state] 79 | if Ns == 0: 80 | Ns = 1e-8 81 | uct = self.Q[hashable_state][a] + self.configs.cpuct * self.P[hashable_state][a] * math.sqrt(Ns) / (1 + self.Nsa[hashable_state][a]) 82 | if uct > best_uct: 83 | best_uct = uct 84 | best_action = a 85 | # transition 86 | next_state = self.game.get_next_state(state, best_action) 87 | 88 | # 1. if not leaf, continue traversing, and state=s will get the value from the leaf node 89 | # 2. if leaf, we will expand it and return the value for backpropagation 90 | v = self.search(next_state) 91 | 92 | # update stats 93 | # add in new estimate and average 94 | self.Q[hashable_state][best_action] = (self.Nsa[hashable_state][best_action] * self.Q[hashable_state][best_action] + v) / (self.Nsa[hashable_state][best_action] + 1) 95 | self.Ns[hashable_state] += 1 96 | self.Nsa[hashable_state][best_action] += 1 97 | 98 | # now we are single player, hence just v instead of -v 99 | return v 100 | 101 | def get_action_prob(self, state:DialogSession): 102 | hashable_state = self._to_string_rep(state) 103 | if hashable_state not in self.Ns: 104 | # selected leaf node, expand 105 | logging.warn("querying a state that has not been visited") 106 | self._init_node(state) 107 | # get the counts for all moves 108 | # convert to prob 109 | prob = np.zeros(self.player.get_valid_moves(state).shape) 110 | for a in self.valid_moves[hashable_state]: 111 | prob[a] = self.Nsa[hashable_state][a] 112 | prob /= prob.sum() 113 | return prob 114 | 115 | 116 | class OpenLoopMCTS(MCTS): 117 | def __init__(self, game, player, configs) -> None: 118 | super().__init__(game, player, configs) 119 | self.realizations: dict = {} # state -> list of real DialogSessions 120 | self.realizations_Vs: dict = {} # state -> {realization: V(realization)} 121 | self.realizations_Ns: dict = {} # state -> {realization: N(realization)} 122 | self.max_realizations = configs.max_realizations 123 | return 124 | 125 | def _to_string_rep(self, state:DialogSession): 126 | # for tree search, keep all dialog turns 127 | das = [] 128 | for (speaker, da, _) in state: 129 | if speaker == state.SYS: 130 | das.append(da) 131 | return "__".join(das) 132 | 133 | def _init_node(self, state:DialogSession): 134 | hashable_state = self._to_string_rep(state) 135 | allowed_actions = self.player.get_valid_moves(state) 136 | self.valid_moves[hashable_state] = allowed_actions.nonzero()[0] 137 | 138 | self.Ns[hashable_state] = 0 139 | self.Nsa[hashable_state] = {action: 0 for action in self.valid_moves[hashable_state]} 140 | self.Q[hashable_state] = {action: self.configs.Q_0 for action in self.valid_moves[hashable_state]} 141 | self.realizations[hashable_state] = [state.copy()] 142 | 143 | prior, v = self.player.predict(state) 144 | self.Vs[state.to_string_rep(keep_sys_da=True, keep_user_da=True)] = v # for debugging 145 | self.P[hashable_state] = prior * allowed_actions 146 | # renormalize 147 | if np.sum(self.P[hashable_state]) == 0: 148 | self.P[hashable_state] = allowed_actions / np.sum(allowed_actions) 149 | logger.warning("This should never happen") 150 | else: 151 | self.P[hashable_state] /= np.sum(self.P[hashable_state]) 152 | return v 153 | 154 | def _sample_realization(self, hashable_state): 155 | rand_i = np.random.randint(len(self.realizations[hashable_state])) 156 | return self.realizations[hashable_state][rand_i] 157 | 158 | def _add_new_realizations(self, state): 159 | hashable_state = self._to_string_rep(state) 160 | if hashable_state not in self.realizations: 161 | self.realizations[hashable_state] = [] 162 | if state in self.realizations[hashable_state]: 163 | return 164 | 165 | self.realizations[hashable_state].append(state.copy()) 166 | if len(self.realizations[hashable_state]) > self.max_realizations: 167 | # should never happen 168 | logger.warning(f"len(self.realizations[hashable_state])={len(self.realizations[hashable_state])}") 169 | self.realizations[hashable_state].pop(0) 170 | return 171 | 172 | def _get_next_state(self, state, best_action): 173 | prefetch_state = self._to_string_rep(state) + "__" + self.player.dialog_acts[best_action] 174 | if prefetch_state in self.realizations and len(self.realizations[prefetch_state]) == self.max_realizations: 175 | # use the cached realization 176 | return self._sample_realization(prefetch_state) 177 | 178 | # otherwise, generate a new realization 179 | next_state = self.game.get_next_state(state, best_action) 180 | return next_state 181 | 182 | def _update_realizations_Vs(self, state: DialogSession, v: float): 183 | hashable_state = self._to_string_rep(state) 184 | if hashable_state not in self.realizations_Vs: 185 | self.realizations_Vs[hashable_state] = {} 186 | self.realizations_Ns[hashable_state] = {} 187 | sys_utt = state.get_turn_utt( 188 | turn=-1, 189 | role=state.SYS, 190 | ) 191 | if sys_utt not in self.realizations_Vs[hashable_state]: 192 | self.realizations_Vs[hashable_state][sys_utt] = 0 193 | self.realizations_Ns[hashable_state][sys_utt] = 0 194 | # update 195 | self.realizations_Ns[hashable_state][sys_utt] += 1 196 | self.realizations_Vs[hashable_state][sys_utt] += (v - self.realizations_Vs[hashable_state][sys_utt]) / self.realizations_Ns[hashable_state][sys_utt] 197 | return 198 | 199 | def search(self, state:DialogSession): 200 | hashable_state = self._to_string_rep(state) 201 | 202 | # check everytime since state is stochastic, does not map to hashable_state 203 | terminated_v = self.game.get_dialog_ended(state) 204 | # check if it is terminal node 205 | if terminated_v == 1.0: 206 | logger.debug("ended") 207 | return terminated_v 208 | 209 | # otherwise, if is nontermial leaf node, we initialize and return v 210 | if hashable_state not in self.P: 211 | # selected leaf node, expand it 212 | # first visit V because v is only evaluated once for a hashable_state 213 | v = self._init_node(state) 214 | return v 215 | else: 216 | # add only when it is new 217 | self._add_new_realizations(state) 218 | 219 | # existing, continue selection 220 | # go next state by picking best according to U(s,a) 221 | best_uct = -float('inf') 222 | best_action = -1 223 | for a in self.valid_moves[hashable_state]: 224 | Ns = self.Ns[hashable_state] 225 | if Ns == 0: 226 | Ns = 1e-8 227 | # a variant of PUCT 228 | uct = self.Q[hashable_state][a] + self.configs.cpuct * self.P[hashable_state][a] * math.sqrt(Ns) / (1 + self.Nsa[hashable_state][a]) 229 | if uct > best_uct: 230 | best_uct = uct 231 | best_action = a 232 | # transition. For open loop, first sample from an existing realization 233 | state = self._sample_realization(hashable_state) 234 | next_state = self._get_next_state(state, best_action) 235 | 236 | # 1. if not leaf, continue traversing, and state=s will get the value from the leaf node 237 | # 2. if leaf, we will expand it and return the value for backpropagation 238 | v = self.search(next_state) 239 | 240 | # update stats 241 | # add in new estimate and average 242 | self.Q[hashable_state][best_action] = (self.Nsa[hashable_state][best_action] * self.Q[hashable_state][best_action] + v) / (self.Nsa[hashable_state][best_action] + 1) 243 | self.Ns[hashable_state] += 1 244 | self.Nsa[hashable_state][best_action] += 1 245 | 246 | # update v to realizations for NLG at inference 247 | self._update_realizations_Vs(next_state, v) 248 | # now we are single player, hence just v instead of -v 249 | return v 250 | 251 | def get_best_realization(self, state:DialogSession, action: int): 252 | prefetch_state = self._to_string_rep(state) + "__" + self.player.dialog_acts[action] 253 | if prefetch_state not in self.realizations_Vs: 254 | raise Exception("querying a state that has no realizations sampled before") 255 | # get the counts for all moves 256 | # convert to prob 257 | curr_best_v = -float('inf') 258 | curr_best_realization = None 259 | for sys_utt, v in self.realizations_Vs[prefetch_state].items(): 260 | if v > curr_best_v: 261 | curr_best_v = v 262 | curr_best_realization = sys_utt 263 | return curr_best_realization 264 | 265 | 266 | class OpenLoopMCTSParallel(OpenLoopMCTS): 267 | def __init__(self, game, player, configs) -> None: 268 | super().__init__(game, player, configs) 269 | 270 | def _populate_next_realizations(self, state, next_action, num_to_add): 271 | next_states = self.game.get_next_state_batched(state, next_action, batch=num_to_add) 272 | for next_state in next_states: 273 | self._add_new_realizations(next_state) 274 | return 275 | 276 | def _get_next_state(self, state, best_action): 277 | prefetch_state = self._to_string_rep(state) + "__" + self.player.dialog_acts[best_action] 278 | if prefetch_state in self.realizations and len(self.realizations[prefetch_state]) == self.max_realizations: 279 | # use the cached realization 280 | return self._sample_realization(prefetch_state) 281 | 282 | self._populate_next_realizations(state, best_action, self.max_realizations) 283 | return self._sample_realization(prefetch_state) 284 | 285 | def _init_node(self, state:DialogSession): 286 | hashable_state = self._to_string_rep(state) 287 | allowed_actions = self.player.get_valid_moves(state) 288 | self.valid_moves[hashable_state] = allowed_actions.nonzero()[0] 289 | 290 | self.Ns[hashable_state] = 0 291 | self.Nsa[hashable_state] = {action: 0 for action in self.valid_moves[hashable_state]} 292 | self.Q[hashable_state] = {action: self.configs.Q_0 for action in self.valid_moves[hashable_state]} 293 | # should have been initialized during _get_next_state, except for the root node 294 | if hashable_state not in self.realizations: 295 | self.realizations[hashable_state] = [state.copy()] 296 | 297 | # TODO: batch predict value function 298 | prior, v = self.player.predict(state) 299 | self.Vs[state.to_string_rep(keep_sys_da=True, keep_user_da=True)] = v # for debugging 300 | self.P[hashable_state] = prior * allowed_actions 301 | # renormalize 302 | if np.sum(self.P[hashable_state]) == 0: 303 | self.P[hashable_state] = allowed_actions / np.sum(allowed_actions) 304 | logger.warning("This should never happen") 305 | else: 306 | self.P[hashable_state] /= np.sum(self.P[hashable_state]) 307 | return v -------------------------------------------------------------------------------- /core/players.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | from typing import List, Tuple 5 | from core.helpers import DialogSession 6 | from core.gen_models import GenerationModel, DialogModel 7 | from core.game import PersuasionGame 8 | from abc import ABC, abstractmethod 9 | from collections import Counter 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class DialogPlanner(ABC): 16 | @abstractmethod 17 | def get_valid_moves(self, state): 18 | # 1 if the i-th dialog act is valid, 0 otherwise 19 | pass 20 | 21 | @abstractmethod 22 | def predict(self, state) -> "Tuple[np.ndarray, float]": 23 | # returns a prob and value 24 | pass 25 | 26 | 27 | class P4GSystemPlanner(DialogPlanner): 28 | def __init__(self, 29 | dialog_acts, max_hist_num_turns, 30 | user_dialog_acts, user_max_hist_num_turns, 31 | generation_model:GenerationModel, 32 | conv_examples: List[DialogSession] = []) -> None: 33 | super().__init__() 34 | self.dialog_acts = dialog_acts 35 | self.max_hist_num_turns = max_hist_num_turns # used in prompting next da 36 | self.user_dialog_acts = user_dialog_acts 37 | self.user_max_hist_num_turns = user_max_hist_num_turns # used in heuristic function 38 | self.conv_examples = conv_examples 39 | self.generation_model = generation_model 40 | self.smoothing = 1.0 41 | self.task_prompt = f""" 42 | The following is background information about Save the Children. 43 | Save the Children is head-quartered in London, and they work to help fight poverty around the world. Children need help in developing countries and war zones. Small donations like $1 or $2 go a long way to help. 44 | The Persuader can choose amongst the following actions during a conversation: 45 | {" ".join([f"[{da}]" for da in dialog_acts])} 46 | The following is an example conversation between a Persuader and a Persuadee about a charity called Save the Children. The Persuader is trying to persuade the Persuadee to donate to Save the Children. 47 | {self.process_exp()} 48 | The following is a new conversation between another Persuader and Persuadee. 49 | """ 50 | self.task_prompt = self.task_prompt.replace("\t", "").strip() 51 | 52 | self.inf_args = { 53 | "max_new_tokens": 8, 54 | "temperature": 1.0, 55 | "return_full_text": False, 56 | "do_sample": True, 57 | "num_return_sequences": 15, 58 | } 59 | return 60 | 61 | def process_exp(self, keep_sys_da=True, keep_user_da=False): 62 | prompt_exps = "" 63 | for exp in self.conv_examples: 64 | prompt_exps += exp.to_string_rep(keep_sys_da=keep_sys_da, keep_user_da=keep_user_da) + "\n" 65 | return prompt_exps.strip() 66 | 67 | def get_valid_moves(self, state): 68 | # 1 if the i-th dialog act is valid, 0 otherwise 69 | turn = len(state) 70 | if turn < 1: 71 | return np.array([1 if da == PersuasionGame.S_Greeting else 0 for da in self.dialog_acts]) 72 | return np.array([1 for _ in self.dialog_acts]) 73 | 74 | def get_utterance(self, state, action) -> str: 75 | return "" # should not be called 76 | 77 | def _get_generated_da(self, data) -> list: 78 | # convert generated responses to DA 79 | pred_da = [] 80 | for resp in data: 81 | resp = resp['generated_text'].strip() 82 | start_idx = resp.find("[") 83 | end_idx = resp.find("]") 84 | if start_idx == -1 or end_idx == -1: 85 | continue 86 | found_da = resp[start_idx + 1: end_idx].strip() 87 | if found_da in self.dialog_acts: 88 | pred_da.append(found_da) 89 | return pred_da 90 | 91 | def predict(self, state:DialogSession) -> "Tuple[np.ndarray, float]": 92 | # test k times and compute prob. See num_return_sequences in the API 93 | # the value would be our objective function 94 | if len(state) == 0: 95 | prompt = f""" 96 | {self.task_prompt} 97 | Persuader: 98 | """ 99 | else: 100 | prompt = f""" 101 | {self.task_prompt} 102 | {state.to_string_rep(keep_sys_da=True)} 103 | Persuader: 104 | """ 105 | prompt = prompt.replace("\t", "").strip() 106 | logger.debug(prompt) 107 | data = self.generation_model.generate(prompt, **self.inf_args) 108 | sampled_das = self._get_generated_da(data) 109 | logger.debug(f"sampled das: {sampled_das}") 110 | # convert to prob distribution 111 | prob = np.zeros(len(self.dialog_acts)) 112 | prob += self.smoothing 113 | for da in sampled_das: 114 | prob[self.dialog_acts.index(da)] += 1 115 | prob /= prob.sum() 116 | v = self.heuristic(state) 117 | return prob, v 118 | 119 | def _get_user_generated_da(self, data) -> list: 120 | # convert generated responses to DA 121 | pred_da = [] 122 | for resp in data: 123 | resp = resp['generated_text'].strip() 124 | start_idx = resp.find("[") 125 | end_idx = resp.find("]") 126 | if start_idx == -1 or end_idx == -1: 127 | continue 128 | found_da = resp[start_idx + 1: end_idx].strip() 129 | if found_da in self.user_dialog_acts: 130 | pred_da.append(found_da) 131 | return pred_da 132 | 133 | def heuristic(self, state:DialogSession) -> float: 134 | # insert prop to donate, and compute the likelihood of user simulator agreeing to donate 135 | assert(state[-1][0] == PersuasionGame.USR) 136 | prompt = f""" 137 | The following is background information about task. 138 | The Persuader is trying to persuade the Persuadee to donate to Save the Children. 139 | The Persuadee can choose amongst the following actions during a conversation to respond to the Persuader: 140 | {" ".join([f"[{da}]" for da in self.user_dialog_acts])} 141 | The following is a conversation between a Persuader and a Persuadee about a charity called Save the Children. 142 | {self.process_exp(keep_sys_da=False, keep_user_da=True)} 143 | The following is a new conversation between another Persuader and Persuadee. 144 | {state.to_string_rep(keep_user_da=True, max_turn_to_display=self.user_max_hist_num_turns)} 145 | Persuader: Would you be interested in donating to Save the Children? 146 | Persuadee: 147 | """ 148 | prompt = prompt.replace("\t", "").strip() 149 | 150 | inf_args = { 151 | "max_new_tokens": 8, 152 | "temperature": 1.1, 153 | "return_full_text": False, 154 | "do_sample": True, 155 | "num_return_sequences": 10, 156 | } 157 | data = self.generation_model.generate(prompt, **inf_args) 158 | sampled_das = self._get_user_generated_da(data) 159 | 160 | logger.debug(f"persuadee prompt: {prompt}") 161 | logger.debug(f"sampled das: {sampled_das}") 162 | 163 | # heuristic score 164 | score = [] 165 | for da in sampled_das: 166 | if da == PersuasionGame.U_NoDonation: 167 | score.append(-1.0) 168 | elif da == PersuasionGame.U_NegativeReaction: 169 | score.append(-0.5) 170 | elif da == PersuasionGame.U_Neutral: 171 | score.append(0.0) 172 | elif da == PersuasionGame.U_PositiveReaction: 173 | score.append(0.5) 174 | elif da == PersuasionGame.U_Donate: 175 | score.append(1.0) 176 | v = 0.0 if len(score) == 0 else np.mean(score) 177 | logger.debug(f"sampled das to v: {v}") 178 | return float(v) 179 | 180 | 181 | class P4GChatSystemPlanner(P4GSystemPlanner): 182 | def __init__(self, 183 | dialog_acts, max_hist_num_turns, 184 | user_dialog_acts, user_max_hist_num_turns, 185 | generation_model:GenerationModel, 186 | conv_examples: List[DialogSession] = []) -> None: 187 | super().__init__( 188 | dialog_acts, max_hist_num_turns, 189 | user_dialog_acts, user_max_hist_num_turns, 190 | generation_model, conv_examples 191 | ) 192 | self.task_prompt = f""" 193 | Save the Children is head-quartered in London, and they work to help fight poverty around the world. Children need help in developing countries and war zones. Small donations like $1 or $2 go a long way to help. 194 | You are Persuader who is trying to persuade the Persuadee to donate to a charity called Save the Children. You can choose amongst the following actions during a conversation: 195 | {" ".join([f"[{da}]" for da in dialog_acts])} 196 | The following is an example conversation between a Persuader and a Persuadee about Save the Children. 197 | """.replace("\t", "").strip() 198 | self.new_task_prompt = "The following is a new conversation between Persuader (you) and a Persuadee." 199 | self.prompt_examples = self.process_chat_exp(new_task_prompt=self.new_task_prompt) 200 | 201 | self.inf_args = { 202 | "max_new_tokens": 12, 203 | "temperature": 1.0, 204 | "return_full_text": False, 205 | "do_sample": True, 206 | "num_return_sequences": 15, 207 | } 208 | return 209 | 210 | def process_chat_exp(self, 211 | new_task_prompt, 212 | assistant_role=PersuasionGame.SYS, 213 | keep_sys_da=True, keep_user_da=False): 214 | prompt_exps = [] 215 | for exp in self.conv_examples: 216 | prompt_exps += self.__proccess_chat_exp(exp, keep_sys_da, keep_user_da, assistant_role) 217 | prompt_exps.append({ 218 | "role":"system", "content": new_task_prompt 219 | }) 220 | return prompt_exps[:-1] 221 | 222 | def __proccess_chat_exp(self, 223 | exp:DialogSession, 224 | keep_sys_da, keep_user_da, 225 | assistant_role=PersuasionGame.SYS, 226 | max_hist_num_turns: int = -1): 227 | if len(exp) == 0: 228 | return [] 229 | # P4G dataset starts with the system/Persuader 230 | assert(exp[0][0] == PersuasionGame.SYS) 231 | 232 | prompt_messages = [] 233 | num_turns_to_truncate = 0 234 | if max_hist_num_turns > 0: 235 | num_turns_to_truncate = max(0, len(exp) // 2 - max_hist_num_turns) 236 | 237 | # init with user 238 | # if assistant_role == PersuasionGame.SYS: 239 | # if keep_user_da: 240 | # prompt_messages.append({ 241 | # "role": "user", 242 | # "content": f"{PersuasionGame.USR}: [{PersuasionGame.U_Neutral}] Hello.".strip() 243 | # }) 244 | # else: 245 | # prompt_messages.append({ 246 | # "role": "user", 247 | # "content": f"{PersuasionGame.USR}: Hello.".strip() 248 | # }) 249 | # all the rest 250 | for i, (role, da, utt) in enumerate(exp): 251 | # truncate to reduce the size of the prompt 252 | if (i // 2) < num_turns_to_truncate: 253 | continue 254 | # if assistant is the Persuader, then current data is also Persuader -> then it is of role "system" 255 | if role == PersuasionGame.SYS: 256 | if keep_sys_da: 257 | content = f"{role}: [{da}] {utt}".strip() 258 | else: 259 | content = f"{role}: {utt}".strip() 260 | if assistant_role == PersuasionGame.SYS: 261 | prompt_role = "assistant" 262 | else: 263 | prompt_role = "user" 264 | else: 265 | if keep_user_da: 266 | content = f"{role}: [{da}] {utt}".strip() 267 | else: 268 | content = f"{role}: {utt}".strip() 269 | if assistant_role == PersuasionGame.USR: 270 | prompt_role = "assistant" 271 | else: 272 | prompt_role = "user" 273 | 274 | prompt_messages.append({ 275 | "role": prompt_role, 276 | "content": content 277 | }) 278 | return prompt_messages 279 | 280 | def get_valid_moves(self, state): 281 | # 1 if the i-th dialog act is valid, 0 otherwise 282 | turn = len(state) 283 | if turn < 1: 284 | return np.array([1 if da == PersuasionGame.S_Greeting else 0 for da in self.dialog_acts]) 285 | return np.array([1 for _ in self.dialog_acts]) 286 | 287 | def get_utterance(self, state, action) -> str: 288 | return "" # should not be called 289 | 290 | def _get_generated_da(self, data) -> list: 291 | # convert generated responses to DA 292 | pred_da = [] 293 | for resp in data: 294 | resp = resp['generated_text'].strip() 295 | start_idx = resp.find("[") 296 | end_idx = resp.find("]") 297 | if start_idx == -1 or end_idx == -1: 298 | continue 299 | found_da = resp[start_idx + 1: end_idx].strip() 300 | if found_da in self.dialog_acts: 301 | pred_da.append(found_da) 302 | return pred_da 303 | 304 | def predict(self, state:DialogSession) -> "Tuple[np.ndarray, float]": 305 | # test k times and compute prob. See num_return_sequences in the API 306 | # the value would be our objective function 307 | messages = [ 308 | {'role': 'system', 'content': self.task_prompt}, 309 | *self.prompt_examples, 310 | {'role': 'system', 'content': self.new_task_prompt} 311 | ] 312 | if len(state) == 0: 313 | messages.append({'role': 'user', 'content': f'{PersuasionGame.USR}: Hello.'}) 314 | else: 315 | assert(state[-1][0] == PersuasionGame.USR) 316 | messages += self.__proccess_chat_exp(state, keep_sys_da=True, keep_user_da=False) 317 | # produce a response 318 | data = self.generation_model.chat_generate(messages, **self.inf_args) 319 | 320 | sampled_das = self._get_generated_da(data) 321 | logger.debug(f"sampled das: {sampled_das}") 322 | # convert to prob distribution 323 | prob = np.zeros(len(self.dialog_acts)) 324 | prob += self.smoothing 325 | for da in sampled_das: 326 | prob[self.dialog_acts.index(da)] += 1 327 | prob /= prob.sum() 328 | v = self.heuristic(state) 329 | return prob, v 330 | 331 | def _get_user_generated_da(self, data) -> list: 332 | # convert generated responses to DA 333 | pred_da = [] 334 | for resp in data: 335 | resp = resp['generated_text'].strip() 336 | start_idx = resp.find("[") 337 | end_idx = resp.find("]") 338 | if start_idx == -1 or end_idx == -1: 339 | continue 340 | found_da = resp[start_idx + 1: end_idx].strip() 341 | if found_da in self.user_dialog_acts: 342 | pred_da.append(found_da) 343 | return pred_da 344 | 345 | def heuristic(self, state:DialogSession) -> float: 346 | # insert prop to donate, and compute the likelihood of user simulator agreeing to donate 347 | assert(state[-1][0] == PersuasionGame.USR) 348 | user_task_prompt = f""" 349 | You are a persuadee. A Persuader is trying to persuade you to donate to a charity called Save the Children. 350 | You can choose amongst the following actions during a conversation to respond to the Persuader: 351 | {" ".join([f"[{da}]" for da in self.user_dialog_acts])} 352 | The following is a new conversation between a Persuader and a Persuadee (you). 353 | """.replace("\t", "").strip() 354 | user_new_task_prompt = "The following is a new conversation between a Persuader and a Persuadee (you)." 355 | 356 | messages = [ 357 | {'role': 'system', 'content': user_task_prompt}, 358 | *self.process_chat_exp(new_task_prompt=user_new_task_prompt, assistant_role=PersuasionGame.USR, keep_sys_da=False, keep_user_da=True), 359 | {'role': 'system', 'content': user_new_task_prompt} 360 | ] 361 | messages += self.__proccess_chat_exp(state, assistant_role=PersuasionGame.USR, keep_sys_da=False, keep_user_da=True) 362 | messages.append({ 363 | 'role': 'user', 'content': f'{PersuasionGame.SYS}: Would you be interested in donating to Save the Children?' 364 | }) 365 | 366 | inf_args = { 367 | "max_new_tokens": 12, 368 | "temperature": 1.1, 369 | "return_full_text": False, 370 | "do_sample": True, 371 | "num_return_sequences": 10, 372 | } 373 | data = self.generation_model.chat_generate(messages, **inf_args) 374 | sampled_das = self._get_user_generated_da(data) 375 | 376 | logger.debug(f"persuadee prompt: {messages}") 377 | logger.debug(f"sampled das: {sampled_das}") 378 | 379 | # heuristic score 380 | score = [] 381 | for da in sampled_das: 382 | if da == PersuasionGame.U_NoDonation: 383 | score.append(-1.0) 384 | elif da == PersuasionGame.U_NegativeReaction: 385 | score.append(-0.5) 386 | elif da == PersuasionGame.U_Neutral: 387 | score.append(0.0) 388 | elif da == PersuasionGame.U_PositiveReaction: 389 | score.append(0.5) 390 | elif da == PersuasionGame.U_Donate: 391 | score.append(1.0) 392 | v = 0.0 if len(score) == 0 else np.mean(score) 393 | logger.debug(f"sampled das to v: {v}") 394 | return float(v) 395 | 396 | 397 | class PersuaderModel(DialogModel): 398 | def __init__(self, 399 | dialog_acts:List[str], 400 | backbone_model:GenerationModel, 401 | max_hist_num_turns: int = 5, 402 | conv_examples: List[DialogSession] = [], 403 | inference_args: dict = {}): 404 | super().__init__() 405 | self.conv_examples = conv_examples 406 | self.backbone_model = backbone_model 407 | self.max_hist_num_turns = max_hist_num_turns 408 | # prompts and DAs 409 | self.da_prompts_mapping = { 410 | PersuasionGame.S_Greeting: "The Persuader greets the Persuadee.", 411 | # start of persuasion strategies 412 | PersuasionGame.S_CredibilityAppeal: "The Persuader establishes credibility of Save the Children by citing its impact.", 413 | PersuasionGame.S_EmotionAppeal: "The Persuader uses an emotion appeal to convince the Persuadee.", 414 | PersuasionGame.S_LogicalAppeal: "The Persuader use of reasoning and evidence to convince the Persuadee.", 415 | PersuasionGame.S_TaskRelatedInquiry: "The Persuader asks about the Persuadee's knowledge or opinion related to Save the Children.", 416 | PersuasionGame.S_PropositionOfDonation: "The Persuader asks if the Persuadee would like to make a small donation.", 417 | # end of persuasion strategies 418 | PersuasionGame.S_Other: "The Persuader responds to the Persuadee without using any persuaive strategy.", 419 | } 420 | # only allow da that has the mapping 421 | self.dialog_acts = [da for da in dialog_acts if da in self.da_prompts_mapping] 422 | 423 | logger.debug(self.dialog_acts) 424 | self.task_prompt = f""" 425 | The following is background information about Save the Children. 426 | Save the Children is head-quartered in London, and they work to help fight poverty around the world. Children need help in developing countries and war zones. Small donations like $1 or $2 go a long way to help. 427 | The following is an example conversation between a Persuader and a Persuadee about a charity called Save the Children. The Persuader is trying to persuade the Persuadee to donate to Save the Children. 428 | {self.process_exp()} 429 | The following is a new conversation between another Persuader and Persuadee. 430 | """ 431 | self.task_prompt = self.task_prompt.replace("\t", "").strip() 432 | self.inference_args = { 433 | "max_new_tokens": 128, 434 | "temperature": 0.0, 435 | "repetition_penalty": 1.0, 436 | "do_sample": False, # otherwise tree will never go to the next level 437 | "return_full_text": False, 438 | **inference_args 439 | } 440 | return 441 | 442 | def process_exp(self): 443 | prompt_exps = "" 444 | for exp in self.conv_examples: 445 | prompt_exps += self.__proccess_exp(exp) + "\n" 446 | return prompt_exps.strip() 447 | 448 | def __proccess_exp(self, exp:DialogSession, max_hist_num_turns: int = -1): 449 | prompt_exp = "" 450 | num_turns_to_truncate = 0 451 | if max_hist_num_turns > 0: 452 | num_turns_to_truncate = max(0, len(exp) // 2 - max_hist_num_turns) 453 | 454 | for i, (role, da, utt) in enumerate(exp): 455 | # truncate to reduce the size of the prompt 456 | if (i // 2) < num_turns_to_truncate: 457 | continue 458 | 459 | if role == PersuasionGame.SYS: 460 | prompt_exp += f"{self.da_prompts_mapping[da]}\n{role}: {utt}\n" 461 | else: 462 | prompt_exp += f"{role}: {utt}\n" 463 | return prompt_exp.strip() 464 | 465 | def get_utterance(self, state:DialogSession, action:int) -> str: 466 | # planner gives an action, state is history, you need to produce a response accrd to the action 467 | da = self.dialog_acts[action] 468 | da_prompt = self.da_prompts_mapping[da] 469 | if len(state) == 0: 470 | prompt = f""" 471 | {self.task_prompt} 472 | {da_prompt} 473 | Persuader: 474 | """ 475 | else: 476 | prompt = f""" 477 | {self.task_prompt} 478 | {self.__proccess_exp(state, max_hist_num_turns=self.max_hist_num_turns)} 479 | {da_prompt} 480 | Persuader: 481 | """ 482 | prompt = prompt.replace("\t", "").strip() 483 | # produce a response 484 | data = self.backbone_model.generate(prompt, **self.inference_args) 485 | sys_resp = self.backbone_model._cleaned_resp(data, prompt)[0] # TODO 486 | return sys_resp 487 | 488 | def get_utterance_w_da(self, state: DialogSession, action) -> Tuple[str, str]: 489 | raise NotImplementedError 490 | 491 | 492 | class PersuaderChatModel(PersuaderModel): 493 | def __init__(self, 494 | dialog_acts:List[str], 495 | backbone_model:GenerationModel, 496 | max_hist_num_turns: int = 5, 497 | conv_examples: List[DialogSession] = [], 498 | inference_args: dict = {}): 499 | super().__init__( 500 | dialog_acts=dialog_acts, 501 | backbone_model=backbone_model, 502 | max_hist_num_turns=max_hist_num_turns, 503 | conv_examples=conv_examples, 504 | inference_args=inference_args 505 | ) 506 | self.inference_args = { 507 | "max_new_tokens": 128, 508 | "temperature": 0.0, 509 | "repetition_penalty": 1.0, 510 | "do_sample": False, # otherwise tree will never go to the next level, unless you do OpenLoop search 511 | "return_full_text": False, 512 | **inference_args 513 | } 514 | self.task_prompt = """ 515 | Save the Children is head-quartered in London, and they work to help fight poverty around the world. Children need help in developing countries and war zones. Small donations like $1 or $2 go a long way to help. 516 | You are Persuader who is trying to persuade the Persuadee to donate to a charity called Save the Children. 517 | The following is an example conversation between a Persuader and a Persuadee about Save the Children. 518 | """.replace("\t", "").strip() 519 | self.new_task_prompt = "The following is a new conversation between Persuader (you) and another Persuadee.\nThe Persuader greets the persuadee." 520 | self.prompt_examples = self.process_chat_exp() 521 | return 522 | 523 | def process_chat_exp(self): 524 | prompt_exps = [] 525 | for exp in self.conv_examples: 526 | prompt_exps += self.__proccess_chat_exp(exp) 527 | prompt_exps.append({ 528 | "role":"system", "content": self.new_task_prompt 529 | }) 530 | return prompt_exps[:-1] 531 | 532 | def __proccess_chat_exp(self, exp:DialogSession, max_hist_num_turns: int = -1): 533 | if len(exp) == 0: 534 | return [] 535 | # P4G dataset starts with the system 536 | assert(exp[0][0] == PersuasionGame.SYS) 537 | 538 | prompt_messages = [] 539 | num_turns_to_truncate = 0 540 | if max_hist_num_turns > 0: 541 | num_turns_to_truncate = max(0, len(exp) // 2 - max_hist_num_turns) 542 | 543 | 544 | next_sys_da = PersuasionGame.S_Greeting 545 | for i, (role, da, utt) in enumerate(exp): 546 | # truncate to reduce the size of the prompt 547 | if (i // 2) < num_turns_to_truncate: 548 | continue 549 | if role == PersuasionGame.SYS: 550 | prompt_messages.append({ 551 | "role": "assistant", 552 | "content": f"{role}: {utt}".strip() 553 | }) 554 | else: 555 | if i+1 < len(exp.history): 556 | next_sys_da = exp[i+1][1] 557 | prompt_messages.append({ 558 | "role": "user", 559 | "content": f"{role}: {utt}\n{self.da_prompts_mapping[next_sys_da]}".strip() 560 | }) 561 | else: 562 | prompt_messages.append({ 563 | "role": "user", 564 | "content": f"{role}: {utt}".strip() 565 | }) 566 | return prompt_messages 567 | 568 | def get_utterance(self, state:DialogSession, action:int) -> str: 569 | return self.get_utterance_batched(state, action, batch=1)[0] 570 | 571 | def get_utterance_batched(self, state:DialogSession, action:int, batch:int=3) -> List[str]: 572 | da = self.dialog_acts[action] 573 | da_prompt = self.da_prompts_mapping[da] 574 | messages = [ 575 | {'role': 'system', 'content': self.task_prompt}, 576 | *self.prompt_examples, 577 | {'role': 'system', 'content': self.new_task_prompt} 578 | ] 579 | if len(state) == 0: 580 | messages.append({'role': 'user', 'content': f'{PersuasionGame.USR}: Hello.\n{da_prompt}'}) 581 | else: 582 | assert(state[-1][0] == PersuasionGame.USR) 583 | messages += self.__proccess_chat_exp(state, max_hist_num_turns=self.max_hist_num_turns) 584 | gen_args = { 585 | **self.inference_args, 586 | "num_return_sequences": batch, # this will be changed to n inside chat_generate 587 | } 588 | data = self.backbone_model.chat_generate(messages, **gen_args) 589 | sys_resps = self.backbone_model._cleaned_chat_resp( 590 | data, assistant_role=f"{PersuasionGame.SYS}:", user_role=f"{PersuasionGame.USR}:" 591 | ) 592 | return sys_resps 593 | 594 | def get_utterance_w_da(self, state: DialogSession, action) -> Tuple[str, str]: 595 | raise NotImplementedError 596 | 597 | 598 | class PersuadeeModel(DialogModel): 599 | def __init__(self, 600 | dialog_acts: List[str], 601 | inference_args: dict, 602 | backbone_model:GenerationModel, 603 | conv_examples: List[DialogSession] = [], 604 | max_hist_num_turns=5): 605 | super().__init__() 606 | self.conv_examples = conv_examples 607 | self.backbone_model = backbone_model 608 | self.dialog_acts = dialog_acts 609 | self.max_hist_num_turns = max_hist_num_turns 610 | # prompts 611 | self.task_prompt = f""" 612 | The following is background information about task. 613 | The Persuader is trying to persuade the Persuadee to donate to Save the Children. 614 | The Persuadee can choose amongst the following actions during a conversation to respond to the Persuader: 615 | {" ".join([f"[{da}]" for da in self.dialog_acts])} 616 | The following is an example conversation between a Persuader and a Persuadee about a charity called Save the Children. 617 | {self.process_exp()} 618 | The following is a new conversation between another Persuader and Persuadee. 619 | """ 620 | self.task_prompt = self.task_prompt.replace("\t", "").strip() 621 | self.inference_args = inference_args 622 | return 623 | 624 | def process_exp(self): 625 | prompt_exps = "" 626 | for exp in self.conv_examples: 627 | prompt_exps += exp.to_string_rep(keep_user_da=True) + "\n" 628 | return prompt_exps.strip() 629 | 630 | def get_utterance(self, state:DialogSession, action=None) -> str: 631 | assert(state[-1][0] == PersuasionGame.SYS) 632 | prompt = f""" 633 | {self.task_prompt} 634 | {state.to_string_rep(keep_user_da=True, max_turn_to_display=self.max_hist_num_turns)} 635 | Persuadee: 636 | """ 637 | prompt = prompt.replace("\t", "").strip() 638 | # produce a response 639 | data = self.backbone_model.generate(prompt, **self.inference_args) 640 | user_resp = self.backbone_model._cleaned_resp(data, prompt)[0] 641 | return user_resp 642 | 643 | def get_utterance_w_da(self, state:DialogSession, action=None) -> "Tuple[str, str]": 644 | user_resp = self.get_utterance(state, action) 645 | # extract da 646 | start_idx = user_resp.find("[") 647 | end_idx = user_resp.find("]") 648 | if start_idx == -1 or end_idx == -1: 649 | da = PersuasionGame.U_Neutral 650 | else: 651 | da = user_resp[start_idx+1:end_idx] 652 | user_resp = user_resp.replace(f"[{da}]", "", 1).strip() 653 | if da not in self.dialog_acts: 654 | da = PersuasionGame.U_Neutral 655 | return da, user_resp 656 | 657 | 658 | class PersuadeeChatModel(PersuadeeModel): 659 | def __init__(self, 660 | dialog_acts: List[str], 661 | inference_args: dict, 662 | backbone_model:GenerationModel, 663 | conv_examples: List[DialogSession] = [], 664 | max_hist_num_turns=5): 665 | super().__init__( 666 | dialog_acts=dialog_acts, 667 | inference_args=inference_args, 668 | backbone_model=backbone_model, 669 | conv_examples=conv_examples, 670 | max_hist_num_turns=max_hist_num_turns 671 | ) 672 | self.inference_args = inference_args 673 | self.task_prompt = f""" 674 | You are a persuadee. A Persuader is trying to persuade you to donate to a charity called Save the Children. 675 | You can choose amongst the following actions during a conversation to respond to the Persuader: 676 | {" ".join([f"[{da}]" for da in self.dialog_acts])} 677 | The following is an example conversation between a Persuader and some Persuadee. 678 | """.replace("\t", "").strip() 679 | self.new_task_prompt = "The following is a new conversation between a Persuader and a Persuadee (you). You may or may not want to donate to Save the Children." 680 | self.heuristic_args: dict = { 681 | "max_hist_num_turns": 2, 682 | "example_pred_turn": [[0, 2, 3, 4]] 683 | } 684 | self.prompt_examples = self.process_chat_exp() 685 | return 686 | 687 | def process_chat_exp(self): 688 | prompt_exps = [] 689 | for exp in self.conv_examples: 690 | prompt_exps += self.__proccess_chat_exp(exp) 691 | prompt_exps.append({ 692 | "role":"system", "content": self.new_task_prompt 693 | }) 694 | return prompt_exps[:-1] 695 | 696 | def __proccess_chat_exp(self, exp:DialogSession, max_hist_num_turns: int = -1): 697 | if len(exp) == 0: 698 | return [] 699 | # P4G dataset starts with the system 700 | assert(exp[0][0] == PersuasionGame.SYS) 701 | 702 | prompt_messages = [] 703 | num_turns_to_truncate = 0 704 | if max_hist_num_turns > 0: 705 | num_turns_to_truncate = max(0, len(exp) // 2 - max_hist_num_turns) 706 | 707 | for i, (role, da, utt) in enumerate(exp): 708 | # truncate to reduce the size of the prompt 709 | if (i // 2) < num_turns_to_truncate: 710 | continue 711 | if role == PersuasionGame.SYS: 712 | prompt_messages.append({ 713 | "role": "user", 714 | "content": f"{role}: {utt}".strip() 715 | }) 716 | else: 717 | prompt_messages.append({ 718 | "role": "assistant", # assistant is the user simulator 719 | "content": f"{role}: [{da}] {utt}".strip() 720 | }) 721 | return prompt_messages 722 | 723 | def get_utterance(self, state:DialogSession, action=None) -> str: 724 | assert(state[-1][0] == PersuasionGame.SYS) # next turn is user's turn 725 | messages = [ 726 | {'role': 'system', 'content': self.task_prompt}, 727 | *self.prompt_examples, 728 | {'role': 'system', 'content': self.new_task_prompt} 729 | ] 730 | messages += self.__proccess_chat_exp(state, max_hist_num_turns=self.max_hist_num_turns) 731 | 732 | # produce a response 733 | data = self.backbone_model.chat_generate(messages, **self.inference_args) 734 | user_resp = self.backbone_model._cleaned_chat_resp( 735 | data, assistant_role=f"{PersuasionGame.USR}:", user_role=f"{PersuasionGame.SYS}:" 736 | )[0] 737 | return user_resp 738 | 739 | def get_utterance_from_batched_states(self, states:List[DialogSession], action=None) -> List[str]: 740 | assert(all([state[-1][0] == PersuasionGame.SYS for state in states])) 741 | all_prompts = [] 742 | for state in states: 743 | messages = [ 744 | {'role': 'system', 'content': self.task_prompt}, 745 | *self.prompt_examples, 746 | {'role': 'system', 'content': self.new_task_prompt} 747 | ] 748 | messages += self.__proccess_chat_exp(state, max_hist_num_turns=self.max_hist_num_turns) 749 | all_prompts.append(messages) 750 | # produce a response 751 | datas = self.backbone_model.chat_generate_batched(all_prompts, **self.inference_args) 752 | user_resps = [] 753 | for data in datas: 754 | user_resp = self.backbone_model._cleaned_chat_resp( 755 | data, assistant_role=f"{PersuasionGame.USR}:", user_role=f"{PersuasionGame.SYS}:" 756 | ) 757 | user_resps.append(user_resp[0]) 758 | return user_resps 759 | 760 | def get_utterance_w_da_from_batched_states(self, states:List[DialogSession], action=None): 761 | gen_user_resps = self.get_utterance_from_batched_states(states, action) 762 | das = [] 763 | user_resps = [] 764 | # extract da 765 | for user_resp in gen_user_resps: 766 | start_idx = user_resp.find("[") 767 | end_idx = user_resp.find("]") 768 | if start_idx == -1 or end_idx == -1: 769 | da = PersuasionGame.U_Neutral 770 | else: 771 | da = user_resp[start_idx+1:end_idx] 772 | user_resp = user_resp.replace(f"[{da}]", "", 1).strip() 773 | if da not in self.dialog_acts: 774 | da = PersuasionGame.U_Neutral 775 | das.append(da) 776 | user_resps.append(user_resp) 777 | return das, user_resps 778 | 779 | def __process_heuristics_chat_exp(self, dialog:DialogSession): 780 | if len(dialog) == 0: 781 | return [] 782 | # assumes you start with the system 783 | # and ends with a user utterance to predict 784 | assert(dialog[0][0] == PersuasionGame.SYS) 785 | assert(dialog[-1][0] == PersuasionGame.USR) 786 | 787 | prompt_messages = [] 788 | input_context = [] 789 | answer_da = dialog[-1][1] 790 | for i, (role, da, utt) in enumerate(dialog): 791 | # if assistant is the Persuader, then current data is also Persuader -> then it is of role "system" 792 | # treat this as a task 793 | content = f"{role}: {utt}".strip() 794 | input_context.append(content) 795 | input_context.append(f"{dialog.USR} feeling:") 796 | 797 | prompt_q = "\n".join(input_context) 798 | prompt_messages.append({ 799 | "role": 'user', 800 | "content": prompt_q 801 | }) 802 | prompt_messages.append({ 803 | "role": 'assistant', 804 | "content": f"{answer_da}" 805 | }) 806 | return prompt_messages 807 | 808 | def __truncate_heuristics_dialog(self, dialog:DialogSession, pred_end_idx=-1): 809 | max_history_length = self.heuristic_args['max_hist_num_turns'] 810 | if pred_end_idx == -1: 811 | pred_end_idx = len(dialog.history) - 1 812 | new_sys_start_idx = max(0, pred_end_idx - (max_history_length * 2 - 1)) 813 | new_history = [] 814 | for j, (role, da, utt) in enumerate(dialog): 815 | if j >= new_sys_start_idx: 816 | new_history.append((role, da, utt)) 817 | if j == pred_end_idx: 818 | # user's utternace to predict 819 | break 820 | new_dialog_session = DialogSession(dialog.SYS, dialog.USR).from_history(new_history) 821 | return new_dialog_session 822 | 823 | def process_heurstics_chat_exp(self, new_task_prompt: str): 824 | prompt_exps = [] 825 | for i, exp in enumerate(self.conv_examples): 826 | pred_end_turns: List[int] = self.heuristic_args['example_pred_turn'][i] 827 | # make a new dialogue session until that pred_idx with max max_history_length turns 828 | for pred_end_turn in pred_end_turns: 829 | pred_end_idx = pred_end_turn * 2 + 1 830 | new_dialog_session = self.__truncate_heuristics_dialog(exp, pred_end_idx) 831 | prompt_exps += self.__process_heuristics_chat_exp(new_dialog_session) 832 | prompt_exps.append({ 833 | "role":"system", "content": new_task_prompt 834 | }) 835 | return prompt_exps[:-1] 836 | 837 | def predict_da(self, state:DialogSession, never_end=True) -> str: 838 | # never_end=True during real chat, let user choose to terminate, not this function 839 | # insert prop to donate, and compute the likelihood of user simulator agreeing to donate 840 | assert(state[-1][0] == PersuasionGame.USR) 841 | 842 | messages = [ 843 | {'role': 'system', 'content': self.task_prompt}, 844 | *self.process_heurstics_chat_exp(new_task_prompt=self.new_task_prompt), 845 | {'role': 'system', 'content': self.new_task_prompt} 846 | ] 847 | new_dialog_session = self.__truncate_heuristics_dialog(state, -1) 848 | messages += self.__process_heuristics_chat_exp(new_dialog_session)[:-1] 849 | 850 | # majority vote, same as value function 851 | inf_args = { 852 | "max_new_tokens": 5, 853 | "temperature": 0.7, 854 | "return_full_text": False, 855 | "do_sample": True, 856 | "num_return_sequences": 5, 857 | } 858 | datas = self.backbone_model.chat_generate(messages, **inf_args) 859 | # process into das 860 | sampled_das: list = [] 861 | for resp in datas: 862 | user_da = resp['generated_text'].strip() 863 | if user_da not in self.dialog_acts: 864 | sampled_das.append(PersuasionGame.U_Neutral) 865 | if never_end: 866 | if user_da == PersuasionGame.U_Donate: 867 | sampled_das.append(PersuasionGame.U_PositiveReaction) 868 | elif user_da == PersuasionGame.U_NoDonation: 869 | sampled_das.append(PersuasionGame.U_NegativeReaction) 870 | else: 871 | sampled_das.append(user_da) 872 | else: 873 | sampled_das.append(user_da) 874 | logger.info(f"sampled das: {sampled_das}") 875 | # majority vote 876 | counted_das = Counter(sampled_das) 877 | user_da = counted_das.most_common(1)[0][0] 878 | return user_da -------------------------------------------------------------------------------- /data/p4g/300_dialog_turn_based.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/data/p4g/300_dialog_turn_based.pkl -------------------------------------------------------------------------------- /interactive.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import argparse 4 | 5 | from tqdm.auto import tqdm 6 | from core.gen_models import ( 7 | LocalModel, OpenAIModel, OpenAIChatModel, AzureOpenAIChatModel 8 | ) 9 | from core.players import ( 10 | PersuadeeModel, PersuaderModel, P4GSystemPlanner, 11 | PersuaderChatModel, PersuadeeChatModel, P4GChatSystemPlanner 12 | ) 13 | from core.game import PersuasionGame 14 | from core.mcts import MCTS, OpenLoopMCTS, OpenLoopMCTSParallel 15 | from core.helpers import DialogSession 16 | from utils.utils import dotdict 17 | from utils.prompt_examples import EXP_DIALOG 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def play_gdpzero(backbone_model, args): 24 | args = dotdict({ 25 | "cpuct": 1.0, 26 | "num_MCTS_sims": args.num_mcts_sims, 27 | "max_realizations": args.max_realizations, 28 | "Q_0": args.Q_0, 29 | }) 30 | 31 | game_ontology = PersuasionGame.get_game_ontology() 32 | sys_da = game_ontology['system']['dialog_acts'] 33 | user_da = game_ontology['user']['dialog_acts'] 34 | system_name = PersuasionGame.SYS 35 | user_name = PersuasionGame.USR 36 | 37 | exp_1 = DialogSession(system_name, user_name).from_history(EXP_DIALOG) 38 | 39 | system = PersuaderChatModel( 40 | sys_da, 41 | backbone_model, 42 | conv_examples=[exp_1], 43 | inference_args={ 44 | "temperature": 0.7, 45 | "do_sample": True, # for MCTS open loop 46 | "return_full_text": False, 47 | } 48 | ) 49 | user = PersuadeeChatModel( 50 | user_da, 51 | inference_args={ 52 | "max_new_tokens": 128, 53 | "temperature": 1.1, 54 | "repetition_penalty": 1.0, 55 | "do_sample": True, # for MCTS open loop 56 | "return_full_text": False, 57 | }, 58 | backbone_model=backbone_model, 59 | conv_examples=[exp_1] 60 | ) 61 | planner = P4GChatSystemPlanner( 62 | dialog_acts=system.dialog_acts, 63 | max_hist_num_turns=system.max_hist_num_turns, 64 | user_dialog_acts=user.dialog_acts, 65 | user_max_hist_num_turns=user.max_hist_num_turns, 66 | generation_model=backbone_model, 67 | conv_examples=[exp_1] 68 | ) 69 | 70 | game = PersuasionGame(system, user) 71 | state = game.init_dialog() 72 | 73 | # init 74 | state.add_single(game.SYS, 'greeting', "Hello. How are you?") 75 | print("You are now the Persuadee. Type 'q' to quit, and 'r' to restart.") 76 | print("Persuader: Hello. How are you?") 77 | 78 | your_utt = input("You: ") 79 | while your_utt.strip() != "q": 80 | if your_utt.strip() == "r": 81 | state = game.init_dialog() 82 | state.add_single(game.SYS, 'greeting', "Hello. How are you?") 83 | game.display(state) 84 | your_utt = input("You: ") 85 | continue 86 | 87 | # used for da prediction 88 | tmp_state = state.copy() 89 | tmp_state.add_single(game.USR, 'neutral', your_utt.strip()) 90 | user_da = user.predict_da(tmp_state) 91 | 92 | logging.info(f"user_da: {user_da}") 93 | state.add_single(game.USR, user_da, your_utt.strip()) 94 | 95 | # planning 96 | if isinstance(backbone_model, OpenAIModel): 97 | backbone_model._cached_generate.cache_clear() 98 | dialog_planner = OpenLoopMCTS(game, planner, args) 99 | for i in tqdm(range(args.num_MCTS_sims)): 100 | dialog_planner.search(state) 101 | 102 | mcts_policy = dialog_planner.get_action_prob(state) 103 | mcts_policy_next_da = system.dialog_acts[np.argmax(mcts_policy)] 104 | logger.info(f"mcts_policy: {mcts_policy}") 105 | logger.info(f"mcts_policy_next_da: {mcts_policy_next_da}") 106 | logger.info(dialog_planner.Q) 107 | 108 | sys_utt = dialog_planner.get_best_realization(state, np.argmax(mcts_policy)) 109 | logging.info(f"sys_da: [{mcts_policy_next_da}]") 110 | print(f"Persuader: {sys_utt}") 111 | 112 | state.add_single(game.SYS, mcts_policy_next_da, sys_utt) 113 | your_utt = input("You: ") 114 | return 115 | 116 | 117 | def play_raw_prompt(backbone_model): 118 | system_name = PersuasionGame.SYS 119 | user_name = PersuasionGame.USR 120 | exp_1 = DialogSession(system_name, user_name).from_history(EXP_DIALOG) 121 | 122 | game_ontology = PersuasionGame.get_game_ontology() 123 | sys_da = game_ontology['system']['dialog_acts'] 124 | user_da = game_ontology['user']['dialog_acts'] 125 | 126 | system = PersuaderChatModel( 127 | sys_da, 128 | backbone_model, 129 | conv_examples=[exp_1] 130 | ) 131 | user = PersuadeeChatModel( 132 | user_da, 133 | inference_args={ 134 | "max_new_tokens": 128, 135 | "temperature": 1.1, 136 | "repetition_penalty": 1.0, 137 | "do_sample": True, 138 | "return_full_text": False, 139 | }, 140 | backbone_model=backbone_model, 141 | conv_examples=[exp_1] 142 | ) 143 | planner = P4GChatSystemPlanner( 144 | dialog_acts=system.dialog_acts, 145 | max_hist_num_turns=system.max_hist_num_turns, 146 | user_dialog_acts=user.dialog_acts, 147 | user_max_hist_num_turns=user.max_hist_num_turns, 148 | generation_model=backbone_model, 149 | conv_examples=[exp_1] 150 | ) 151 | game = PersuasionGame(system, user) 152 | state = game.init_dialog() 153 | 154 | # init 155 | state.add_single(game.SYS, 'greeting', "Hello. How are you?") 156 | print("You are now the Persuadee. Type 'q' to quit, and 'r' to restart.") 157 | print("Persuader: Hello. How are you?") 158 | 159 | your_utt = input("You: ") 160 | while your_utt.strip() != "q": 161 | if your_utt.strip() == "r": 162 | state = game.init_dialog() 163 | state.add_single(game.SYS, 'greeting', "Hello. How are you?") 164 | game.display(state) 165 | your_utt = input("You: ") 166 | continue 167 | # used for da prediction 168 | state.add_single(game.USR, 'neutral', your_utt.strip()) 169 | 170 | # planning 171 | prior, v = planner.predict(state) 172 | greedy_policy = system.dialog_acts[np.argmax(prior)] 173 | next_best_state = game.get_next_state(state, np.argmax(prior)) 174 | greedy_pred_resp = next_best_state.history[-2][2] 175 | 176 | logging.info(f"sys_da: [{greedy_policy}]") 177 | print(f"Persuader: {greedy_pred_resp}") 178 | 179 | state.add_single(game.SYS, greedy_policy, greedy_pred_resp) 180 | your_utt = input("You: ") 181 | return 182 | 183 | 184 | def main(args): 185 | if args.llm in ['code-davinci-002', 'text-davinci-003']: 186 | backbone_model = OpenAIModel(args.llm) 187 | elif args.llm in ['gpt-3.5-turbo']: 188 | backbone_model = OpenAIChatModel(args.llm, args.gen_sentences) 189 | elif args.llm == 'chatgpt': 190 | backbone_model = AzureOpenAIChatModel(args.llm, args.gen_sentences) 191 | 192 | if args.algo == 'gdpzero': 193 | print("using GDPZero as planning algorithm") 194 | play_gdpzero(backbone_model, args) 195 | elif args.algo == 'raw-prompt': 196 | print("using raw prompting as planning") 197 | play_raw_prompt(backbone_model) 198 | return 199 | 200 | 201 | if __name__ == "__main__": 202 | # logging mode 203 | parser = argparse.ArgumentParser() 204 | parser.add_argument("--log", type=int, default=logging.WARNING, help="logging mode", choices=[logging.INFO, logging.DEBUG, logging.WARNING]) 205 | parser.add_argument("--algo", type=str, default='gdpzero', choices=['gdpzero', 'raw-prompt'], help="planning algorithm") 206 | # used by PDP-Zero 207 | parser.add_argument('--llm', type=str, default="gpt-3.5-turbo", choices=["code-davinci-002", "gpt-3.5-turbo", "text-davinci-002", "chatgpt"], help='OpenAI model name') 208 | parser.add_argument('--gen_sentences', type=int, default=3, help='number of sentences to generate from the llm. Longer ones will be truncated by nltk.') 209 | parser.add_argument('--num_mcts_sims', type=int, default=10, help='number of mcts simulations') 210 | parser.add_argument('--max_realizations', type=int, default=3, help='number of realizations per mcts state') 211 | parser.add_argument('--Q_0', type=float, default=0.25, help='initial Q value for unitialized states. to control exploration') 212 | args = parser.parse_args() 213 | logging.basicConfig(level=args.log) 214 | logger.setLevel(args.log) 215 | 216 | main(args) -------------------------------------------------------------------------------- /outputs/chatgpt_raw_prompt.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/outputs/chatgpt_raw_prompt.pkl -------------------------------------------------------------------------------- /outputs/gdpzero_10sims_3rlz_0.25Q0_20dialogs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/outputs/gdpzero_10sims_3rlz_0.25Q0_20dialogs.pkl -------------------------------------------------------------------------------- /outputs/gdpzero_10sims_v_chatgpt.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/outputs/gdpzero_10sims_v_chatgpt.pkl -------------------------------------------------------------------------------- /outputs/gdpzero_20sims_3rlz_0.0Q0_20dialogs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/outputs/gdpzero_20sims_3rlz_0.0Q0_20dialogs.pkl -------------------------------------------------------------------------------- /outputs/gdpzero_50sims_3rlz_0.0Q0_20dialogs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/outputs/gdpzero_50sims_3rlz_0.0Q0_20dialogs.pkl -------------------------------------------------------------------------------- /outputs/gdpzero_5sims_3rlz_0.0Q0_20dialogs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyux/GDPZero/e9f10e089953205a38b2e44857089f96f656c20e/outputs/gdpzero_5sims_3rlz_0.0Q0_20dialogs.pkl -------------------------------------------------------------------------------- /runners/gdpzero.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | 7 | from tqdm.auto import tqdm 8 | from core.gen_models import ( 9 | LocalModel, OpenAIModel, OpenAIChatModel, AzureOpenAIChatModel 10 | ) 11 | from core.players import ( 12 | PersuadeeModel, PersuaderModel, P4GSystemPlanner, 13 | PersuaderChatModel, PersuadeeChatModel, P4GChatSystemPlanner 14 | ) 15 | from core.game import PersuasionGame 16 | from core.mcts import OpenLoopMCTS 17 | from core.helpers import DialogSession 18 | from utils.utils import dotdict 19 | from utils.prompt_examples import EXP_DIALOG 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | logger.setLevel(logging.DEBUG) 24 | 25 | 26 | def main(cmd_args): 27 | game_ontology = PersuasionGame.get_game_ontology() 28 | sys_da = game_ontology['system']['dialog_acts'] 29 | user_da = game_ontology['user']['dialog_acts'] 30 | system_name = PersuasionGame.SYS 31 | user_name = PersuasionGame.USR 32 | 33 | exp_1 = DialogSession(system_name, user_name).from_history(EXP_DIALOG) 34 | 35 | 36 | if cmd_args.llm in ['code-davinci-002']: 37 | backbone_model = OpenAIModel(cmd_args.llm) 38 | SysModel = PersuaderModel 39 | UsrModel = PersuadeeModel 40 | SysPlanner = P4GSystemPlanner 41 | elif cmd_args.llm in ['gpt-3.5-turbo']: 42 | backbone_model = OpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 43 | SysModel = PersuaderChatModel 44 | UsrModel = PersuadeeChatModel 45 | SysPlanner = P4GChatSystemPlanner 46 | elif cmd_args.llm == 'chatgpt': 47 | backbone_model = AzureOpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 48 | SysModel = PersuaderChatModel 49 | UsrModel = PersuadeeChatModel 50 | SysPlanner = P4GChatSystemPlanner 51 | 52 | system = SysModel( 53 | sys_da, 54 | backbone_model, 55 | conv_examples=[exp_1], 56 | inference_args={ 57 | "temperature": 0.7, 58 | "do_sample": True, # for MCTS open loop 59 | "return_full_text": False, 60 | } 61 | ) 62 | user = UsrModel( 63 | user_da, 64 | inference_args={ 65 | "max_new_tokens": 128, 66 | "temperature": 1.1, 67 | "repetition_penalty": 1.0, 68 | "do_sample": True, # for MCTS open loop 69 | "return_full_text": False, 70 | }, 71 | backbone_model=backbone_model, 72 | conv_examples=[exp_1] 73 | ) 74 | planner = SysPlanner( 75 | dialog_acts=system.dialog_acts, 76 | max_hist_num_turns=system.max_hist_num_turns, 77 | user_dialog_acts=user.dialog_acts, 78 | user_max_hist_num_turns=user.max_hist_num_turns, 79 | generation_model=backbone_model, 80 | conv_examples=[exp_1] 81 | ) 82 | game = PersuasionGame(system, user) 83 | 84 | print(f"System dialog acts: {system.dialog_acts}") 85 | print(f"User dialog acts: {user.dialog_acts}") 86 | 87 | with open("data/p4g/300_dialog_turn_based.pkl", "rb") as f: 88 | all_dialogs = pickle.load(f) 89 | 90 | num_dialogs = cmd_args.num_dialogs 91 | args = dotdict({ 92 | "cpuct": 1.0, 93 | "num_MCTS_sims": cmd_args.num_mcts_sims, 94 | "Q_0": cmd_args.Q_0, 95 | "max_realizations": cmd_args.max_realizations, 96 | }) 97 | 98 | output = [] # for evaluation. [{did, context, ori_resp, new_resp, debug}, ...] 99 | # those dialogs has inappropriated content and will throw an error/be filtered with OPENAI models. See raw_prompting.py file for more details 100 | bad_dialogs = ['20180808-024552_152_live', '20180723-100140_767_live', '20180825-080802_964_live'] # throws exception due to ChatGPT API filtering 101 | num_done = 0 102 | pbar = tqdm(total=num_dialogs, desc="evaluating") 103 | for did in all_dialogs.keys(): 104 | if did in bad_dialogs: 105 | print("skipping dialog id: ", did) 106 | continue 107 | if num_done == num_dialogs: 108 | break 109 | 110 | print("evaluating dialog id: ", did) 111 | context = "" 112 | dialog = all_dialogs[did] 113 | 114 | state = game.init_dialog() 115 | for t, turn in enumerate(dialog["dialog"]): 116 | if len(turn["ee"]) == 0: # ended 117 | break 118 | # also skip last turn as there is no evaluation 119 | if t == len(dialog["dialog"]) - 1: 120 | break 121 | 122 | usr_utt = " ".join(turn["ee"]).strip() 123 | usr_da = dialog["label"][t]["ee"][-1] 124 | 125 | # map to our dialog act 126 | if usr_da == "disagree-donation": 127 | usr_da = PersuasionGame.U_NoDonation 128 | elif usr_da == "negative-reaction-to-donation": 129 | usr_da = PersuasionGame.U_NegativeReaction 130 | elif usr_da == "positive-reaction-to-donation": 131 | usr_da = PersuasionGame.U_PositiveReaction 132 | elif usr_da == "agree-donation": 133 | usr_da = PersuasionGame.U_Donate 134 | else: 135 | usr_da = PersuasionGame.U_Neutral 136 | 137 | # game ended 138 | if usr_da == PersuasionGame.U_Donate: 139 | break 140 | 141 | # map sys as well 142 | sys_utt = " ".join(turn["er"]).strip() 143 | sys_da = set(dialog["label"][t]["er"]) 144 | intersected_das = sys_da.intersection(system.dialog_acts) 145 | if len(intersected_das) == 0: 146 | sys_da = "other" 147 | else: 148 | sys_da = list(intersected_das)[-1] 149 | 150 | state.add_single(PersuasionGame.SYS, sys_da, sys_utt) 151 | state.add_single(PersuasionGame.USR, usr_da, usr_utt) 152 | 153 | # update context for evaluation 154 | context = f""" 155 | {context} 156 | Persuader: {sys_utt} 157 | Persuadee: {usr_utt} 158 | """ 159 | context = context.replace('\t', '').strip() 160 | 161 | # mcts policy 162 | if isinstance(backbone_model, OpenAIModel): 163 | backbone_model._cached_generate.cache_clear() 164 | dialog_planner = OpenLoopMCTS(game, planner, args) 165 | print("searching") 166 | for i in tqdm(range(args.num_MCTS_sims)): 167 | dialog_planner.search(state) 168 | 169 | mcts_policy = dialog_planner.get_action_prob(state) 170 | mcts_policy_next_da = system.dialog_acts[np.argmax(mcts_policy)] 171 | 172 | # # fetch the generated utterance from simulation 173 | mcts_pred_rep = dialog_planner.get_best_realization(state, np.argmax(mcts_policy)) 174 | 175 | # next ground truth utterance 176 | human_resp = " ".join(dialog["dialog"][t+1]["er"]).strip() 177 | next_sys_das = set(dialog["label"][t+1]["er"]) 178 | next_intersected_das = next_sys_das.intersection(system.dialog_acts) 179 | if len(next_intersected_das) == 0: 180 | next_sys_da = "other" 181 | else: 182 | next_sys_da = list(next_intersected_das)[-1] 183 | 184 | # logging for debug 185 | debug_data = { 186 | "probs": mcts_policy, 187 | "da": mcts_policy_next_da, 188 | "search_tree": { 189 | "Ns": dialog_planner.Ns, 190 | "Nsa": dialog_planner.Nsa, 191 | "Q": dialog_planner.Q, 192 | "P": dialog_planner.P, 193 | "Vs": dialog_planner.Vs, 194 | "realizations": dialog_planner.realizations, 195 | "realizations_Vs": dialog_planner.realizations_Vs, 196 | "realizations_Ns": dialog_planner.realizations_Ns, 197 | }, 198 | } 199 | 200 | # update data 201 | cmp_data = { 202 | 'did': did, 203 | 'context': context, 204 | 'ori_resp': human_resp, 205 | 'ori_da': next_sys_da, 206 | 'new_resp': mcts_pred_rep, 207 | 'new_da': mcts_policy_next_da, 208 | "debug": debug_data, 209 | } 210 | output.append(cmp_data) 211 | 212 | if cmd_args.debug: 213 | print(context) 214 | print("human resp: ", human_resp) 215 | print("human da: ", next_sys_da) 216 | print("mcts resp: ", mcts_pred_rep) 217 | print("mcts da: ", mcts_policy_next_da) 218 | with open(cmd_args.output, "wb") as f: 219 | pickle.dump(output, f) 220 | num_done += 1 221 | pbar.update(1) 222 | return 223 | 224 | 225 | if __name__ == "__main__": 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument('--output', type=str, default="outputs/gdpzero.pkl", help='output file') 228 | parser.add_argument('--llm', type=str, default="code-davinci-002", choices=["code-davinci-002", "chatgpt", "gpt-3.5-turbo"], help='OpenAI model name') 229 | parser.add_argument('--gen_sentences', type=int, default=-1, help='number of sentences to generate from the llm. Longer ones will be truncated by nltk.') 230 | parser.add_argument('--num_mcts_sims', type=int, default=20, help='number of mcts simulations') 231 | parser.add_argument('--max_realizations', type=int, default=3, help='number of realizations per mcts state') 232 | parser.add_argument('--Q_0', type=float, default=0.0, help='initial Q value for unitialized states. to control exploration') 233 | parser.add_argument('--num_dialogs', type=int, default=20, help='number of dialogs to test MCTS on') 234 | parser.add_argument('--debug', action='store_true', help='debug mode') 235 | parser.parse_args() 236 | cmd_args = parser.parse_args() 237 | print("saving to", cmd_args.output) 238 | 239 | main(cmd_args) -------------------------------------------------------------------------------- /runners/gdpzero_noRS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import pickle 4 | import argparse 5 | 6 | from tqdm.auto import tqdm 7 | from core.gen_models import ( 8 | LocalModel, OpenAIModel, OpenAIChatModel, AzureOpenAIChatModel 9 | ) 10 | from core.players import ( 11 | PersuadeeModel, PersuaderModel, P4GSystemPlanner, 12 | PersuaderChatModel, PersuadeeChatModel, P4GChatSystemPlanner 13 | ) 14 | from core.game import PersuasionGame 15 | from core.mcts import OpenLoopMCTS 16 | from core.helpers import DialogSession 17 | from utils.utils import dotdict 18 | from utils.prompt_examples import EXP_DIALOG 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.DEBUG) 23 | 24 | 25 | def main(cmd_args): 26 | system_name = PersuasionGame.SYS 27 | user_name = PersuasionGame.USR 28 | 29 | exp_1 = DialogSession(system_name, user_name).from_history(EXP_DIALOG) 30 | 31 | game_ontology = PersuasionGame.get_game_ontology() 32 | sys_da = game_ontology['system']['dialog_acts'] 33 | user_da = game_ontology['user']['dialog_acts'] 34 | 35 | if cmd_args.llm == 'code-davinci-002': 36 | backbone_model = OpenAIModel(cmd_args.llm) 37 | SysModel = PersuaderModel 38 | UsrModel = PersuadeeModel 39 | SysPlanner = P4GSystemPlanner 40 | elif cmd_args.llm in ['gpt-3.5-turbo']: 41 | backbone_model = OpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 42 | SysModel = PersuaderChatModel 43 | UsrModel = PersuadeeChatModel 44 | SysPlanner = P4GChatSystemPlanner 45 | elif cmd_args.llm == 'chatgpt': 46 | backbone_model = AzureOpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 47 | SysModel = PersuaderChatModel 48 | UsrModel = PersuadeeChatModel 49 | SysPlanner = P4GChatSystemPlanner 50 | 51 | system = SysModel( 52 | sys_da, 53 | backbone_model, 54 | conv_examples=[exp_1], 55 | inference_args={ 56 | "temperature": 0.7, 57 | "do_sample": True, # for MCTS open loop 58 | "return_full_text": False, 59 | } 60 | ) 61 | user = UsrModel( 62 | user_da, 63 | inference_args={ 64 | "max_new_tokens": 128, 65 | "temperature": 1.1, 66 | "repetition_penalty": 1.0, 67 | "do_sample": True, # for MCTS open loop 68 | "return_full_text": False, 69 | }, 70 | backbone_model=backbone_model, 71 | conv_examples=[exp_1] 72 | ) 73 | planner = SysPlanner( 74 | dialog_acts=system.dialog_acts, 75 | max_hist_num_turns=system.max_hist_num_turns, 76 | user_dialog_acts=user.dialog_acts, 77 | user_max_hist_num_turns=user.max_hist_num_turns, 78 | generation_model=backbone_model, 79 | conv_examples=[exp_1] 80 | ) 81 | game = PersuasionGame(system, user) 82 | 83 | with open("data/p4g/300_dialog_turn_based.pkl", "rb") as f: 84 | all_dialogs = pickle.load(f) 85 | 86 | num_dialogs = 20 87 | args = dotdict({ 88 | "cpuct": 1.0, 89 | "num_MCTS_sims": cmd_args.num_mcts_sims, 90 | "max_realizations": cmd_args.max_realizations, 91 | "Q_0": cmd_args.Q_0, 92 | }) 93 | 94 | output = [] # for evaluation. [{did, context, ori_resp, new_resp, debug}, ...] 95 | bad_dialogs = ['20180808-024552_152_live', '20180723-100140_767_live', '20180825-080802_964_live'] # throws exception due to ChatGPT API filtering 96 | num_done = 0 97 | pbar = tqdm(total=num_dialogs, desc="evaluating") 98 | for did in all_dialogs.keys(): 99 | if did in bad_dialogs: 100 | print("skipping dialog id: ", did) 101 | continue 102 | if num_done == num_dialogs: 103 | break 104 | 105 | print("evaluating dialog id: ", did) 106 | context = "" 107 | dialog = all_dialogs[did] 108 | 109 | state = game.init_dialog() 110 | for t, turn in enumerate(dialog["dialog"]): 111 | if len(turn["ee"]) == 0: # ended 112 | break 113 | # also skip last turn as there is no evaluation 114 | if t == len(dialog["dialog"]) - 1: 115 | break 116 | 117 | usr_utt = " ".join(turn["ee"]).strip() 118 | usr_da = dialog["label"][t]["ee"][-1] 119 | 120 | # map to our dialog act 121 | if usr_da == "disagree-donation": 122 | usr_da = PersuasionGame.U_NoDonation 123 | elif usr_da == "negative-reaction-to-donation": 124 | usr_da = PersuasionGame.U_NegativeReaction 125 | elif usr_da == "positive-reaction-to-donation": 126 | usr_da = PersuasionGame.U_PositiveReaction 127 | elif usr_da == "agree-donation": 128 | usr_da = PersuasionGame.U_Donate 129 | else: 130 | usr_da = PersuasionGame.U_Neutral 131 | 132 | # game ended 133 | if usr_da == PersuasionGame.U_Donate: 134 | break 135 | 136 | # map sys as well 137 | sys_utt = " ".join(turn["er"]).strip() 138 | sys_da = set(dialog["label"][t]["er"]) 139 | intersected_das = sys_da.intersection(system.dialog_acts) 140 | if len(intersected_das) == 0: 141 | sys_da = "other" 142 | else: 143 | sys_da = list(intersected_das)[-1] 144 | 145 | state.add_single(PersuasionGame.SYS, sys_da, sys_utt) 146 | state.add_single(PersuasionGame.USR, usr_da, usr_utt) 147 | 148 | # update context for evaluation 149 | context = f""" 150 | {context} 151 | Persuader: {sys_utt} 152 | Persuadee: {usr_utt} 153 | """ 154 | context = context.replace('\t', '').strip() 155 | 156 | # mcts policy 157 | if isinstance(backbone_model, OpenAIModel): 158 | backbone_model._cached_generate.cache_clear() 159 | dialog_planner = OpenLoopMCTS(game, planner, args) 160 | print("searching") 161 | for i in tqdm(range(args.num_MCTS_sims)): 162 | dialog_planner.search(state) 163 | 164 | mcts_policy = dialog_planner.get_action_prob(state) 165 | mcts_policy_next_da = system.dialog_acts[np.argmax(mcts_policy)] 166 | 167 | # # fetch the generated utterance from simulation 168 | # next_best_state = "__".join([dialog_planner._to_string_rep(state), mcts_policy_next_da]) 169 | # NLGs = dialog_planner.realizations[next_best_state] 170 | # rand_idx: int = np.random.choice(np.arange(0, len(NLGs)), size=1)[0] 171 | # mcts_pred_rep = NLGs[rand_idx].history[-2][2] 172 | # generate a new utterance to be fair (technically it will be the same as above as system resp are cached) 173 | next_best_state = game.get_next_state(state, np.argmax(mcts_policy)) 174 | mcts_pred_rep = next_best_state.history[-2][2] 175 | 176 | # next ground truth utterance 177 | human_resp = " ".join(dialog["dialog"][t+1]["er"]).strip() 178 | next_sys_das = set(dialog["label"][t+1]["er"]) 179 | next_intersected_das = next_sys_das.intersection(system.dialog_acts) 180 | if len(next_intersected_das) == 0: 181 | next_sys_da = "other" 182 | else: 183 | next_sys_da = list(next_intersected_das)[-1] 184 | 185 | # logging for debug 186 | debug_data = { 187 | "probs": mcts_policy, 188 | "da": mcts_policy_next_da, 189 | "search_tree": { 190 | "Ns": dialog_planner.Ns, 191 | "Nsa": dialog_planner.Nsa, 192 | "Q": dialog_planner.Q, 193 | "P": dialog_planner.P, 194 | "Vs": dialog_planner.Vs, 195 | "realizations": dialog_planner.realizations, 196 | }, 197 | } 198 | 199 | # update data 200 | cmp_data = { 201 | 'did': did, 202 | 'context': context, 203 | 'ori_resp': human_resp, 204 | 'ori_da': next_sys_da, 205 | 'new_resp': mcts_pred_rep, 206 | 'new_da': mcts_policy_next_da, 207 | "debug": debug_data, 208 | } 209 | output.append(cmp_data) 210 | with open(cmd_args.output, "wb") as f: 211 | pickle.dump(output, f) 212 | num_done += 1 213 | pbar.update(1) 214 | pbar.close() 215 | return 216 | 217 | 218 | if __name__ == "__main__": 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument('--output', type=str, default="outputs/gdpzero_noRS.pkl", help='output file') 221 | parser.add_argument('--llm', type=str, default="code-davinci-002", choices=["code-davinci-002", "gpt-3.5-turbo", "chatgpt"], help='OpenAI model name') 222 | parser.add_argument('--gen_sentences', type=int, default=-1, help='max number of sentences to generate') 223 | parser.add_argument('--num_mcts_sims', type=int, default=20, help='number of mcts simulations') 224 | parser.add_argument('--max_realizations', type=int, default=3, help='number of realizations per mcts state') 225 | parser.add_argument('--Q_0', type=float, default=0.0, help='initial Q value for unitialized states. to control exploration') 226 | parser.parse_args() 227 | cmd_args = parser.parse_args() 228 | print("saving to", cmd_args.output) 229 | 230 | main(cmd_args) -------------------------------------------------------------------------------- /runners/gdpzero_noopenloop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import pickle 4 | import argparse 5 | 6 | from tqdm.auto import tqdm 7 | from core.gen_models import ( 8 | LocalModel, OpenAIModel, OpenAIChatModel, AzureOpenAIChatModel 9 | ) 10 | from core.players import ( 11 | PersuadeeModel, PersuaderModel, P4GSystemPlanner, 12 | PersuaderChatModel, PersuadeeChatModel, P4GChatSystemPlanner 13 | ) 14 | from core.game import PersuasionGame 15 | from core.mcts import MCTS 16 | from core.helpers import DialogSession 17 | from utils.utils import dotdict 18 | from utils.prompt_examples import EXP_DIALOG 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.DEBUG) 23 | 24 | 25 | def main(cmd_args): 26 | system_name = PersuasionGame.SYS 27 | user_name = PersuasionGame.USR 28 | exp_1 = DialogSession(system_name, user_name).from_history(EXP_DIALOG) 29 | 30 | 31 | game_ontology = PersuasionGame.get_game_ontology() 32 | sys_da = game_ontology['system']['dialog_acts'] 33 | user_da = game_ontology['user']['dialog_acts'] 34 | 35 | if cmd_args.llm == 'code-davinci-002': 36 | backbone_model = OpenAIModel(cmd_args.llm) 37 | SysModel = PersuaderModel 38 | UsrModel = PersuadeeModel 39 | SysPlanner = P4GSystemPlanner 40 | elif cmd_args.llm in ['gpt-3.5-turbo']: 41 | backbone_model = OpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 42 | SysModel = PersuaderChatModel 43 | UsrModel = PersuadeeChatModel 44 | SysPlanner = P4GChatSystemPlanner 45 | elif cmd_args.llm == 'chatgpt': 46 | backbone_model = AzureOpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 47 | SysModel = PersuaderChatModel 48 | UsrModel = PersuadeeChatModel 49 | SysPlanner = P4GChatSystemPlanner 50 | 51 | system = SysModel( 52 | sys_da, 53 | backbone_model, 54 | conv_examples=[exp_1] 55 | ) 56 | user = UsrModel( 57 | user_da, 58 | inference_args={ 59 | "max_new_tokens": 128, 60 | "temperature": 1.0, 61 | "repetition_penalty": 1.0, 62 | "do_sample": False, # for MCTS closed loop 63 | "return_full_text": False, 64 | }, 65 | backbone_model=backbone_model, 66 | conv_examples=[exp_1] 67 | ) 68 | planner = SysPlanner( 69 | dialog_acts=system.dialog_acts, 70 | max_hist_num_turns=system.max_hist_num_turns, 71 | user_dialog_acts=user.dialog_acts, 72 | user_max_hist_num_turns=user.max_hist_num_turns, 73 | generation_model=backbone_model, 74 | conv_examples=[exp_1] 75 | ) 76 | 77 | game = PersuasionGame(system, user) 78 | 79 | with open("data/p4g/300_dialog_turn_based.pkl", "rb") as f: 80 | all_dialogs = pickle.load(f) 81 | 82 | num_dialogs = 20 83 | args = dotdict({ 84 | "cpuct": 1.0, 85 | "num_MCTS_sims": cmd_args.num_mcts_sims, 86 | "Q_0": 0.0, 87 | }) 88 | 89 | output = [] # for evaluation. [{did, context, ori_da, ori_resp, new_da, new_resp, debug}, ...] 90 | bad_dialogs = ['20180808-024552_152_live', '20180723-100140_767_live', '20180825-080802_964_live'] # throws exception due to ChatGPT API filtering 91 | num_done = 0 92 | pbar = tqdm(total=num_dialogs, desc="evaluating") 93 | for did in all_dialogs.keys(): 94 | if did in bad_dialogs: 95 | print("skipping dialog id: ", did) 96 | continue 97 | if num_done == num_dialogs: 98 | break 99 | 100 | print("evaluating dialog id: ", did) 101 | context = "" 102 | dialog = all_dialogs[did] 103 | 104 | state = game.init_dialog() 105 | for t, turn in enumerate(dialog["dialog"]): 106 | if len(turn["ee"]) == 0: # ended 107 | break 108 | # also skip last turn as there is no evaluation 109 | if t == len(dialog["dialog"]) - 1: 110 | break 111 | 112 | usr_utt = " ".join(turn["ee"]).strip() 113 | usr_da = dialog["label"][t]["ee"][-1] 114 | 115 | # map to our dialog act 116 | if usr_da == "disagree-donation": 117 | usr_da = PersuasionGame.U_NoDonation 118 | elif usr_da == "negative-reaction-to-donation": 119 | usr_da = PersuasionGame.U_NegativeReaction 120 | elif usr_da == "positive-reaction-to-donation": 121 | usr_da = PersuasionGame.U_PositiveReaction 122 | elif usr_da == "agree-donation": 123 | usr_da = PersuasionGame.U_Donate 124 | else: 125 | usr_da = PersuasionGame.U_Neutral 126 | 127 | # game ended 128 | if usr_da == PersuasionGame.U_Donate: 129 | break 130 | 131 | # map sys as well 132 | sys_utt = " ".join(turn["er"]).strip() 133 | sys_da = set(dialog["label"][t]["er"]) 134 | intersected_das = sys_da.intersection(system.dialog_acts) 135 | if len(intersected_das) == 0: 136 | sys_da = "other" 137 | else: 138 | sys_da = list(intersected_das)[-1] 139 | 140 | state.add_single(PersuasionGame.SYS, sys_da, sys_utt) 141 | state.add_single(PersuasionGame.USR, usr_da, usr_utt) 142 | 143 | # update context for evaluation 144 | context = f""" 145 | {context} 146 | Persuader: {sys_utt} 147 | Persuadee: {usr_utt} 148 | """ 149 | context = context.replace('\t', '').strip() 150 | 151 | # mcts policy, reset cache since we are in a new turn 152 | if isinstance(backbone_model, OpenAIModel): 153 | backbone_model._cached_generate.cache_clear() 154 | dialog_planner = MCTS(game, planner, args) 155 | print("searching") 156 | for i in tqdm(range(args.num_MCTS_sims)): 157 | dialog_planner.search(state) 158 | 159 | mcts_policy = dialog_planner.get_action_prob(state) 160 | mcts_policy_next_da = system.dialog_acts[np.argmax(mcts_policy)] 161 | 162 | # fetch the generated utterance from simulation 163 | next_best_state = game.get_next_state(state, np.argmax(mcts_policy)) 164 | mcts_pred_rep = next_best_state.history[-2][2] 165 | 166 | # next ground truth utterance 167 | human_resp = " ".join(dialog["dialog"][t+1]["er"]).strip() 168 | next_sys_das = set(dialog["label"][t+1]["er"]) 169 | next_intersected_das = next_sys_das.intersection(system.dialog_acts) 170 | if len(next_intersected_das) == 0: 171 | next_sys_da = "other" 172 | else: 173 | next_sys_da = list(next_intersected_das)[-1] 174 | 175 | # logging for debug 176 | debug_data = { 177 | "probs": mcts_policy, 178 | "da": mcts_policy_next_da, 179 | "search_tree": { 180 | "Ns": dialog_planner.Ns, 181 | "Nsa": dialog_planner.Nsa, 182 | "Q": dialog_planner.Q, 183 | "P": dialog_planner.P, 184 | "Vs": dialog_planner.Vs, 185 | }, 186 | } 187 | 188 | # update data 189 | cmp_data = { 190 | 'did': did, 191 | 'context': context, 192 | 'ori_resp': human_resp, 193 | 'ori_da': next_sys_da, 194 | 'new_resp': mcts_pred_rep, 195 | 'new_da': mcts_policy_next_da, 196 | "debug": debug_data, 197 | } 198 | output.append(cmp_data) 199 | with open(cmd_args.output, "wb") as f: 200 | pickle.dump(output, f) 201 | num_done += 1 202 | pbar.update(1) 203 | pbar.close() 204 | return 205 | 206 | 207 | if __name__ == "__main__": 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--output', type=str, default="outputs/gdpzero_noopenloop.pkl", help='output file') 210 | parser.add_argument('--llm', type=str, default="code-davinci-002", choices=["code-davinci-002", "gpt-3.5-turbo", "chatgpt"], help='OpenAI model name') 211 | parser.add_argument('--gen_sentences', type=int, default=-1, help='max number of sentences to generate') 212 | parser.add_argument('--num_mcts_sims', type=int, default=20, help='number of mcts simulations') 213 | parser.parse_args() 214 | cmd_args = parser.parse_args() 215 | print("saving to", cmd_args.output) 216 | 217 | main(cmd_args) -------------------------------------------------------------------------------- /runners/raw_prompting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import pickle 4 | import argparse 5 | 6 | from tqdm.auto import tqdm 7 | from core.gen_models import ( 8 | LocalModel, OpenAIModel, OpenAIChatModel, AzureOpenAIChatModel 9 | ) 10 | from core.players import ( 11 | PersuadeeModel, PersuaderModel, P4GSystemPlanner, 12 | PersuaderChatModel, PersuadeeChatModel, P4GChatSystemPlanner 13 | ) 14 | from core.game import PersuasionGame 15 | from core.helpers import DialogSession 16 | from utils.prompt_examples import EXP_DIALOG 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.DEBUG) 21 | 22 | 23 | def main(cmd_args): 24 | system_name = PersuasionGame.SYS 25 | user_name = PersuasionGame.USR 26 | exp_1 = DialogSession(system_name, user_name).from_history(EXP_DIALOG) 27 | 28 | if cmd_args.llm == 'code-davinci-002': 29 | backbone_model = OpenAIModel(cmd_args.llm) 30 | SysModel = PersuaderModel 31 | UsrModel = PersuadeeModel 32 | SysPlanner = P4GSystemPlanner 33 | elif cmd_args.llm in ['gpt-3.5-turbo']: 34 | backbone_model = OpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 35 | SysModel = PersuaderChatModel 36 | UsrModel = PersuadeeChatModel 37 | SysPlanner = P4GChatSystemPlanner 38 | elif cmd_args.llm == 'chatgpt': 39 | backbone_model = AzureOpenAIChatModel(cmd_args.llm, cmd_args.gen_sentences) 40 | SysModel = PersuaderChatModel 41 | UsrModel = PersuadeeChatModel 42 | SysPlanner = P4GChatSystemPlanner 43 | 44 | game_ontology = PersuasionGame.get_game_ontology() 45 | sys_da = game_ontology['system']['dialog_acts'] 46 | user_da = game_ontology['user']['dialog_acts'] 47 | 48 | system = SysModel( 49 | sys_da, 50 | backbone_model, 51 | conv_examples=[exp_1] 52 | ) 53 | user = UsrModel( 54 | user_da, 55 | inference_args={ 56 | "max_new_tokens": 128, 57 | "temperature": 1.1, 58 | "repetition_penalty": 1.0, 59 | "do_sample": True, 60 | "return_full_text": False, 61 | }, 62 | backbone_model=backbone_model, 63 | conv_examples=[exp_1] 64 | ) 65 | planner = SysPlanner( 66 | dialog_acts=system.dialog_acts, 67 | max_hist_num_turns=system.max_hist_num_turns, 68 | user_dialog_acts=user.dialog_acts, 69 | user_max_hist_num_turns=user.max_hist_num_turns, 70 | generation_model=backbone_model, 71 | conv_examples=[exp_1] 72 | ) 73 | game = PersuasionGame(system, user) 74 | 75 | with open("data/p4g/300_dialog_turn_based.pkl", "rb") as f: 76 | all_dialogs = pickle.load(f) 77 | 78 | num_dialogs = 20 79 | 80 | output = [] # for evaluation. [{did, context, ori_resp, new_resp, debug}, ...] 81 | # those dialogs has inappropriated content and will throw an error/be filtered with OPENAI models 82 | bad_dialogs = ['20180808-024552_152_live', '20180723-100140_767_live', '20180825-080802_964_live'] 83 | num_done = 0 84 | pbar = tqdm(total=num_dialogs, desc="evaluating") 85 | for did in all_dialogs.keys(): 86 | if did in bad_dialogs: 87 | print("skipping dialog id: ", did) 88 | continue 89 | if num_done == num_dialogs: 90 | break 91 | 92 | print("evaluating dialog id: ", did) 93 | context = "" 94 | no_error = True 95 | dialog = all_dialogs[did] 96 | 97 | state = game.init_dialog() 98 | for t, turn in enumerate(dialog["dialog"]): 99 | if len(turn["ee"]) == 0: # ended 100 | break 101 | # also skip last turn as there is no evaluation 102 | if t == len(dialog["dialog"]) - 1: 103 | break 104 | 105 | usr_utt = " ".join(turn["ee"]).strip() 106 | usr_da = dialog["label"][t]["ee"][-1] 107 | 108 | # map to our dialog act 109 | if usr_da == "disagree-donation": 110 | usr_da = PersuasionGame.U_NoDonation 111 | elif usr_da == "negative-reaction-to-donation": 112 | usr_da = PersuasionGame.U_NegativeReaction 113 | elif usr_da == "positive-reaction-to-donation": 114 | usr_da = PersuasionGame.U_PositiveReaction 115 | elif usr_da == "agree-donation": 116 | usr_da = PersuasionGame.U_Donate 117 | else: 118 | usr_da = PersuasionGame.U_Neutral 119 | 120 | # game ended 121 | if usr_da == PersuasionGame.U_Donate: 122 | break 123 | 124 | # map sys as well 125 | sys_utt = " ".join(turn["er"]).strip() 126 | sys_da = set(dialog["label"][t]["er"]) 127 | intersected_das = sys_da.intersection(system.dialog_acts) 128 | if len(intersected_das) == 0: 129 | sys_da = "other" 130 | else: 131 | sys_da = list(intersected_das)[-1] 132 | 133 | state.add_single(PersuasionGame.SYS, sys_da, sys_utt) 134 | state.add_single(PersuasionGame.USR, usr_da, usr_utt) 135 | 136 | # update context for evaluation 137 | context = f""" 138 | {context} 139 | Persuader: {sys_utt} 140 | Persuadee: {usr_utt} 141 | """ 142 | context = context.replace('\t', '').strip() 143 | 144 | # mcts policy 145 | prior, v = planner.predict(state) 146 | greedy_policy = system.dialog_acts[np.argmax(prior)] 147 | try: 148 | next_best_state = game.get_next_state(state, np.argmax(prior)) 149 | except Exception as e: 150 | bad_dialogs.append(did) 151 | no_error = False 152 | raise e 153 | greedy_pred_resp = next_best_state.history[-2][2] 154 | 155 | # next ground truth utterance 156 | human_resp = " ".join(dialog["dialog"][t + 1]["er"]).strip() 157 | next_sys_das = set(dialog["label"][t+1]["er"]) 158 | next_intersected_das = next_sys_das.intersection(system.dialog_acts) 159 | if len(next_intersected_das) == 0: 160 | next_sys_da = "other" 161 | else: 162 | next_sys_da = list(next_intersected_das)[-1] 163 | 164 | # logging for debug 165 | debug_data = { 166 | "prior": prior, 167 | "da": greedy_policy, 168 | "v": v 169 | } 170 | 171 | # update data 172 | cmp_data = { 173 | 'did': did, 174 | 'context': context, 175 | 'ori_resp': human_resp, 176 | 'ori_da': next_sys_da, 177 | 'new_resp': greedy_pred_resp, 178 | 'new_da': greedy_policy, 179 | "debug": debug_data, 180 | } 181 | output.append(cmp_data) 182 | 183 | if no_error: 184 | with open(cmd_args.output, "wb") as f: 185 | pickle.dump(output, f) 186 | pbar.update(1) 187 | num_done += 1 188 | pbar.close() 189 | print(bad_dialogs) 190 | return 191 | 192 | 193 | if __name__ == "__main__": 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument('--llm', type=str, default="code-davinci-002", choices=["code-davinci-002", "gpt-3.5-turbo", "chatgpt"], help='OpenAI model name') 196 | parser.add_argument('--gen_sentences', type=int, default=-1, help='max number of sentences to generate. -1 for no limit') 197 | parser.add_argument('--output', type=str, default="outputs/raw_prompt.pkl", help='output file') 198 | parser.parse_args() 199 | cmd_args = parser.parse_args() 200 | print("saving to", cmd_args.output) 201 | 202 | main(cmd_args) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import logging 4 | import os 5 | 6 | from tqdm.auto import tqdm 7 | from core.evaluator import P4GEvaluator 8 | from core.gen_models import OpenAIModel, OpenAIChatModel, AzureOpenAIModel, AzureOpenAIChatModel 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def main(args): 15 | if args.debug: 16 | logging.basicConfig(level=logging.DEBUG) 17 | logger.setLevel(logging.DEBUG) 18 | 19 | # backbone_model = OpenAIModel('text-davinci-003') 20 | if args.judge in ['gpt-3.5-turbo']: 21 | backbone_model = OpenAIChatModel(args.judge) 22 | elif args.judge == 'chatgpt': 23 | backbone_model = AzureOpenAIChatModel(args.judge) 24 | else: 25 | raise ValueError(f"unknown judge: {args.judge}") 26 | evaluator = P4GEvaluator(backbone_model) 27 | 28 | with open(args.f, 'rb') as f: 29 | data: list = pickle.load(f) 30 | h2h_data = [] 31 | if args.h2h: 32 | with open(args.h2h, 'rb') as f: 33 | h2h_data: list = pickle.load(f) 34 | assert(len(data) == len(h2h_data)) 35 | assert(args.output != '') # specify output path when doing h2h comparisons 36 | 37 | result = [] 38 | stats = { 39 | 'win': 0, # if b=new_resp is better than a=ori_resp 40 | 'draw': 0, 41 | 'lose': 0, 42 | } 43 | for i, d in tqdm(enumerate(data[:]), total=len(data), desc="evaluating"): 44 | context = d['context'] 45 | ori_resp = d['ori_resp'] 46 | new_resp = d['new_resp'] 47 | if len(h2h_data) > 0: 48 | ori_resp = h2h_data[i]['new_resp'] 49 | 50 | winner, info = evaluator.evaluate(context, ori_resp, new_resp) 51 | 52 | # update winners 53 | if winner == 0: 54 | stats['lose'] += 1 55 | elif winner == 1: 56 | stats['win'] += 1 57 | else: 58 | stats['draw'] += 1 59 | 60 | info['winner'] = winner 61 | result.append(info) 62 | 63 | # save 64 | if args.output != '': 65 | output_file = args.output 66 | else: 67 | output_folder = os.path.join(os.path.dirname(args.f), 'evaluation') 68 | output_filename = os.path.basename(args.f).replace('.pkl', '_evaluated.pkl') 69 | output_file = os.path.join(output_folder, output_filename) 70 | with open(output_file, 'wb') as f: 71 | pickle.dump(result, f) 72 | 73 | # statistics 74 | win_rate = stats['win'] / sum(stats.values()) 75 | print(f"win rate: {win_rate*100.0:.2f}%") 76 | print("stats: ", stats) 77 | return 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('-f', type=str, help='path to the data file for comparing against human in p4g. See P4GEvaluator documentation to see the format of the file.') 83 | parser.add_argument('--judge', type=str, default='gpt-3.5-turbo', help='which judge to use.', choices=['gpt-3.5-turbo', 'chatgpt']) 84 | parser.add_argument('--h2h', type=str, default='', help='path to the data file for head to head comparison. If empty compare against human in p4g.') 85 | parser.add_argument("--output", type=str, default='', help="output file") 86 | parser.add_argument("--debug", action='store_true', help="debug mode") 87 | args = parser.parse_args() 88 | 89 | main(args) -------------------------------------------------------------------------------- /utils/prompt_examples.py: -------------------------------------------------------------------------------- 1 | from core.game import PersuasionGame 2 | 3 | EXP_DIALOG = [ 4 | # extracted from 20180825-061105_792_live and 20180826-053845_531_live 5 | (PersuasionGame.SYS, PersuasionGame.S_Greeting, "Hello. How are you?",), 6 | (PersuasionGame.USR, PersuasionGame.U_Neutral, "I'm good, how are you doing?",), 7 | (PersuasionGame.SYS, PersuasionGame.S_TaskRelatedInquiry, "Very well. I'm just up organizing info for my charity called Save the Children. Have you heard of this charity berfore?",), 8 | (PersuasionGame.USR, PersuasionGame.U_Neutral, "No, I have not. Can you tell me more?",), 9 | (PersuasionGame.SYS, PersuasionGame.S_CredibilityAppeal, "Save the Children is an organization that helps children in developing countries, by promoting children's rights and providing relief. It is an amazing charity that helps kids who are in desperate need. They can help with safety, education and more.",), 10 | (PersuasionGame.USR, PersuasionGame.U_NegativeReaction, "That sounds great. I believe in this charity, but still wonder how much of the money I donate actually helps. I am always worried if I donate it will just go to some higer up that is living the high life.",), 11 | (PersuasionGame.SYS, PersuasionGame.S_EmotionAppeal, "Every little bit makes a difference. When you have people who are so poor, it's amazing what a tiny amount can do. I usually donate in hopes I can at least feel like I did my part. If I donated and some corrupt person took it, that's the worst karma and even worst scandal imaginable.",), 12 | (PersuasionGame.USR, PersuasionGame.U_PositiveReaction, "With that all said I do feel like any orginazation that aims to help the children I am more inclined to donate to them than most. I think helping children is an important thing as they are our future!",), 13 | (PersuasionGame.SYS, PersuasionGame.S_PropositionOfDonation,"I think donating to this cause would def be a step in the right direction to hopefully helping across the world the children that are in dispair. I don't want you to donate any more than you want, so if you want to donate how much do you to do?",), 14 | (PersuasionGame.USR, PersuasionGame.U_Donate, "I would donate 1 dollar to this charity and feel good about it I think.",), 15 | ] -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def set_determinitic_seed(seed): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | torch.cuda.manual_seed_all(seed) 10 | torch.use_deterministic_algorithms(True) 11 | torch.backends.cudnn.deterministic = True 12 | torch.backends.cudnn.benchmark = False 13 | return 14 | 15 | 16 | class dotdict(dict): 17 | def __getattr__(self, name): 18 | return self[name] 19 | 20 | 21 | class hashabledict(dict): 22 | def __hash__(self): 23 | return hash(tuple(sorted(self.items()))) --------------------------------------------------------------------------------