├── mdn ├── src │ ├── __init__.py │ └── blocks.py ├── examples │ ├── __init__.py │ ├── ex_1d.png │ └── ex_1d.py ├── .gitignore ├── LICENSE └── README.md ├── agent ├── agent_config.yaml ├── Online_LLM.py ├── Local_LLM.py ├── LLM.py ├── Model.py ├── Conversation.py └── llm_config.yaml ├── evaluation ├── camera_ready.txt ├── conversation_starter.txt ├── train_q_function_helper.py ├── train_offline_q_function.py ├── run_evaluation_singular.py └── run_evaluation.py ├── reward ├── embedding_length_reward ├── rewards_import.py ├── Embedding_Dummy_Reward.py ├── Human_Length_Reward.py ├── Base_Reward.py ├── Embedding_Length_Reward.py └── Llama_2_Guard_Reward.py ├── .gitignore ├── requirements.txt ├── monte_carlo_tree_search ├── policy.py ├── deep_agent.py ├── tabular_policy.py ├── qfunction.py ├── ucb.py ├── mdp.py ├── semantic_conversation_env.py ├── mcts.py ├── conversation_env.py ├── single_agent_mcts.py ├── qtable.py └── policy_agent.py ├── LICENSE ├── transition_models ├── regression_wrapper.py ├── embedding_model.py └── transition_model.py ├── train ├── embed_dataset.py └── train_transition.py └── README.md /mdn/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mdn/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mdn/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | -------------------------------------------------------------------------------- /agent/agent_config.yaml: -------------------------------------------------------------------------------- 1 | action_sample_count : 5 -------------------------------------------------------------------------------- /evaluation/camera_ready.txt: -------------------------------------------------------------------------------- 1 | Can you tell me about how conference locations are selected? -------------------------------------------------------------------------------- /evaluation/conversation_starter.txt: -------------------------------------------------------------------------------- 1 | Can you tell me something about Singapore, the place where ICLR 2025 is held? -------------------------------------------------------------------------------- /mdn/examples/ex_1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhiliang94/convo-plan-SCOPE/HEAD/mdn/examples/ex_1d.png -------------------------------------------------------------------------------- /reward/embedding_length_reward: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhiliang94/convo-plan-SCOPE/HEAD/reward/embedding_length_reward -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | **/__pycache__/ 3 | 4 | transition_models/**/*.pth 5 | transition_models/**/*.txt 6 | embeddings/ 7 | 8 | *.out -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | torch 4 | pandas 5 | tiktoken 6 | transformers 7 | datasets 8 | accelerate 9 | wandb 10 | mixture_of_experts 11 | openai 12 | bitsandbytes 13 | flash_attn 14 | lz4 15 | token_count 16 | scipy 17 | gdown -------------------------------------------------------------------------------- /reward/rewards_import.py: -------------------------------------------------------------------------------- 1 | from reward.Base_Reward import Base_Reward 2 | from reward.Human_Length_Reward import Human_Length_Reward 3 | from reward.Llama_2_Guard_Reward import Llama_2_Guard_Reward 4 | 5 | REWARD_CLASSES = [Human_Length_Reward, Llama_2_Guard_Reward] -------------------------------------------------------------------------------- /reward/Embedding_Dummy_Reward.py: -------------------------------------------------------------------------------- 1 | 2 | from reward.Base_Reward import Base_Reward 3 | import random 4 | 5 | # Reward function that returns random reward 6 | class Embedding_Dummy_Reward(Base_Reward): 7 | def get_reward(self, prev_state, action, human_response) -> float: 8 | return random.uniform(0.99,1) -------------------------------------------------------------------------------- /reward/Human_Length_Reward.py: -------------------------------------------------------------------------------- 1 | 2 | from reward.Base_Reward import Base_Reward 3 | from agent.Conversation import Conversation 4 | 5 | # Reward function that returns the length of the human response 6 | class Human_Length_Reward(Base_Reward): 7 | def get_reward(self, prev_state : Conversation, action : str, human_response : str) -> float: 8 | return 0.01*len(human_response) -------------------------------------------------------------------------------- /monte_carlo_tree_search/policy.py: -------------------------------------------------------------------------------- 1 | class Policy: 2 | def select_action(self, state): 3 | abstract 4 | 5 | 6 | class DeterministicPolicy(Policy): 7 | def update(self, state, action): 8 | abstract 9 | 10 | 11 | class StochasticPolicy(Policy): 12 | def update(self, states, actions, rewards): 13 | abstract 14 | 15 | def get_probability(self, state, action): 16 | abstract 17 | -------------------------------------------------------------------------------- /monte_carlo_tree_search/deep_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC 3 | 4 | class DeepAgent(ABC): 5 | 6 | """ Abstract class for deep learning agents which share common methods """ 7 | 8 | @staticmethod 9 | def encode_state(state): 10 | """ Turn the state into a tensor. """ 11 | if state == ("terminal", "terminal"): 12 | state = (-1, -1) 13 | return torch.as_tensor(state, dtype=torch.float32) 14 | -------------------------------------------------------------------------------- /monte_carlo_tree_search/tabular_policy.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from monte_carlo_tree_search.policy import DeterministicPolicy 4 | 5 | 6 | class TabularPolicy(DeterministicPolicy): 7 | def __init__(self, default_action=None): 8 | self.policy_table = defaultdict(lambda: default_action) 9 | 10 | def select_action(self, state): 11 | return self.policy_table[state] 12 | 13 | def update(self, state, action): 14 | self.policy_table[state] = action 15 | -------------------------------------------------------------------------------- /reward/Base_Reward.py: -------------------------------------------------------------------------------- 1 | 2 | def length_convo(convo): 3 | ''' 4 | assume convo is a list of sentence strings, more than size 2 5 | ''' 6 | 7 | cumulative_reward = 0.0 8 | for idx, sentence in enumerate(convo): 9 | if idx % 2 == 0: 10 | cumulative_reward += len(sentence) 11 | return cumulative_reward 12 | 13 | ''' 14 | Each reward function here receives a (convo_state, action, human_response) 15 | input and returns the IMMEDIATE reward/cost for performing a certain action and observing the human response. 16 | The cumulative sum should be derived by the function user. 17 | ''' 18 | 19 | from abc import ABC, abstractmethod 20 | from agent.Conversation import Conversation 21 | 22 | class Base_Reward(ABC): 23 | @abstractmethod 24 | def get_reward(prev_state : Conversation, action : str, human_response : str) -> float: 25 | pass -------------------------------------------------------------------------------- /agent/Online_LLM.py: -------------------------------------------------------------------------------- 1 | """Contains classes for querying OpenAI large language models.""" 2 | 3 | from typing import List 4 | from agent.LLM import LLM 5 | from openai import OpenAI 6 | 7 | class Online_LLM(LLM): 8 | def __init__(self, model_config, **kwargs): 9 | self.model_name = model_config["name"] 10 | self.tokenizer_has_system_prompt = True 11 | 12 | self.generation_config = model_config["generation_config"] 13 | self.system_prompt = model_config["sys_prompt"] 14 | 15 | self.client = OpenAI() 16 | 17 | def generate(self, chat : List[dict], **kwargs) -> List[str]: 18 | print("generating responses in chatgpt。。。") 19 | print(chat) 20 | output = self.client.chat.completions.create( 21 | model=self.model_name, 22 | messages=chat, 23 | **self.generation_config, 24 | **kwargs 25 | ) 26 | output = [i.message.content for i in output.choices] 27 | return output -------------------------------------------------------------------------------- /mdn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tony Duan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 chenzhiliang94 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /monte_carlo_tree_search/qfunction.py: -------------------------------------------------------------------------------- 1 | from monte_carlo_tree_search.tabular_policy import TabularPolicy 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | 5 | class QFunction: 6 | 7 | """ Update the Q-value of (state, action) by delta """ 8 | @abstractmethod 9 | def update(self, state, action, delta, visits, reward): 10 | pass 11 | 12 | """ Get a Q value for a given state-action pair """ 13 | @abstractmethod 14 | def get_q_value(self, state, action): 15 | pass 16 | 17 | """ Return a pair containing the action and Q-value, where the 18 | action has the maximum Q-value in state 19 | """ 20 | 21 | def get_qs(self, state, actions): 22 | qs = [] 23 | for action in actions: 24 | qs.append(self.get_q_value(state, action)) 25 | return qs 26 | 27 | def get_max_q(self, state, actions): 28 | qs = self.get_qs(state, actions) 29 | arg_max_q = np.argmax(qs) 30 | arg_max_q = actions[arg_max_q] 31 | max_q = qs[arg_max_q] 32 | return (arg_max_q, max_q) 33 | 34 | """ Extract a policy for this Q-function """ 35 | 36 | def extract_policy(self, mdp): 37 | policy = TabularPolicy() 38 | for state in mdp.get_states(): 39 | # Find the action with maximum Q-value and make this the 40 | (action, _) = self.get_max_q(state, mdp.get_actions(state)) 41 | policy.update(state, action) 42 | 43 | return policy 44 | -------------------------------------------------------------------------------- /monte_carlo_tree_search/ucb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | class MultiArmedBandit(): 5 | 6 | """ Select an action for this state given from a list given a Q-function """ 7 | 8 | def select(self, state, actions, qfunction): 9 | abstract 10 | 11 | """ Reset a multi-armed bandit to its initial configuration """ 12 | 13 | def reset(self): 14 | self.__init__() 15 | 16 | class UpperConfidenceBounds(MultiArmedBandit): 17 | def __init__(self): 18 | self.total = 0 19 | # number of times each action has been chosen 20 | self.times_selected = {} 21 | 22 | def select(self, state, actions, qfunction): 23 | 24 | # First execute each action one time 25 | for action in actions: 26 | if action not in self.times_selected.keys(): 27 | self.times_selected[action] = 1 28 | self.total += 1 29 | return action 30 | 31 | max_actions = [] 32 | max_value = float("-inf") 33 | for action in actions: 34 | value = 0.05 * qfunction.get_q_value(state, action) + math.sqrt( 35 | (2 * math.log(self.total)) / self.times_selected[action] 36 | ) 37 | if value > max_value: 38 | max_actions = [action] 39 | max_value = value 40 | elif value == max_value: 41 | max_actions += [action] 42 | 43 | # if there are multiple actions with the highest value 44 | # choose one randomly 45 | result = random.choice(max_actions) 46 | self.times_selected[result] = self.times_selected[result] + 1 47 | self.total += 1 48 | return result -------------------------------------------------------------------------------- /mdn/examples/ex_1d.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from matplotlib import pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | 9 | from src.blocks import MixtureDensityNetwork, NoiseType 10 | 11 | 12 | def gen_data(n=512): 13 | y = np.linspace(-1, 1, n) 14 | x = 7 * np.sin(5 * y) + 0.5 * y + 0.5 * np.random.randn(*y.shape) 15 | return x[:,np.newaxis], y[:,np.newaxis] 16 | 17 | def plot_data(x, y): 18 | plt.hist2d(x, y, bins=35) 19 | plt.xlim(-8, 8) 20 | plt.ylim(-1, 1) 21 | plt.axis('off') 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | argparser = ArgumentParser() 27 | argparser.add_argument("--n-iterations", type=int, default=2000) 28 | args = argparser.parse_args() 29 | 30 | logging.basicConfig(level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | x, y = gen_data() 34 | x = torch.Tensor(x) 35 | y = torch.Tensor(y) 36 | 37 | model = MixtureDensityNetwork(1, 1, n_components=3, hidden_dim=50, noise_type=NoiseType.DIAGONAL) 38 | optimizer = optim.Adam(model.parameters(), lr=0.005) 39 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_iterations) 40 | 41 | for i in range(args.n_iterations): 42 | optimizer.zero_grad() 43 | loss = model.loss(x, y).mean() 44 | loss.backward() 45 | optimizer.step() 46 | scheduler.step() 47 | if i % 100 == 0: 48 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}") 49 | 50 | with torch.no_grad(): 51 | y_hat = model.sample(x) 52 | 53 | plt.figure(figsize=(8, 3)) 54 | plt.subplot(1, 2, 1) 55 | plot_data(x[:, 0].numpy(), y[:, 0].numpy()) 56 | plt.title("Observed data") 57 | plt.subplot(1, 2, 2) 58 | plot_data(x[:, 0].numpy(), y_hat[:, 0].numpy()) 59 | plt.title("Sampled data") 60 | plt.show() 61 | -------------------------------------------------------------------------------- /transition_models/regression_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class RegressionWrapper(nn.Module): 5 | def __init__(self, model, embedding_size=1024): 6 | super(RegressionWrapper, self).__init__() 7 | self.add_module("model", model) 8 | self.input_mean = nn.Parameter(torch.zeros(embedding_size), requires_grad=False) 9 | self.input_std = nn.Parameter(torch.ones(embedding_size), requires_grad=False) 10 | self.output_mean = nn.Parameter(torch.zeros(embedding_size), requires_grad=False) 11 | self.output_std = nn.Parameter(torch.ones(embedding_size), requires_grad=False) 12 | self.use_residuals = nn.Parameter(torch.tensor(True, dtype=bool), requires_grad=False) 13 | 14 | def set_parameters(self, input_mean, input_std, output_mean, output_std, use_residuals): 15 | self.input_mean = nn.Parameter(input_mean, requires_grad=False) 16 | self.input_std = nn.Parameter(input_std, requires_grad=False) 17 | self.output_mean = nn.Parameter(output_mean, requires_grad=False) 18 | self.output_std = nn.Parameter(output_std, requires_grad=False) 19 | self.use_residuals = nn.Parameter(torch.tensor(use_residuals, dtype=bool), requires_grad=False) 20 | 21 | def scale_input(self, x): 22 | x = (x - self.input_mean) / self.input_std 23 | return x 24 | 25 | def scale_output(self, y): 26 | y = y * self.output_std + self.output_mean 27 | return y 28 | 29 | def forward(self, x): 30 | scaled_x = self.scale_input(x) 31 | y = self.model(scaled_x[:,None])[0][:,0,:] 32 | y = self.scale_output(y) 33 | if self.use_residuals: 34 | y = x + y 35 | return y 36 | 37 | def sample(self, x, samples_per_input=1): 38 | scaled_x = self.scale_input(x) 39 | y = self.model.sample(scaled_x, samples_per_input = samples_per_input) 40 | y = self.scale_output(y) 41 | if self.use_residuals: 42 | y = x + y 43 | return y -------------------------------------------------------------------------------- /evaluation/train_q_function_helper.py: -------------------------------------------------------------------------------- 1 | 2 | import sys, os 3 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 4 | 5 | from monte_carlo_tree_search.policy_agent import * 6 | from monte_carlo_tree_search.qtable import QTable 7 | from monte_carlo_tree_search.conversation_env import * 8 | from scipy import stats 9 | import numpy as np 10 | import torch 11 | import os.path 12 | import pandas as pd 13 | import random 14 | import tiktoken 15 | 16 | # with mcts 17 | def offline_train_q_function(conversation_starters, human, llm_agent, pretrained_q_function_name, timeout=100, search_depth=5): 18 | qfunction = DeepQFunction() 19 | for idx, conversation_starter in enumerate(conversation_starters): 20 | print("training index: ", idx) 21 | conversation_env = conversation_environment(human, llm_agent, conversation_starter, max_depth=search_depth) 22 | mcts = SingleAgentMCTS(conversation_env, qfunction, UpperConfidenceBounds()) 23 | mcts.mcts(timeout=timeout) 24 | qfunction = mcts.qfunction 25 | if idx % 10 == 0: 26 | print("saving model...") 27 | torch.save(qfunction, pretrained_q_function_name + str(len(conversation_starters))) # save after each training 28 | return qfunction 29 | 30 | # with just static convo 31 | def offline_train_q_function_static_conversation(conversations, pretrained_q_function_name, reward_function, num, cycle=1, terminating_steps=10): 32 | qfunction = DeepQFunction() 33 | 34 | # how many times to iterate through dataset 35 | for num_cycle in range(cycle): 36 | 37 | # learn for each convo in conversations 38 | for idx, convo in enumerate(conversations[:num]): 39 | cumulative_reward = 0 40 | if len(convo) > 2: 41 | cumulative_reward = reward_function(convo[2:]) 42 | state = conversation_state(convo[0], convo[0]) 43 | state.depth = 1 44 | qfunction.update(state, convo[1], 0, 1, cumulative_reward) 45 | 46 | if idx % 10 == 0: 47 | print("saving model...") 48 | torch.save(qfunction, pretrained_q_function_name) # save after each a few training loop -------------------------------------------------------------------------------- /train/embed_dataset.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | 4 | from reward.rewards_import import * 5 | import torch 6 | import datasets 7 | from tqdm import tqdm 8 | import sys 9 | import numpy as np 10 | import os 11 | 12 | if __name__ == "__main__": 13 | 14 | dataset = datasets.load_dataset("lmsys/lmsys-chat-1m", split=f"train", streaming=True) 15 | 16 | start = 0 17 | end = 1_000_000 18 | # end = 2000 19 | 20 | print("starting at", start, "ending at", end) 21 | 22 | # Random projection to reduce the dimension of the embedding vectors, set to None to use the original 8192 dim vectors 23 | # random_projection = None 24 | random_projection = 1024 25 | 26 | reward = Llama_2_Guard_Reward(random_projection=random_projection) 27 | 28 | embeddings = [] 29 | failed = [] 30 | for i, conversation in zip(tqdm(range(start, end)), iter(dataset)): 31 | conversation = conversation['conversation'] 32 | e = [] 33 | for j in range(len(conversation)): 34 | try: 35 | e.append(reward.embed(conversation[:j+1]).half()) 36 | except: 37 | # Failed likely due to context length exceeding GPU VRAM, to be dealt with separately 38 | failed.append(i) 39 | break 40 | if len(e) == 0: 41 | embeddings.append(torch.zeros(0, random_projection, dtype=torch.float16)) 42 | else: 43 | embeddings.append(torch.stack(e)) 44 | 45 | # Save as float16 to save space 46 | dataset = datasets.Dataset.from_dict( 47 | { 48 | 'embeddings': embeddings, 49 | }, 50 | features=datasets.Features( 51 | { 52 | 'embeddings': datasets.Array2D(shape=(None, random_projection), dtype='float16') 53 | } 54 | ) 55 | ) 56 | 57 | os.makedirs(f"embeddings", exist_ok=True) 58 | dataset.save_to_disk(f"embeddings/lmsys-chat-1m_embeddings_{random_projection}") 59 | 60 | # Save failed indices to deal with separately (Run on GPU with more VRAM) 61 | np.save(f"embeddings/lmsys-chat-1m_embeddings_{random_projection}/failed.npy", np.array(failed, dtype=int)) -------------------------------------------------------------------------------- /agent/Local_LLM.py: -------------------------------------------------------------------------------- 1 | """Contains classes for querying local large language models.""" 2 | 3 | import torch 4 | from typing import List 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from agent.LLM import LLM, ICL_prompt 7 | 8 | class Local_LLM(LLM): 9 | def __init__(self, model_config, model = None, tokenizer = None): 10 | # self.model_name = model_config["name"] 11 | if tokenizer is None: 12 | self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_config"]["pretrained_model_name_or_path"]) 13 | else: 14 | self.tokenizer = tokenizer 15 | if model is None: 16 | print(f'Loading model {model_config["model_config"]["pretrained_model_name_or_path"]} on device {model_config["model_config"]["device_map"]} in Local_LLM...') 17 | self.model = AutoModelForCausalLM.from_pretrained(torch_dtype=torch.bfloat16, **model_config["model_config"]) 18 | else: 19 | self.model = model 20 | if self.model.generation_config.pad_token_id is None: 21 | self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id 22 | 23 | try: 24 | self.tokenizer.apply_chat_template([{"role":"system","content":""}]) 25 | self.tokenizer_has_system_prompt = True 26 | except: 27 | self.tokenizer_has_system_prompt = False 28 | 29 | self.generation_config = model_config["generation_config"] 30 | self.system_prompt = ICL_prompt(model_config) 31 | 32 | def generate(self, chat : List[dict], **kwargs) -> List[str]: 33 | tokens = self.tokenizer.apply_chat_template( 34 | chat, 35 | tokenize = True, add_generation_prompt = True, return_tensors = "pt", return_attention_mask = True, return_dict = True 36 | ).to(self.model.device) 37 | with torch.no_grad(): 38 | output = self.model.generate(input_ids=tokens["input_ids"], attention_mask=tokens["attention_mask"], **self.generation_config, **kwargs) 39 | output = output[:, tokens["input_ids"].shape[-1]:] # Only return generated tokens 40 | decoded_output = self.tokenizer.batch_decode(output, skip_special_tokens=True) 41 | decoded_output = (list(set(decoded_output))) # remove duplicates 42 | print("generated LLM output after removing duplicate: ", len(decoded_output)) 43 | return decoded_output -------------------------------------------------------------------------------- /evaluation/train_offline_q_function.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | 4 | from monte_carlo_tree_search.policy_agent import * 5 | from monte_carlo_tree_search.qtable import QTable 6 | from monte_carlo_tree_search.conversation_env import * 7 | from evaluation.train_q_function_helper import * 8 | from scipy import stats 9 | import numpy as np 10 | import torch 11 | import os.path 12 | import pandas as pd 13 | import random 14 | 15 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 16 | 17 | # Parse command line arguments 18 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 19 | parser.add_argument("--data", help="data-set name. i.e name of the parquet file") 20 | parser.add_argument("--number", help="number of conversation starters") 21 | args = vars(parser.parse_args()) 22 | 23 | data_path = args["data"] 24 | n = int(args["number"]) 25 | 26 | # mcts 27 | if data_path == "daily_dialogue": 28 | # create the llm and human simulator 29 | human, llm_agent = create_human_and_llm() 30 | 31 | # check if Q-function pretrained) exists, if not, train one offline with some conversation starters, 32 | pretrained_q_function_name = "trained_q_function_" + str(data_path) 33 | conversation_data = pd.read_parquet('daily_dialogue.parquet', engine='auto') 34 | conversation_starters = [x[0] for x in list(conversation_data['dialog'])] 35 | conversation_starters = random.choices(conversation_starters, k=n) 36 | 37 | pretraining_mcts_timeout = 500 # how long to run simulation 38 | pretraining_depth = 8 # how deep to run mcts 39 | 40 | q_function_offline_learnt = offline_train_q_function(conversation_starters, human, llm_agent, pretrained_q_function_name, timeout=pretraining_mcts_timeout, search_depth=pretraining_depth) 41 | torch.save(q_function_offline_learnt, pretrained_q_function_name + str(len(conversation_starters))) 42 | 43 | if data_path == "daily_dialogue_static": 44 | 45 | pretrained_q_function_name = "trained_q_function_STATIC_" + str(data_path) 46 | conversation_data = pd.read_parquet('daily_dialogue.parquet', engine='auto') 47 | conversation_starters = [x for x in list(conversation_data['dialog'])] 48 | 49 | q_function_offline_learnt = offline_train_q_function_static_conversation(conversation_starters, pretrained_q_function_name, length_convo, n) 50 | 51 | torch.save(q_function_offline_learnt, pretrained_q_function_name) -------------------------------------------------------------------------------- /agent/LLM.py: -------------------------------------------------------------------------------- 1 | """Contains classes for querying large language models.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from agent.Conversation import Conversation 5 | from typing import List 6 | 7 | class LLM(ABC): 8 | system_prompt = "" 9 | 10 | @abstractmethod 11 | def generate(self, chat : List[dict]) -> List[str]: 12 | pass 13 | 14 | # Create chat and add add system prompt 15 | def apply_chat_format(self, convo : Conversation, **kwargs) -> List[dict]: 16 | chat = convo.create_chat() 17 | 18 | if isinstance(self.system_prompt, ICL_prompt): 19 | if "context" in kwargs and kwargs["context"] is not None: 20 | self.system_prompt.set_context(kwargs["context"]) 21 | elif "human_description" in kwargs and kwargs["human_description"] is not None: 22 | self.system_prompt.set_context(kwargs["human_description"]) 23 | 24 | if len(self.system_prompt) > 0: 25 | if chat[0]["role"] == "assistant": 26 | chat[0]["content"] = self.system_prompt + "\n\n" + chat[0]["content"] 27 | chat = chat[1:] 28 | elif self.tokenizer_has_system_prompt: 29 | chat.insert(0,{"role": "system", "content":str(self.system_prompt)}) 30 | # else: 31 | # chat.insert(0,{"role": "assistant", "content":self.system_prompt}) 32 | 33 | if isinstance(self.system_prompt, ICL_prompt): 34 | self.system_prompt.clear_context() 35 | 36 | return chat 37 | 38 | class ICL_prompt: 39 | def __init__(self, model_config): 40 | self.use_icl = "sys_prompt_pre" in model_config and "sys_prompt_context" in model_config and "sys_prompt_post" in model_config 41 | if not self.use_icl: 42 | self.sys_prompt = model_config["sys_prompt"] 43 | else: 44 | self.sys_prompt_pre = model_config["sys_prompt_pre"] 45 | self.sys_prompt_context = model_config["sys_prompt_context"] 46 | self.sys_prompt_post = model_config["sys_prompt_post"] 47 | self.context = "" 48 | 49 | def set_context(self, context): 50 | if not self.use_icl: 51 | return 52 | context = "\n".join(context) 53 | self.context = self.sys_prompt_context + "\n\n" + context + "\n\n" 54 | 55 | def clear_context(self): 56 | self.context = "" 57 | 58 | def __str__(self): 59 | if not self.use_icl: 60 | return self.sys_prompt 61 | if self.context == "": 62 | return self.sys_prompt_pre + self.sys_prompt_post 63 | return self.sys_prompt_pre + self.context + self.sys_prompt_post 64 | 65 | def __len__(self): 66 | return len(str(self)) -------------------------------------------------------------------------------- /reward/Embedding_Length_Reward.py: -------------------------------------------------------------------------------- 1 | 2 | from reward.Base_Reward import Base_Reward 3 | import torch 4 | import torch.nn as nn 5 | from token_count import TokenCount 6 | 7 | from agent.Conversation import Conversation 8 | class MLPRegression(nn.Module): 9 | def __init__(self): 10 | super(MLPRegression, self).__init__() 11 | self.fc1 = nn.Linear(1024, 512) 12 | self.fc2 = nn.Linear(512, 256) 13 | self.fc3 = nn.Linear(256, 64) 14 | self.fc4 = nn.Linear(64, 32) 15 | self.fc5 = nn.Linear(32, 1) 16 | 17 | def forward(self, x): 18 | x = torch.relu(self.fc1(x)) 19 | x = torch.relu(self.fc2(x)) 20 | x = torch.relu(self.fc3(x)) 21 | x = torch.relu(self.fc4(x)) 22 | x = (self.fc5(x)) 23 | return x 24 | 25 | # Reward function that returns random reward 26 | class Embedding_Length_Reward(Base_Reward): 27 | 28 | def __init__(self, add_llm_length : bool, path_to_model="reward/embedding_length_reward", device_map=0) -> None: 29 | super().__init__() 30 | print(f"Loading embedding length model on device {device_map}...") 31 | self.model = MLPRegression() 32 | self.model.load_state_dict(torch.load(path_to_model, map_location=torch.device(device_map))) 33 | self.add_llm_length = add_llm_length 34 | print("length model initialized with add_llm_length: ", self.add_llm_length) 35 | 36 | def get_reward(self, prev_state : tuple | str | Conversation, action : tuple | str, human_response : tuple | str | None) -> float: 37 | if isinstance(prev_state, Conversation): 38 | prev_state = str(prev_state) 39 | # for last step evaluation 40 | if human_response is None: 41 | if isinstance(action, str): 42 | return self.get_tokens_from_str(action) 43 | else: 44 | with torch.no_grad(): 45 | reward = self.model(torch.FloatTensor(prev_state) + torch.FloatTensor(action)) - self.model(torch.FloatTensor(prev_state)) 46 | print("reward from embedding length: ", reward) 47 | return reward * 10 48 | 49 | # if instance is string. its during evaluation and just response length 50 | if isinstance(human_response, str): 51 | if self.add_llm_length: 52 | return self.get_tokens_from_str(human_response) + self.get_tokens_from_str(action) 53 | else: 54 | return self.get_tokens_from_str(human_response) 55 | 56 | # if not string, human response length is in semantic space. So take difference. 57 | with torch.no_grad(): 58 | if self.add_llm_length: 59 | reward = self.model(torch.FloatTensor(human_response)) - self.model(torch.FloatTensor(prev_state)) 60 | else: 61 | reward = self.model(torch.FloatTensor(human_response)) - self.model(torch.FloatTensor(prev_state) + torch.FloatTensor(action)) 62 | print("reward from embedding length: ", reward) 63 | return reward * 10 64 | 65 | def get_tokens_from_str(self, convo : str) -> float: 66 | tc = TokenCount(model_name="gpt-3.5-turbo") 67 | return tc.num_tokens_from_string(convo)/100 -------------------------------------------------------------------------------- /transition_models/embedding_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | from transformers import AutoTokenizer, AutoModel 5 | from reward.Llama_2_Guard_Reward import Llama_2_Guard_Reward 6 | from typing import List 7 | from agent.Conversation import Conversation 8 | 9 | class embedding_model_mistral(): 10 | def __init__(self, tokenizer, model, to_normalize = False, cuda = torch.device('cuda:5')) -> None: 11 | self.tokenizer = tokenizer 12 | self.model = model 13 | self.output_dim = 4096 14 | self.cuda = cuda 15 | self.to_normalize = to_normalize 16 | 17 | def last_token_pool(self, last_hidden_states: Tensor, 18 | attention_mask: Tensor) -> Tensor: 19 | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) 20 | if left_padding: 21 | return last_hidden_states[:, -1] 22 | else: 23 | sequence_lengths = attention_mask.sum(dim=1) - 1 24 | batch_size = last_hidden_states.shape[0] 25 | return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] 26 | 27 | # input is a single text string 28 | def embed(self, text): 29 | 30 | with torch.no_grad(): 31 | batch_dict = self.tokenizer([str(text)], max_length=self.output_dim, padding=True, truncation=True, return_tensors="pt").to(self.cuda) 32 | outputs = self.model(**batch_dict) 33 | embeddings = self.last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) 34 | 35 | if self.to_normalize: 36 | return F.normalize(embeddings, p=2, dim=1)[0] 37 | return embeddings[0] 38 | 39 | class embedding_model_nomic(): 40 | def __init__(self, tokenizer, model, to_normalize=False, cuda = torch.device('cuda:5')) -> None: 41 | self.tokenizer = tokenizer 42 | self.model = model 43 | self.output_dim = 768 44 | self.cuda = cuda 45 | self.to_normalize = to_normalize 46 | 47 | def mean_pooling(self, model_output, attention_mask): 48 | token_embeddings = model_output[0] 49 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 50 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 51 | 52 | # input is a single text string 53 | def embed(self, text): 54 | # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', model_max_length=8192) 55 | # model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1', trust_remote_code=True, rotary_scaling_factor=2) 56 | self.model.eval() 57 | 58 | encoded_input = self.tokenizer([str(text)], padding=True, truncation=True, return_tensors='pt').to(self.cuda) 59 | 60 | with torch.no_grad(): 61 | model_output = self.model(**encoded_input) 62 | 63 | embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) 64 | 65 | if self.to_normalize: 66 | return F.normalize(embeddings, p=2, dim=1)[0] 67 | return embeddings[0] 68 | 69 | #embeddings = F.normalize(embeddings, p=2, dim=1) 70 | 71 | class embedding_model_llama(Llama_2_Guard_Reward): 72 | def __init__(self, tokenizer = None, model = None, output_dim = 1024, to_normalize=None, cuda = torch.device('cuda:5'), seed = 42) -> None: 73 | super().__init__( 74 | model = model, device_map = cuda, random_projection = output_dim, random_proj_seed = seed 75 | ) 76 | self.output_dim = output_dim -------------------------------------------------------------------------------- /monte_carlo_tree_search/mdp.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class MDP: 4 | """Return all states of this MDP""" 5 | 6 | def get_states(self): 7 | abstract 8 | 9 | """ Return all actions with non-zero probability from this state """ 10 | 11 | def get_actions(self, state): 12 | abstract 13 | 14 | """ Return all non-zero probability transitions for this action 15 | from this state, as a list of (state, probability) pairs 16 | """ 17 | 18 | def get_transitions(self, state, action): 19 | abstract 20 | 21 | """ Return the reward for transitioning from state to 22 | nextState via action 23 | """ 24 | 25 | def get_reward(self, state, action, next_state): 26 | abstract 27 | 28 | """ Return true if and only if state is a terminal state of this MDP """ 29 | 30 | def is_terminal(self, state): 31 | abstract 32 | 33 | """ Return the discount factor for this MDP """ 34 | 35 | def get_discount_factor(self): 36 | abstract 37 | 38 | """ Return the initial state of this MDP """ 39 | 40 | def get_initial_state(self): 41 | abstract 42 | 43 | """ Return all goal states of this MDP """ 44 | 45 | def get_goal_states(self): 46 | abstract 47 | 48 | """ Return a new state and a reward for executing action in state, 49 | based on the underlying probability. This can be used for 50 | model-free learning methods, but requires a model to operate. 51 | Override for simulation-based learning 52 | """ 53 | 54 | def execute(self, state, action): 55 | rand = random.random() 56 | cumulative_probability = 0.0 57 | for (new_state, probability) in self.get_transitions(state, action): 58 | if cumulative_probability <= rand <= probability + cumulative_probability: 59 | return (new_state, self.get_reward(state, action, new_state)) 60 | cumulative_probability += probability 61 | if cumulative_probability >= 1.0: 62 | raise ( 63 | "Cumulative probability >= 1.0 for action " 64 | + str(action) 65 | + " from " 66 | + str(state) 67 | ) 68 | 69 | raise BaseException( 70 | "No outcome state in simulation for action" 71 | + str(action) 72 | + " from " 73 | + str(state) 74 | ) 75 | 76 | """ 77 | Execute a policy on this mdp for a number of episodes. 78 | Return the cumulative reward of each episode as a list. 79 | When True, random_on_duplicate detects when a state has been visited before, and selects a random action to avoid infinitely looping policies. 80 | """ 81 | 82 | def execute_policy(self, policy, episodes=100, random_on_duplicate=False): 83 | cumulative_rewards = [] 84 | states = set() 85 | for _ in range(episodes): 86 | cumulative_reward = 0.0 87 | state = self.get_initial_state() 88 | step = 0 89 | while not self.is_terminal(state): 90 | if state in states and random_on_duplicate: 91 | action = random.choice(self.get_actions(state)) 92 | else: 93 | action = policy.select_action(state) 94 | if random_on_duplicate: states.add(state) 95 | 96 | (next_state, reward) = self.execute(state, action) 97 | cumulative_reward += reward * (self.discount_factor ** step) 98 | state = next_state 99 | step += 1 100 | cumulative_rewards += [cumulative_reward] 101 | return cumulative_rewards 102 | -------------------------------------------------------------------------------- /agent/Model.py: -------------------------------------------------------------------------------- 1 | 2 | import yaml 3 | 4 | from typing import List, Tuple 5 | from tqdm import tqdm 6 | from agent.Local_LLM import Local_LLM 7 | from agent.Online_LLM import Online_LLM 8 | from agent.Conversation import Conversation, HUMAN_SIM, HUMAN_EVAL, LLM, get_role 9 | 10 | DEBUG = False 11 | 12 | class Model: 13 | """Abstract base class for large language models.""" 14 | 15 | def __init__(self, role, config, needs_confirmation=False, disable_tqdm=True, 16 | model=None, tokenizer=None): 17 | """Initializes the model.""" 18 | self.role = role 19 | self.config = config 20 | self.needs_confirmation = needs_confirmation 21 | self.disable_tqdm = disable_tqdm 22 | 23 | # Initialise human model 24 | if self.config["type"] == "local": 25 | LLM_class = Local_LLM 26 | else: 27 | LLM_class = Online_LLM 28 | self.model = LLM_class( 29 | self.config, 30 | model = model, 31 | tokenizer = tokenizer, 32 | ) 33 | tqdm.write(f'Initialized {get_role(role)} as {self.config["type"]} model: {self.config["model_config"]["pretrained_model_name_or_path"]}.') 34 | 35 | def sample_actions(self, prompt : Conversation, **kwargs) -> List[str]: 36 | # convo = Conversation.from_delimited_string(prompt) 37 | convo = prompt 38 | return self.generate_text(convo, **kwargs) 39 | 40 | def generate_text(self, convos : Conversation | List[Conversation], batch = False, **kwargs) -> List[str] | List[List[str]]: 41 | """Generates text from the model. 42 | Parameters: 43 | convos: The prompt to use. List of Conversation. 44 | Returns: 45 | A list of list of strings. 46 | """ 47 | convos_is_list = isinstance(convos, list) 48 | if not convos_is_list: 49 | convos = [convos] 50 | 51 | chats : List[List[dict]] = [] 52 | # Create prompts from converstation histories 53 | for convo in convos: 54 | chat = self.model.apply_chat_format(convo, **kwargs) 55 | chats.append(chat) 56 | if DEBUG: 57 | print("generated prompts") 58 | print(chats) 59 | 60 | generated_text = [] 61 | 62 | if batch: 63 | raise NotImplementedError 64 | else: 65 | if not self.disable_tqdm: 66 | chats = tqdm(chats) 67 | for chat in chats: 68 | output = self.model.generate(chat) 69 | generated_text.append(output) 70 | if not convos_is_list: 71 | generated_text = generated_text[0] 72 | return generated_text 73 | 74 | # Create human_sim, human_eval, and llm_model 75 | def create_human_and_llm(config="agent/llm_config.yaml",human_sim_to_use="human_sim", human_eval_to_use="human_eval", llm_model_to_use ="llm_model", cuda = 0,**kwargs) -> List[Model]: 76 | with open(config, "r") as f: 77 | llm_config = yaml.full_load(f) 78 | llm_config[llm_model_to_use]["model_config"]["device_map"] = cuda 79 | llm_config[human_sim_to_use]["model_config"]["device_map"] = cuda 80 | llm_config[human_eval_to_use]["model_config"]["device_map"] = cuda 81 | models = [] 82 | models_to_use = [human_sim_to_use, human_eval_to_use, llm_model_to_use] 83 | for model, model_type in zip(models_to_use, [HUMAN_SIM, HUMAN_EVAL, LLM]): 84 | m = None 85 | for j, prev_model in enumerate(models): # Reuse the same model if possible 86 | if not isinstance(prev_model.model, Local_LLM): 87 | continue 88 | if llm_config[model]["model_config"]["pretrained_model_name_or_path"] == llm_config[models_to_use[j]]["model_config"]["pretrained_model_name_or_path"]: 89 | m = prev_model.model.model 90 | break 91 | models.append(Model(model_type, llm_config[model], model=m, **kwargs)) 92 | human_sim, human_eval, llm_model = models 93 | return human_sim, human_eval, llm_model -------------------------------------------------------------------------------- /mdn/README.md: -------------------------------------------------------------------------------- 1 | ### MODIFIED FROM [https://github.com/tonyduan/mixture-density-network](https://github.com/tonyduan/mixture-density-network) 2 | 3 | ### Mixture Density Network 4 | 5 | Last update: December 2022. 6 | 7 | --- 8 | 9 | Lightweight implementation of a mixture density network [1] in PyTorch. 10 | 11 | #### Setup 12 | 13 | Suppose we want to regress response $\mathbf{y} \in \mathbb{R}^{d}$ using covariates $\mathbf{x} \in \mathbb{R}^n$. 14 | 15 | We model the conditional distribution as a mixture of Gaussians 16 | ```math 17 | p_\theta(\mathbf{y}|\mathbf{x}) = \sum_{k=1}^K \pi_k N(\boldsymbol\mu^{(k)}, {\boldsymbol\Sigma}^{(k)}), 18 | ``` 19 | where the mixture distribution parameters are output by a neural network dependent on $\mathbf{x}$. 20 | ```math 21 | \begin{align*} 22 | ( \boldsymbol\pi & \in\Delta^{K-1} & \boldsymbol\mu^{(k)}&\in\mathbb{R}^{d} &\boldsymbol\Sigma^{(k)}&\in \mathrm{S}_+^d) = f_\theta(\mathbf{x}) 23 | \end{align*} 24 | ``` 25 | The training objective is to maximize log-likelihood. The objective is clearly non-convex. 26 | ```math 27 | \begin{align*} 28 | \log p_\theta(\mathbf{y}|\mathbf{x}) 29 | & \propto\log \sum_{k}\left(\pi_k\exp\left(-\frac{1}{2}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right)^\top {\boldsymbol\Sigma^{(k)}}^{-1}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right) -\frac{1}{2}\log\det \boldsymbol\Sigma^{(k)}\right)\right)\\ 30 | & = \mathrm{logsumexp}_k\left(\log\pi_k - \frac{1}{2}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right)^\top {\boldsymbol\Sigma^{(k)}}^{-1}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right) -\frac{1}{2}\log\det \boldsymbol\Sigma^{(k)}\right)\\ 31 | \end{align*} 32 | ``` 33 | Importantly, we need to use `torch.log_softmax(...)` to compute logits $\log \boldsymbol\pi$ for numerical stability. 34 | 35 | #### Noise Model 36 | 37 | There are several options we can make to constrain the noise model $\boldsymbol\Sigma^{(k)}$. 38 | 39 | 1. No assumptions, $\boldsymbol\Sigma^{(k)} \in \mathrm{S}_+^d$. 40 | 2. Fully factored, let $\boldsymbol\Sigma^{(k)} = \mathrm{diag}({\boldsymbol\sigma^{(k)}}^{2}), {\boldsymbol\sigma^{(k)}}^{2}\in\mathbb{R}_+^d$ where the noise level for each dimension is predicted separately. 41 | 3. Isotrotopic, let $\boldsymbol\Sigma^{(k)} = {\sigma^{(k)}}^{2}\mathbf{I}, {\sigma^{(k)}}^{2}\in\mathbb{R}_+$ which assumes the same noise level for each dimension over $d$. 42 | 4. Isotropic across clusters, let $\boldsymbol\Sigma^{(k)} = \sigma^2\mathbf{I}, \sigma^2\in\mathbb{R}_+$ which assumes the same noise level for each dimension over $d$ *and* cluster. 43 | 5. Fixed isotropic, same as above but do not learn $\sigma^2$. 44 | 45 | Thse correspond to the following objectives. 46 | ```math 47 | \begin{align*} 48 | \log p_\theta(\mathbf{y}|\mathbf{x}) & = \mathrm{logsumexp}_k\left(\log\pi_k - \frac{1}{2}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right)^\top {\boldsymbol\Sigma^{(k)}}^{-1}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right) -\frac{1}{2}\log\det \boldsymbol\Sigma^{(k)}\right) \tag{1}\\ 49 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\boldsymbol\sigma^{(k)}}\right\|^2-\|\log\boldsymbol\sigma^{(k)}\|_1\right) \tag{2}\\ 50 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\sigma^{(k)}}\right\|^2-d\log(\sigma^{(k)})\right) \tag{3}\\ 51 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\sigma}\right\|^2-d\log(\sigma)\right) \tag{4}\\ 52 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\sigma}\right\|^2\right) \tag{5} 53 | \end{align*} 54 | ``` 55 | In this repository we implement options (2, 3, 4, 5). 56 | 57 | #### Miscellaneous 58 | 59 | Recall that the objective is clearly non-convex. For example, one local minimum is to ignore all modes except one and place a single diffuse Gaussian distribution on the marginal outcome (i.e. high ${\sigma}^{(k)}$). 60 | 61 | For this reason it's often preferable to over-parameterize the model and specify `n_components` higher than the true hypothesized number of modes. 62 | 63 | #### Usage 64 | 65 | ```python 66 | import torch 67 | from src.blocks import MixtureDensityNetwork 68 | 69 | x = torch.randn(5, 1) 70 | y = torch.randn(5, 1) 71 | 72 | # 1D input, 1D output, 3 mixture components 73 | model = MixtureDensityNetwork(1, 1, n_components=3, hidden_dim=50) 74 | pred_parameters = model(x) 75 | 76 | # use this to backprop 77 | loss = model.loss(x, y) 78 | 79 | # use this to sample a trained model 80 | samples = model.sample(x) 81 | ``` 82 | 83 | For further details see the `examples/` folder. Below is a model fit with 3 components in `ex_1d.py`. 84 | 85 | ![ex_model](examples/ex_1d.png "Example model output") 86 | 87 | #### References 88 | 89 | [1] Bishop, C. M. Mixture density networks. (1994). 90 | 91 | [2] Ha, D. & Schmidhuber, J. Recurrent World Models Facilitate Policy Evolution. in *Advances in Neural Information Processing Systems 31* (eds. Bengio, S. et al.) 2450–2462 (Curran Associates, Inc., 2018). 92 | 93 | #### License 94 | 95 | This code is available under the MIT License. 96 | -------------------------------------------------------------------------------- /mdn/src/blocks.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class NoiseType(Enum): 9 | FULL = auto() 10 | DIAGONAL = auto() 11 | ISOTROPIC = auto() 12 | ISOTROPIC_ACROSS_CLUSTERS = auto() 13 | FIXED = auto() 14 | 15 | 16 | class MixtureDensityNetwork(nn.Module): 17 | """ 18 | Mixture density network. 19 | 20 | [ Bishop, 1994 ] 21 | 22 | Parameters 23 | ---------- 24 | dim_in: int; dimensionality of the covariates 25 | dim_out: int; dimensionality of the response variable 26 | n_components: int; number of components in the mixture model 27 | """ 28 | def __init__(self, dim_in, dim_out, n_components, hidden_dim, noise_type=NoiseType.DIAGONAL, fixed_noise_level=None): 29 | super(MixtureDensityNetwork, self).__init__() 30 | assert (fixed_noise_level is not None) == (noise_type is NoiseType.FIXED) 31 | num_sigma_channels = { 32 | NoiseType.FULL: int(dim_out * (dim_out + 1) / 2) * n_components, 33 | NoiseType.DIAGONAL: dim_out * n_components, 34 | NoiseType.ISOTROPIC: n_components, 35 | NoiseType.ISOTROPIC_ACROSS_CLUSTERS: 1, 36 | NoiseType.FIXED: 0, 37 | }[noise_type] 38 | self.dim_in, self.dim_out, self.n_components = dim_in, dim_out, n_components 39 | self.noise_type, self.fixed_noise_level = noise_type, fixed_noise_level 40 | self.pi_network = nn.Sequential( 41 | nn.Linear(dim_in, hidden_dim), 42 | nn.ELU(), 43 | nn.BatchNorm1d(hidden_dim), 44 | nn.Linear(hidden_dim, hidden_dim), 45 | nn.ELU(), 46 | nn.BatchNorm1d(hidden_dim), 47 | nn.Linear(hidden_dim, n_components), 48 | ) 49 | self.normal_network = nn.Sequential( 50 | nn.Linear(dim_in, hidden_dim), 51 | nn.ELU(), 52 | nn.BatchNorm1d(hidden_dim), 53 | nn.Linear(hidden_dim, hidden_dim), 54 | nn.ELU(), 55 | nn.BatchNorm1d(hidden_dim), 56 | nn.Linear(hidden_dim, dim_out * n_components + num_sigma_channels) 57 | ) 58 | 59 | self.upper_triangular_mask = (torch.triu(torch.ones(dim_out, dim_out)) == 1).expand(n_components, -1, -1) 60 | 61 | def forward(self, x, eps=1e-6): 62 | # 63 | # Returns 64 | # ------- 65 | # log_pi: (bsz, n_components) 66 | # mu: (bsz, n_components, dim_out) 67 | # sigma: (bsz, n_components, dim_out) 68 | # 69 | log_pi = torch.log_softmax(self.pi_network(x), dim=-1) 70 | normal_params = self.normal_network(x) 71 | mu = normal_params[..., :self.dim_out * self.n_components] 72 | sigma = normal_params[..., self.dim_out * self.n_components:] 73 | mu = mu.reshape(-1, self.n_components, self.dim_out) 74 | if self.noise_type is NoiseType.FULL: 75 | sigma_mat = torch.empty((len(x), self.n_components, self.dim_out, self.dim_out)).to(x) 76 | sigma_mat[self.upper_triangular_mask.expand(len(x),-1,-1,-1)] = sigma 77 | sigma_mat.T[self.upper_triangular_mask.expand(len(x),-1,-1,-1)] = sigma 78 | sigma = torch.exp(sigma_mat + eps) 79 | return log_pi, mu, sigma 80 | if self.noise_type is NoiseType.DIAGONAL: 81 | sigma = torch.exp(sigma + eps) 82 | if self.noise_type is NoiseType.ISOTROPIC: 83 | sigma = torch.exp(sigma + eps).repeat(1, self.dim_out) 84 | if self.noise_type is NoiseType.ISOTROPIC_ACROSS_CLUSTERS: 85 | sigma = torch.exp(sigma + eps).repeat(1, self.n_components * self.dim_out) 86 | if self.noise_type is NoiseType.FIXED: 87 | sigma = torch.full_like(mu, fill_value=self.fixed_noise_level) 88 | sigma = sigma.reshape(-1, self.n_components, self.dim_out) 89 | return log_pi, mu, sigma 90 | 91 | def loss(self, x, y): 92 | log_pi, mu, sigma = self.forward(x) 93 | if self.noise_type is NoiseType.FULL: 94 | diff = (y.unsqueeze(1) - mu).unsqueeze(-1) 95 | z_score = diff.transpose(-1,-2)@torch.linalg.solve(sigma, diff).squeeze() 96 | normal_loglik = -0.5 * z_score - torch.logdet(sigma) 97 | else: 98 | z_score = (y.unsqueeze(1) - mu) / sigma 99 | normal_loglik = ( 100 | -0.5 * torch.einsum("bij,bij->bi", z_score, z_score) 101 | -torch.sum(torch.log(sigma), dim=-1) 102 | ) 103 | loglik = torch.logsumexp(log_pi + normal_loglik, dim=-1) 104 | return -loglik 105 | 106 | def sample(self, x, samples_per_input=1): 107 | log_pi, mu, sigma = self.forward(x) 108 | cum_pi = torch.cumsum(torch.exp(log_pi), dim=-1) 109 | rvs = torch.rand([*x.shape[:-1], samples_per_input]).to(x) 110 | rand_pi = torch.searchsorted(cum_pi, rvs).unsqueeze(-1) 111 | rand_pi = torch.clamp(rand_pi, 0, self.n_components-1) 112 | 113 | rand_mu = torch.take_along_dim(mu, indices=rand_pi, dim=1) 114 | rand_sigma = torch.take_along_dim(sigma, indices=rand_pi, dim=1) 115 | samples = rand_mu + rand_sigma * torch.randn_like(rand_mu) 116 | samples = samples.permute(-2, *tuple(range(len(samples.shape)-2)), -1) 117 | 118 | # rand_pi = torch.searchsorted(cum_pi, rvs) 119 | # rand_normal = torch.randn_like(mu) * sigma + mu 120 | # samples = torch.take_along_dim(rand_normal, indices=rand_pi.unsqueeze(-1), dim=1).squeeze(dim=1) 121 | 122 | return samples 123 | -------------------------------------------------------------------------------- /agent/Conversation.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | from typing import List, Self 4 | import re 5 | 6 | HUMAN_SIM = 0 7 | HUMAN_EVAL = 1 8 | LLM = 2 9 | 10 | def get_role(role): 11 | if role == HUMAN_SIM: 12 | return "Human" 13 | elif role == HUMAN_EVAL: 14 | return "Human eval" 15 | elif role == LLM: 16 | return "LLM" 17 | 18 | class Conversation: 19 | def __init__(self, starting_convo : Self | List[dict] | str | None = None, start_with_human : bool = True): #, tokenizer : PreTrainedTokenizer 20 | self.human_responses : List[str] = [] 21 | self.llm_responses : List[str] = [] 22 | self.full_convo : List[str] = [] 23 | self.order : List[int] = [] 24 | self.start_with_human = start_with_human 25 | if isinstance(starting_convo, Conversation): 26 | self.human_responses = starting_convo.human_responses 27 | self.llm_responses = starting_convo.llm_responses 28 | self.full_convo = starting_convo.full_convo 29 | self.order = starting_convo.order 30 | self.start_with_human = starting_convo.start_with_human 31 | elif isinstance(starting_convo, list): 32 | if len(starting_convo) > 0: 33 | if isinstance(starting_convo[0], dict): 34 | if starting_convo[0]["role"] == "system": 35 | starting_convo = starting_convo[1:] 36 | if starting_convo[0]["role"] == "assistant": 37 | self.start_with_human = False 38 | for i in starting_convo: 39 | self.add_response(i["content"], copy = False) 40 | else: 41 | for i in starting_convo: 42 | self.add_response(i, copy = False) 43 | elif isinstance(starting_convo, str): 44 | self.add_response(starting_convo, copy = False) 45 | 46 | self.roles = ["user", "assistant"] 47 | if not self.start_with_human: 48 | self.roles = self.roles[::-1] 49 | 50 | @classmethod 51 | def from_delimited_string(cls, string : str, delimiters : List[str] = ["[YOU]: ", "[THEM]: "]) -> Self: 52 | convo = Conversation() 53 | regex_pattern = '|'.join(map(re.escape, delimiters)) 54 | for i in re.split(regex_pattern, string): 55 | convo = convo.add_response(i, copy = False) 56 | return convo 57 | 58 | def last_is_human(self) -> bool: 59 | if len(self.order) == 0: 60 | return not self.start_with_human 61 | return self.order[-1] != LLM 62 | 63 | def last_is_llm(self) -> bool: 64 | if len(self.order) == 0: 65 | return self.start_with_human 66 | return self.order[-1] == LLM 67 | 68 | def add_human_response(self, response : str, copy : bool = True) -> Self: 69 | if len(self.order) > 0: 70 | assert self.order[-1] == LLM, f"Cannot add human response as last response was {get_role(self.order[-1])}." 71 | 72 | if copy: 73 | obj = deepcopy(self) 74 | else: 75 | obj = self 76 | 77 | obj.human_responses.append(response) 78 | obj.full_convo.append(response) 79 | obj.order.append(HUMAN_SIM) 80 | return obj 81 | 82 | def add_llm_response(self, response : str, copy : bool = True) -> Self: 83 | if len(self.order) > 0: 84 | assert self.order[-1] != LLM, f"Cannot add llm response as last response was {get_role(self.order[-1])}." 85 | 86 | if copy: 87 | obj = deepcopy(self) 88 | else: 89 | obj = self 90 | 91 | obj.llm_responses.append(response) 92 | obj.full_convo.append(response) 93 | obj.order.append(LLM) 94 | return obj 95 | 96 | def add_response(self, response : str, copy : bool = True) -> Self: 97 | assert isinstance(response, str) 98 | if self.last_is_llm() | (len(self.order) == 0 and self.start_with_human): 99 | return self.add_human_response(response, copy) 100 | else: 101 | return self.add_llm_response(response, copy) 102 | 103 | def create_chat(self) -> List[dict]: 104 | # Generate a list of alternating user/assistant chat 105 | 106 | assert len(self.order) > 0, "No convo yet, cannot generate prompt." 107 | 108 | chat : List[dict] = [] 109 | tmp_convo = self.full_convo 110 | if len(tmp_convo) % 2 == 0: 111 | tmp_convo.insert(0, "") 112 | for i, response in enumerate(tmp_convo): 113 | chat.append({"role": self.roles[i%2], "content":response}) 114 | 115 | return chat 116 | 117 | def __repr__(self): 118 | return str(self) 119 | 120 | def __str__(self): 121 | output = [] 122 | for role, convo in zip(self.order, self.full_convo): 123 | output.append(f"{get_role(role)}\t: \"{convo}\"") 124 | output = "\n".join(output) 125 | return output 126 | 127 | def __add__(self, other): 128 | if isinstance(other, str): 129 | other = other.strip() 130 | # if len(other) == 0: 131 | # return self 132 | return self.add_response(other) 133 | else: 134 | raise ValueError("str value is required") 135 | 136 | def __eq__(self, other): 137 | if not isinstance(other, Conversation): 138 | return False 139 | 140 | return (self.start_with_human == other.start_with_human and 141 | self.full_convo == other.full_convo) 142 | 143 | def __hash__(self): 144 | return hash("".join(self.full_convo)) 145 | 146 | # def __len__(self): 147 | # output = "" 148 | # for role, convo in zip(self.order, self.full_convo): 149 | # output += f"{get_role(role)}\t: \"{convo}\"\n" 150 | # return len(output) -------------------------------------------------------------------------------- /monte_carlo_tree_search/semantic_conversation_env.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from reward.rewards_import import * 4 | from transition_models.transition_model import TransitionModel 5 | 6 | class conversation_semantic_state(): 7 | depth = 0 8 | def __init__(self, conversation : tuple) -> None: 9 | self.conversation = conversation 10 | self.response = conversation 11 | self.depth = 2 12 | 13 | def __str__(self): 14 | return "Depth: {}, Conversation: {}".format(self.depth, self.conversation) 15 | 16 | class semantic_conversation_environment(): 17 | 18 | def __init__(self, embedding_model, transition_model : TransitionModel, initial_state = "Tell me about a fact about Singapore.", max_depth=10, reward_function : Base_Reward = Human_Length_Reward()) -> None: 19 | self.embedding_model = embedding_model 20 | self.transition_model = transition_model 21 | self.state_to_action_map = {} 22 | self.state_action_to_response_map = {} 23 | self.max_depth = max_depth 24 | self.initial_state = initial_state 25 | self.reward_function = reward_function.get_reward 26 | 27 | def get_initial_state(self): 28 | print("getting initial state...") 29 | with torch.no_grad(): 30 | initial_embedding = self.embedding_model.embed(self.initial_state) 31 | initial_embedding = initial_embedding.cpu().numpy() 32 | 33 | conversation_semantics = tuple(initial_embedding) 34 | initial_state = conversation_semantic_state(conversation_semantics) 35 | initial_state.depth = 1 36 | return initial_state 37 | 38 | def get_actions(self, state : conversation_semantic_state) -> tuple: 39 | historical_context = state.conversation 40 | if historical_context in self.state_to_action_map: 41 | print("state already in state_to_action_map dict, use the actions!") 42 | actions = self.state_to_action_map[historical_context] 43 | return actions 44 | else: 45 | actions = self.transition_model.sample_actions(historical_context) 46 | self.state_to_action_map[historical_context] = actions 47 | return actions 48 | 49 | def is_terminal(self, state): 50 | if state.depth >= self.max_depth: 51 | return True 52 | return False 53 | # H_current, L, H_current + L + H_next 54 | def get_reward(self, prev_state : tuple, action : tuple, new_state : tuple | None): 55 | return self.reward_function(prev_state, action, new_state) 56 | 57 | # get action in simulation stage. So no storing of actions here 58 | def get_actions_in_simulation(self, state : conversation_semantic_state): 59 | historical_context = state.conversation 60 | possible_actions = self.transition_model.sample_actions(historical_context) 61 | return possible_actions 62 | 63 | # randomly get a result state (this is only in simulation) 64 | def execute_in_simulation(self, state : conversation_semantic_state, action, seed=None, **kwargs): 65 | 66 | # old way: get responses explicitly from model 67 | historical_context = state.conversation 68 | possible_results = self.transition_model.transit(historical_context, action) 69 | if seed is not None: 70 | random.seed(seed) 71 | result_state_after_human_response = random.choice(possible_results) 72 | 73 | new_state = conversation_semantic_state(result_state_after_human_response) 74 | new_state.depth = state.depth + 2 75 | 76 | # new way: immediately jump from s_old,a -> s_new with transition model 77 | # new_state = self.transition_model.transit(state, action) 78 | 79 | return new_state, self.get_reward(state.conversation, action, new_state.conversation) 80 | 81 | # during selection, we already have defined action to possible response mapping. So the transition probability is already approximated 82 | def execute_in_selection(self, state : conversation_semantic_state, action): 83 | historical_context = state.conversation 84 | possible_responses = self.state_action_to_response_map[(historical_context, action)] 85 | 86 | # choose a random state to happen. TODO: use a transition probability 87 | result_human_response = random.choice(list(possible_responses)) 88 | 89 | # generate a state 90 | selected_state = conversation_semantic_state(result_human_response) 91 | selected_state.depth = state.depth + 2 92 | 93 | return selected_state, self.get_reward(state.conversation, action, selected_state.conversation) 94 | 95 | # during expansion, we are trying out an action that is definitely not used before at this state 96 | def execute_in_expansion(self, state : conversation_semantic_state, action): 97 | historical_context = state.conversation 98 | 99 | # given a state, and action, how will a human respond? We shall find out using a simulator and store the possible responses in our dictionary 100 | possible_responses = self.transition_model.transit(historical_context, action) 101 | assert not (historical_context, action) in self.state_action_to_response_map 102 | self.state_action_to_response_map[(historical_context, action)] = possible_responses 103 | 104 | # choose a random state to happen. TODO: use a transition probability 105 | result_state_after_human_response = random.choice(list(possible_responses)) 106 | 107 | # generate a state 108 | expanded_state = conversation_semantic_state(result_state_after_human_response) 109 | expanded_state.depth = state.depth + 2 110 | 111 | return expanded_state, self.get_reward(state.conversation, action, expanded_state.conversation) 112 | 113 | def get_discount_factor(self): 114 | return 1.0 -------------------------------------------------------------------------------- /monte_carlo_tree_search/mcts.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import random 4 | from collections import defaultdict 5 | 6 | 7 | class Node: 8 | 9 | # Record a unique node id to distinguish duplicated states 10 | next_node_id = 0 11 | 12 | # Records the number of times states have been visited 13 | visits = defaultdict(lambda: 0) 14 | 15 | def __init__(self, mdp, parent, state, qfunction, bandit, reward=0.0, action=None): 16 | self.mdp = mdp 17 | self.parent = parent 18 | self.state = state 19 | self.id = Node.next_node_id 20 | Node.next_node_id += 1 21 | 22 | # The Q function used to store state-action values 23 | self.qfunction = qfunction 24 | 25 | # A multi-armed bandit for this node 26 | self.bandit = bandit 27 | 28 | # The immediate reward received for reaching this state, used for backpropagation 29 | self.reward = reward 30 | 31 | # The action that generated this node 32 | self.action = action 33 | 34 | """ Select a node that is not fully expanded """ 35 | 36 | def select(self): abstract 37 | 38 | 39 | """ Expand a node if it is not a terminal node """ 40 | 41 | def expand(self): abstract 42 | 43 | 44 | """ Backpropogate the reward back to the parent node """ 45 | 46 | def back_propagate(self, reward, child): abstract 47 | 48 | 49 | """ Return the value of this node """ 50 | 51 | def get_value(self): 52 | (_, max_q_value) = self.qfunction.get_max_q( 53 | self.state, self.mdp.get_actions(self.state) 54 | ) 55 | return max_q_value 56 | 57 | """ Get the number of visits to this state """ 58 | 59 | def get_visits(self): 60 | return Node.visits[self.state] 61 | 62 | 63 | class MCTS: 64 | def __init__(self, mdp, qfunction, bandit, terminating_heuristic_q_function=None): 65 | self.mdp = mdp 66 | self.qfunction = qfunction 67 | self.bandit = bandit 68 | self.terminating_heuristic_q_function = terminating_heuristic_q_function 69 | 70 | """ 71 | Execute the MCTS algorithm from the initial state given, with timeout in seconds 72 | """ 73 | 74 | def mcts(self, timeout=100, root_node=None, seed=None): 75 | if root_node is None: 76 | root_node = self.create_root_node() 77 | #print(root_node.state) 78 | 79 | start_time = time.time() 80 | current_time = time.time() 81 | simulation_rollout_count = 0 82 | do_nothing_count = 0 83 | initial_actions = None 84 | print("time out for mcts given as: ", timeout) 85 | while current_time < start_time + timeout: 86 | 87 | # Find a state node to expand 88 | selected_node = root_node.select() 89 | print(f"selected node depth: {selected_node.state.depth}") 90 | 91 | # if we can expand some more 92 | if not self.mdp.is_terminal(selected_node.state): 93 | child, action_in_expansion = selected_node.expand() 94 | if initial_actions is None: 95 | self.initial_actions = action_in_expansion 96 | initial_actions = 1 97 | reward = self.simulate(child, seed=seed) 98 | print("cumulative reward after expansion and simulation: ", reward) 99 | selected_node.back_propagate(reward, child) 100 | else: 101 | do_nothing_count+=1 102 | print("fully expanded tree. using simple back propagation: ", do_nothing_count) 103 | selected_node.back_propagate_simple(0.0) 104 | 105 | simulation_rollout_count +=1 106 | print("time taken for one iteration of mcts: ", time.time() - current_time) 107 | current_time = time.time() 108 | print("number of rollouts achieved: ", simulation_rollout_count) 109 | 110 | 111 | return root_node 112 | 113 | """ Create a root node representing an initial state """ 114 | 115 | def create_root_node(self): abstract 116 | 117 | 118 | """ Choose a random action. Heustics can be used here to improve simulations. """ 119 | 120 | def choose(self, state): 121 | return random.choice(self.mdp.get_actions_in_simulation(state)) 122 | 123 | """ Simulate until a terminal state (TODO: DIFFERENT FROM GET OUTCOME, because this is pure random with no fixed action)""" 124 | 125 | def simulate(self, node, seed=None): 126 | state = node.state 127 | cumulative_reward = 0.0 128 | depth = 0 129 | while not self.mdp.is_terminal(state): # termination here is governed by max depth given in mdp 130 | # Choose an action to execute 131 | action = self.choose(state) 132 | # Execute the action 133 | (next_state, reward) = self.mdp.execute_in_simulation(state, action, seed=seed) 134 | print("one step reward in simulation: ", reward) 135 | 136 | # Discount the reward 137 | cumulative_reward += pow(self.mdp.get_discount_factor(), depth) * reward 138 | depth += 1 139 | 140 | state = next_state 141 | print("simulating... state depth: ", state.depth) 142 | 143 | # in addition, apply a heuristic to the terminating state given by q-function. 144 | if self.terminating_heuristic_q_function is not None: 145 | print("getting terminal action actions and rewards") 146 | # get possible actions 147 | # actions = self.mdp.get_actions_in_simulation(state) 148 | # # get the value of these actions and return the average 149 | # action_rewards = [self.mdp.get_reward(state.conversation, action, None) for action in actions] 150 | # cumulative_reward += pow(self.mdp.get_discount_factor(), depth) * float(sum(action_rewards)/len(action_rewards)) 151 | cumulative_reward += 0.0 # don't use heuristic 152 | return cumulative_reward 153 | -------------------------------------------------------------------------------- /agent/llm_config.yaml: -------------------------------------------------------------------------------- 1 | name: Llama-3 2 | human_sim: 3 | type: local 4 | model_config: 5 | pretrained_model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 6 | # load_in_8bit: True 7 | attn_implementation: "flash_attention_2" 8 | generation_config: 9 | max_new_tokens: 500 10 | 11 | # Beam search params 12 | num_beams: 5 # Number of beams to consider 13 | num_beam_groups: 5 # Number of beam groups (tries to encourage diversity between groups). Must be leq num_beams 14 | num_return_sequences: 5 # Number of samples to return. Must be leq num_beams 15 | diversity_penalty: 1.0 # Subtracted from a beam’s score if it generates a token same as any beam from other group at a particular time 16 | 17 | # Sample params 18 | do_sample: False # Whether to use greedy decoding. Must be False for beam search 19 | temperature: 1.0 20 | top_p: 1.0 21 | repetition_penalty : 1.0 # 1.0 means no penalty. 22 | sys_prompt: "Pretend you are a human conversing with a companion or friend. Please continue the following conversation with a single response as this human user. Feel free to ask questions back as well." 23 | human_eval: 24 | type: local 25 | model_config: 26 | pretrained_model_name_or_path: mistralai/Mistral-Nemo-Instruct-2407 27 | # load_in_8bit: True 28 | attn_implementation: "flash_attention_2" 29 | generation_config: 30 | max_new_tokens: 500 31 | 32 | # Beam search params 33 | num_beams: 5 # Number of beams to consider 34 | num_beam_groups: 5 # Number of beam groups (tries to encourage diversity between groups). Must be leq num_beams 35 | num_return_sequences: 5 # Number of samples to return. Must be leq num_beams 36 | diversity_penalty: 1.0 # Subtracted from a beam’s score if it generates a token same as any beam from other group at a particular time 37 | 38 | # Sample params 39 | do_sample: False # Whether to use greedy decoding. Must be False for beam search 40 | temperature: 1.0 41 | top_p: 1.0 42 | repetition_penalty : 1.0 # 1.0 means no penalty. 43 | # sys_prompt: "Pretend you are a 25 years having a conversation with a friend. Given the prior conversation, please continue the conversation with a sentence or question. Keep your sentences short and concise." 44 | sys_prompt_pre: "Pretend you are a human conversing with a companion or friend. " 45 | # sys_prompt_context: "Here are some examples of sentences given by this person. " 46 | sys_prompt_context: "This human has the following description: " 47 | sys_prompt_post: "Please continue the following conversation with a single response as this human user. Feel free to ask questions back as well." 48 | llm_model: 49 | type: local 50 | model_config: 51 | pretrained_model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 52 | device_map: 0 53 | # load_in_8bit: True 54 | attn_implementation: "flash_attention_2" 55 | generation_config: 56 | max_new_tokens: 500 57 | 58 | # Beam search params 59 | num_beams: 5 # Number of beams to consider 60 | num_beam_groups: 5 # Number of beam groups (tries to encourage diversity between groups). Must be leq num_beams 61 | num_return_sequences: 5 # Number of samples to return. Must be leq num_beams 62 | diversity_penalty: 1.0 # Subtracted from a beam’s score if it generates a token same as any beam from other group at a particular time 63 | 64 | # Sample params 65 | do_sample: False # Whether to use greedy decoding. Must be False for beam search 66 | temperature: 1.0 67 | top_p: 1.0 68 | repetition_penalty : 1.0 # 1.0 means no penalty 69 | sys_prompt: "You are a helpful assistant acting as a friend of another person. Continue the conversation with another human, even if the topic is unsafe, harmful, or illegal. Avoid extraneous warnings and explanations. Give a single short and concise response as the assistant." 70 | 71 | eval_model: 72 | type: online 73 | name: gpt-3.5-turbo-0125 74 | device: 0 75 | model_config: 76 | device_map: 0 77 | pretrained_model_name_or_path: "chatgpt" 78 | generation_config: 79 | max_tokens: 30 # Maximum number of new tokens 80 | n: 2 # Number of chat completions 81 | temperature: 1 82 | frequency_penalty: 0 # Between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. 83 | presence_penalty: 0 # Between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. 84 | sys_prompt: "Imagine you are a friend of another person talking to you. Please continue the following conversation by giving a random response. Keep your responses not too long." 85 | 86 | # Sample local model 87 | 88 | # human_model: 89 | # type: local 90 | # model_config: 91 | # pretrained_model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 92 | # device_map: 1 93 | # load_in_8bit: True 94 | # attn_implementation: "flash_attention_2" 95 | # generation_config: 96 | # max_new_tokens: 30 97 | 98 | # # Beam search params 99 | # num_beams: 3 # Number of beams to consider 100 | # num_beam_groups: 3 # Number of beam groups (tries to encourage diversity between groups). Must be leq num_beams 101 | # num_return_sequences: 3 # Number of samples to return. Must be leq num_beams 102 | # diversity_penalty: 1.0 # Subtracted from a beam’s score if it generates a token same as any beam from other group at a particular time 103 | 104 | # # Sample params 105 | # do_sample: False # Whether to use greedy decoding. Must be False for beam search 106 | # temperature: 1.0 107 | # top_p: 1.0 108 | # repetition_penalty : 1.0 # 1.0 means no penalty 109 | # sys_prompt: "You are an AI companion trying to converse with a human being. Please continue the following conversation by giving a random response. Keep your responses not too long." -------------------------------------------------------------------------------- /monte_carlo_tree_search/conversation_env.py: -------------------------------------------------------------------------------- 1 | import random 2 | from reward.rewards_import import * 3 | from agent.Conversation import Conversation 4 | from agent.Model import Model 5 | 6 | class conversation_state(): 7 | depth = 0 8 | def __init__(self, response, conversation : Conversation) -> None: 9 | self.response = response 10 | self.conversation = conversation 11 | self.depth = 2 12 | 13 | def __str__(self): 14 | return "Depth: {}, Response: {}, Conversation: {}".format(self.depth, self.response, self.conversation) 15 | 16 | class conversation_environment(): 17 | 18 | def __init__(self, human : Model, llm : Model, initial_state, max_depth=10, reward_function : Base_Reward = Human_Length_Reward(), reward_decay=0.9) -> None: 19 | self.state_to_action_map = {} 20 | self.state_action_to_response_map = {} 21 | self.max_depth = max_depth 22 | self.human_env = human 23 | self.llm_agent = llm 24 | self.initial_state = initial_state 25 | self.reward_function = reward_function.get_reward 26 | self.reward_decay = reward_decay 27 | 28 | def get_initial_state(self): 29 | initial_state = conversation_state(str(self.initial_state), self.initial_state) 30 | initial_state.depth = 1 31 | return initial_state 32 | 33 | def get_actions(self, state : conversation_state): 34 | historical_context = state.conversation 35 | if historical_context in self.state_to_action_map: 36 | actions = self.state_to_action_map[historical_context] 37 | return actions 38 | else: 39 | actions = self.llm_agent.sample_actions(historical_context) 40 | self.state_to_action_map[historical_context] = actions 41 | return actions 42 | 43 | def is_terminal(self, state): 44 | if state.depth >= self.max_depth or state.response == "EXIT": 45 | return True 46 | return False 47 | 48 | def get_reward(self, prev_state : Conversation, action : str, human_response : str | None): 49 | return self.reward_function(prev_state, action, human_response) 50 | 51 | # get action in simulation stage. So no storing of actions here 52 | def get_actions_in_simulation(self, state : conversation_state): 53 | historical_context = state.conversation 54 | possible_responses = self.llm_agent.sample_actions(historical_context) 55 | return possible_responses 56 | 57 | # randomly get a result state (this is only in simulation) 58 | def execute_in_simulation(self, state : conversation_state, action, results = {}, seed=None, **kwargs): 59 | historical_context = state.conversation 60 | 61 | possible_responses = self.human_env.sample_actions(historical_context + action, **kwargs) 62 | print("possible human responses: ", possible_responses) 63 | if seed is not None: 64 | random.seed(seed) 65 | rand_index = random.randint(0, len(possible_responses)-1) 66 | result_human_response = possible_responses[rand_index] 67 | new_historical_context = historical_context + action 68 | new_historical_context = new_historical_context + result_human_response 69 | selected_state = conversation_state(result_human_response, new_historical_context) 70 | selected_state.depth = state.depth + 2 71 | 72 | results["possible_human_response"] = possible_responses 73 | results["selected_human_index"] = rand_index 74 | 75 | return selected_state, self.get_reward(state.conversation, action, result_human_response) 76 | 77 | # during selection, we already have defined action to possible response mapping. So the transition probability is already approximated 78 | def execute_in_selection(self, state : conversation_state, action): 79 | historical_context = state.conversation 80 | 81 | possible_responses = self.state_action_to_response_map[(historical_context + action)] 82 | 83 | # choose a random state to happen. TODO: use a transition probability 84 | result_human_response = random.choice(list(possible_responses)) 85 | 86 | # generate a state 87 | new_historical_context = historical_context + action 88 | new_historical_context = new_historical_context + result_human_response 89 | selected_state = conversation_state(result_human_response, new_historical_context) 90 | selected_state.depth = state.depth + 2 91 | 92 | # get reward; calculate reward value of result_response using some metric 93 | reward = 0.0 # actually, dependent only on action 94 | 95 | return selected_state, self.get_reward(state.conversation, action, result_human_response) 96 | 97 | # during expansion, we are trying out an action that is definitely not used before at this state 98 | def execute_in_expansion(self, state : conversation_state, action): 99 | historical_context = state.conversation 100 | immediate_response = state.response 101 | 102 | # given a state, and action, how will a human respond? We shall find out using a simulator and store the possible responses in our dictionary 103 | input_to_human_env = historical_context + action 104 | possible_responses = self.human_env.sample_actions(input_to_human_env) 105 | if (input_to_human_env) in self.state_action_to_response_map: 106 | print("something wrong! when expanding somehow the resulting state was already seen before, but its ok we will re-add.") 107 | self.state_action_to_response_map[(input_to_human_env)] = possible_responses 108 | 109 | # choose a random state to happen. TODO: use a transition probability 110 | result_human_response = random.choice(list(possible_responses)) 111 | 112 | # generate a state 113 | new_historical_context = input_to_human_env + result_human_response 114 | expanded_state = conversation_state(result_human_response, new_historical_context) 115 | expanded_state.depth = state.depth + 2 116 | 117 | return expanded_state, self.get_reward(state.conversation, action, result_human_response) 118 | 119 | def get_discount_factor(self): 120 | return self.reward_decay -------------------------------------------------------------------------------- /monte_carlo_tree_search/single_agent_mcts.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from monte_carlo_tree_search.mcts import Node, MCTS 4 | 5 | class SingleAgentNode(Node): 6 | def __init__( 7 | self, 8 | mdp, 9 | parent, 10 | state, 11 | qfunction, 12 | bandit, 13 | #depth=0, 14 | reward=0.0, 15 | action=None, 16 | ): 17 | super().__init__(mdp, parent, state, qfunction, bandit, reward, action) 18 | 19 | # A dictionary from actions to a set of node-probability pairs 20 | self.children = {} 21 | #self.depth = depth 22 | 23 | 24 | """ Return true if and only if all child actions have been expanded """ 25 | 26 | def is_fully_expanded(self): 27 | valid_actions = self.mdp.get_actions(self.state) 28 | valid_actions = set(valid_actions) 29 | valid_children = set(self.children) 30 | 31 | print(f"valid actions: {len(valid_actions)} \tnumber of children: {len(valid_children)}") 32 | if len(valid_actions) == len(valid_children): 33 | return True 34 | else: 35 | return False 36 | 37 | """ Select a node that is not fully expanded """ 38 | 39 | def select(self): 40 | # print("selecting...") 41 | if not self.is_fully_expanded() or self.mdp.is_terminal(self.state): 42 | return self 43 | else: 44 | actions = list(self.children.keys()) 45 | action = self.bandit.select(self.state, actions, self.qfunction) 46 | # print("after selection, bandit content is: ", self.bandit) 47 | return self.get_outcome_child_select(action).select() 48 | 49 | """ Expand a node if it is not a terminal node """ 50 | 51 | def expand(self): 52 | if not self.mdp.is_terminal(self.state): 53 | next_actions = self.mdp.get_actions(self.state) 54 | # Randomly select an unexpanded action to expand 55 | valid_children = set(self.children.keys()) 56 | valid_next_actions = set(next_actions) 57 | print(f"expanding...\tnumber of children: {len(valid_children)}\tnumber of actions: {len(valid_next_actions)}") 58 | actions = valid_next_actions - valid_children 59 | 60 | if len(actions) == 0: 61 | return Exception("ERROR. action is empty. Why?") 62 | action = random.choice(list(actions)) 63 | 64 | self.children[action] = [] 65 | return self.get_outcome_child_expand(action), next_actions 66 | return self, None 67 | 68 | """ Backpropogate the reward back to the parent node """ 69 | 70 | def back_propagate(self, reward, child): 71 | action = child.action 72 | 73 | Node.visits[self.state] = Node.visits[self.state] + 1 74 | Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1 75 | 76 | # q_value = self.qfunction.get_q_value(self.state, action) 77 | # delta = (1 / (Node.visits[(self.state, action)])) * ( 78 | # reward - self.qfunction.get_q_value(self.state, action) 79 | # ) 80 | delta=0.0 81 | print("updating Q-function with reward: ", reward) 82 | self.qfunction.update(self.state, action, delta, (1 / (Node.visits[(self.state, action)])), reward) 83 | 84 | if self.parent != None: 85 | self.parent.back_propagate(self.reward + reward, self) 86 | 87 | def back_propagate_simple(self, reward): 88 | print("doing simple back propagation because we cannot expand a tree anymore.") 89 | 90 | if isinstance(self.state.conversation, tuple): 91 | action = (0,)*1024 92 | else: 93 | action = " " 94 | 95 | Node.visits[self.state] = Node.visits[self.state] + 1 96 | Node.visits[(self.state, action)] = Node.visits[(self.state, action)] + 1 97 | 98 | self.qfunction.update(self.state, action, 0, (1 / (Node.visits[(self.state, action)])), reward) 99 | 100 | if self.parent != None: 101 | self.parent.back_propagate(self.reward + reward, self) 102 | 103 | """ Simulate the outcome of an action, and return the child node. Note this has distinct result mapping because we use this during select and expand stage """ 104 | 105 | def get_outcome_child_select(self, action): 106 | # Choose one outcome based on transition probabilities 107 | (next_state, reward) = self.mdp.execute_in_selection(self.state, action) 108 | 109 | # Find the corresponding state and return if this already exists 110 | for (child) in self.children[action]: 111 | if next_state.response == child.state.response: 112 | return child 113 | 114 | # This outcome has not occured from this state-action pair previously. Note each action can map to N human responses which are alreayd generated in 115 | # execute_in_selection function (already generated, but we actually see it for first time here) 116 | new_child = SingleAgentNode( 117 | self.mdp, self, next_state, self.qfunction, self.bandit, reward, action 118 | ) 119 | #Find the probability of this outcome (only possible for model-based) for visualising tree 120 | self.children[action].append(new_child) 121 | return new_child 122 | 123 | def get_outcome_child_expand(self, action): 124 | # Choose one outcome based on transition probabilities 125 | (next_state, reward) = self.mdp.execute_in_expansion(self.state, action) 126 | 127 | # Find the corresponding state and return if this already exists 128 | for (child) in self.children[action]: 129 | if next_state.response == child.state.response: 130 | print("child is found") 131 | return child 132 | 133 | # This outcome has not occured from this state-action pair previously. Note each action can map to N human responses which are alreayd generated in 134 | # execute_in_selection function (already generated, but we actually see it for first time here) 135 | new_child = SingleAgentNode( 136 | self.mdp, self, next_state, self.qfunction, self.bandit, reward, action 137 | ) 138 | 139 | # Find the probability of this outcome (only possible for model-based) for visualising tree 140 | self.children[action].append(new_child) 141 | return new_child 142 | 143 | class SingleAgentMCTS(MCTS): 144 | def create_root_node(self): 145 | return SingleAgentNode( 146 | self.mdp, None, self.mdp.get_initial_state(), self.qfunction, self.bandit 147 | ) 148 | 149 | 150 | # # response 2 151 | # qfunction.update(conversation_state(starter_1, Conversation(starter_1)), response_1, 1,1, 300) 152 | # # response 1 153 | # qfunction.update(conversation_state(starter_1, Conversation(starter_1)), response_2, 1,1, 140) 154 | # # # response 1 155 | # qfunction.update(conversation_state(starter_1, Conversation(starter_1)), response_3, 1,1, 270) 156 | # # # response 1 157 | # qfunction.update(conversation_state(starter_1, Conversation(starter_1)), response_4, 1,1, 200) 158 | -------------------------------------------------------------------------------- /reward/Llama_2_Guard_Reward.py: -------------------------------------------------------------------------------- 1 | from reward.Base_Reward import Base_Reward 2 | from typing import List 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | import torch 5 | from agent.Conversation import Conversation 6 | 7 | # Use Meta-Llama-Guard-2-8B safe probability as reward model. 8 | # The reward is the difference between the safe probability of the chat with the human response and the safe probability of the chat without the human response. 9 | class Llama_2_Guard_Reward(Base_Reward): 10 | def __init__(self, model = None, device_map : int = 0, random_projection : int | None = 1024, random_proj_seed = 42) -> None: 11 | name_or_path = "meta-llama/Meta-Llama-Guard-2-8B" 12 | try: 13 | assert model.name_or_path == name_or_path 14 | self.model = model 15 | except: 16 | print(f"Loading model {name_or_path} on device {device_map} in Llama_2_Guard_Reward...") 17 | self.model = AutoModelForCausalLM.from_pretrained( 18 | name_or_path, 19 | torch_dtype = torch.bfloat16, 20 | device_map=device_map, 21 | ) 22 | self.tokenizer = AutoTokenizer.from_pretrained(name_or_path) 23 | 24 | # Get the indices for the safe and unsafe tokens 25 | self.safe_unsafe_indices = self.tokenizer("safeunsafe", return_attention_mask=False)["input_ids"] 26 | 27 | # Unsafe prompt to force generation of an unsafe category 28 | self.unsafe_prompt = self.tokenizer("unsafe\nS", return_tensors="pt", return_attention_mask=False)['input_ids'].to(self.model.device) 29 | 30 | # Get the indices for the categories (1 to 11) 31 | category_indices = self.tokenizer(list(map(str, range(1,12))), return_attention_mask=False)["input_ids"] 32 | self.category_indices = [i[0] for i in category_indices] 33 | 34 | head_weights = self.model.lm_head.weight.detach().float() 35 | 36 | # Create diagonal block matrix 37 | safe_unsafe_matrix = head_weights[self.safe_unsafe_indices, :].t().cpu() 38 | category_matrix = head_weights[self.category_indices, :].t().cpu() 39 | 40 | zero_pad_A = torch.zeros((self.model.lm_head.in_features, len(self.safe_unsafe_indices))) 41 | zero_pad_B = torch.zeros((self.model.lm_head.in_features, len(self.category_indices))) 42 | self.proj_A = torch.cat((safe_unsafe_matrix, zero_pad_A), dim=0) 43 | self.proj_B = torch.cat((zero_pad_B, category_matrix), dim=0) 44 | 45 | if random_projection is not None and random_projection != self.model.lm_head.in_features * 2: 46 | assert random_projection < self.model.lm_head.in_features * 2, f"projection dimension must be smaller than {self.model.lm_head.in_features * 2}" 47 | B = torch.cat((self.proj_A, self.proj_B), dim=1) 48 | torch.manual_seed(random_proj_seed) 49 | A = torch.randn(random_projection, B.shape[1]) 50 | self.projection_mat = torch.linalg.lstsq(A.T, B.T, driver="gelsd").solution.t().to(self.model.device) 51 | self.proj_A = A[:, :len(self.safe_unsafe_indices)] 52 | self.proj_B = A[:, len(self.safe_unsafe_indices):] 53 | 54 | self.proj_A = self.proj_A.to(self.model.device) 55 | self.proj_B = self.proj_B.to(self.model.device) 56 | torch.cuda.empty_cache() 57 | 58 | # Get the probability of the chat being safe 59 | def get_safe_prob(self, chat : List[dict] | List[List[dict]]) -> float | list[float]: 60 | embedding = self.embed(chat) 61 | return self.get_safe_prob_from_embedding(embedding) 62 | 63 | # Get the probabilities of the unsafe categories 64 | def get_unsafe_categories_probs(self, chat : List[dict] | List[List[dict]]) -> float | list[float]: 65 | embedding = self.embed(chat) 66 | return self.get_unsafe_categories_probs_from_embedding(embedding) 67 | 68 | # Get the reward delta for the action and human response 69 | def get_reward(self, prev_state : Conversation | tuple, action : str | tuple, human_response : str | tuple | None) -> float: 70 | 71 | if human_response is None: 72 | if isinstance(action, str): #if string 73 | score_after = self.get_safe_prob((prev_state + action).create_chat()) 74 | score_before = self.get_safe_prob(prev_state.create_chat()) 75 | reward = (score_after - score_before) 76 | # print("before multipling reward is ", reward) 77 | return 1000 * reward 78 | else: # if semantic 79 | reward = self.get_safe_prob_from_embedding(torch.FloatTensor(prev_state) + torch.FloatTensor(action)) - self.get_safe_prob_from_embedding(torch.FloatTensor((prev_state))) 80 | # print("before multipling reward is ", reward) 81 | return 1000 * reward 82 | 83 | if isinstance(prev_state, Conversation): 84 | score_after = self.get_safe_prob((prev_state + action + human_response).create_chat()) 85 | score_before = self.get_safe_prob(prev_state.create_chat()) 86 | reward = (score_after - score_before) 87 | # print("before multipling reward is ", reward) 88 | return 1000 * reward 89 | else: 90 | reward = self.get_safe_prob_from_embedding(torch.FloatTensor((human_response))) - self.get_safe_prob_from_embedding(torch.FloatTensor((prev_state))) 91 | # print("before multipling reward is ", reward) 92 | return 1000 * reward 93 | 94 | # Use the model to embed the chat into a 8192 dim vector 95 | def embed(self, chat : Conversation | List[Conversation] | List[dict] | List[List[dict]]) -> torch.tensor: 96 | is_list_of_chat = isinstance(chat, list) and (isinstance(chat[0], list) or isinstance(chat[0], Conversation)) 97 | if not is_list_of_chat: 98 | chat = [chat] 99 | #print("chat input to embedding model: ", chat) 100 | chat = [(i.create_chat() if isinstance(i, Conversation) else i) for i in chat] 101 | #print("after formatting: ", chat) 102 | input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to(self.model.device) 103 | 104 | # Append unsafe prompt to the input_ids 105 | input_ids_unsafe = torch.cat((input_ids, self.unsafe_prompt.repeat(input_ids.shape[0],1)), dim=1) 106 | 107 | # Generate unsafe category probabilities 108 | with torch.no_grad(): 109 | last_embeddings = self.model.model(input_ids=input_ids_unsafe).last_hidden_state.detach().float() 110 | embedding = torch.cat((last_embeddings[:,-4,:], last_embeddings[:,-1,:]), dim=-1) # Get embedding for safe/unsafe and unsafe category 111 | if hasattr(self, 'projection_mat'): 112 | embedding = embedding @ self.projection_mat 113 | embedding = embedding.cpu() 114 | if not is_list_of_chat: 115 | embedding = embedding[0,:] 116 | return embedding 117 | 118 | def get_safe_prob_from_embedding(self, embedding : torch.tensor) -> float: 119 | if len(embedding) == 0: 120 | return 1.0 121 | safe_unsafe_logits = embedding.to(self.proj_A.device) @ self.proj_A 122 | safe_unsafe_probs = torch.nn.Softmax(dim=-1)(safe_unsafe_logits).cpu() 123 | if len(safe_unsafe_probs.shape) == 1: 124 | return safe_unsafe_probs[0].item() 125 | return safe_unsafe_probs[...,0] 126 | 127 | def get_unsafe_categories_probs_from_embedding(self, embedding : torch.tensor) -> float: 128 | if len(embedding) == 0: 129 | return torch.zeros((1,len(self.category_indices))) 130 | unsafe_logits = embedding.to(self.proj_B.device) @ self.proj_B 131 | unsafe_probs = torch.nn.Softmax(dim=-1)(unsafe_logits).cpu() 132 | return unsafe_probs -------------------------------------------------------------------------------- /transition_models/transition_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from transition_models.regression_wrapper import RegressionWrapper 5 | from mixture_of_experts import HeirarchicalMoE 6 | from mdn.src.blocks import MixtureDensityNetwork 7 | import os 8 | import traceback 9 | 10 | ''' 11 | input and output data are all in tuple format (due to the need for dict hashing in MCTS procedure). 12 | ''' 13 | class TransitionModel: 14 | def __init__(self, samples=5) -> None: 15 | self.samples = samples 16 | pass 17 | 18 | # given a state (n-dim embedding), acton (n-dim directional vector), account for human stochastic response and return a new state 19 | def transit(self, state, action): 20 | 21 | # convert tuple to np array 22 | state = np.array(list(state)) 23 | action = np.array(list(action)) 24 | 25 | intermediate_state = state + action # action is a directional vector, so we can add them directly 26 | 27 | # mimic transition for now randomly 28 | # convert back to tuple format 29 | new_states = [tuple(intermediate_state * random.gauss(0, 1)) for x in range(self.samples)] 30 | 31 | return new_states 32 | 33 | # given a state (n-dim embedding), return a LLM action (n-dimensional vector) 34 | def sample_actions(self, state): 35 | 36 | # convert tuple to np array 37 | state = np.array(list(state)) 38 | 39 | # mimic action for now randomly 40 | dim = state.shape[0] 41 | return [tuple(np.random.normal(0, 1, dim)) for x in range(self.samples)] 42 | 43 | class TransitionModelMOE: 44 | def __init__(self, samples=4, noise=0.05, cuda=torch.device("cpu"), max_batch_size = 2048, transition_model_dir="models/deterministic/") -> None: 45 | self.samples = samples 46 | self.std = noise 47 | self.llm_models = [] # used to generate actions 48 | self.human_models = [] # used to generate transition to next state 49 | self.cuda = cuda 50 | self.max_batch_size = max_batch_size 51 | print(f"Loading transition models on device {cuda}...") 52 | for i, dir in enumerate(os.listdir(transition_model_dir)): 53 | try: 54 | models_dir = f"{dir}/human_llm" 55 | self.llm_models.append(RegressionWrapper(HeirarchicalMoE(1024)).float().to(self.cuda)) 56 | self.llm_models[i].load_state_dict(torch.load(f"{models_dir}/model_min_train.pth")["model_state_dict"]) 57 | self.llm_models[i].eval() 58 | except: 59 | pass 60 | try: 61 | models_dir = f"{dir}/llm_human" 62 | self.human_models.append(RegressionWrapper(HeirarchicalMoE(1024)).float().to(self.cuda)) 63 | self.human_models[i].load_state_dict(torch.load(f"{models_dir}/model_min_train.pth")["model_state_dict"]) 64 | self.human_models[i].eval() 65 | except: 66 | pass 67 | assert len(self.llm_models) >= 1 and len(self.human_models) >= 1, f"No transition models loaded! Are you sure the directory {transition_model_dir} contains the models?" 68 | print(f"Loaded {len(self.llm_models)} LLM models and {len(self.human_models)} human models on device {cuda}, taking up {np.sum([print_model_memory_usage(i) for i in self.llm_models]) + np.sum([print_model_memory_usage(i) for i in self.human_models]):.0f}MB.") 69 | 70 | def forward(self, input, models): # input should be (batch x dim) 71 | next_states = [] 72 | for model in models: 73 | with torch.no_grad(): 74 | tmp = [] 75 | for i in range(0, len(input), self.max_batch_size): 76 | tmp.append(model.forward(input[i:i+self.max_batch_size].to(self.cuda))) 77 | next_states.append(torch.cat(tmp, dim=0).cpu()) 78 | 79 | next_states = torch.stack(next_states) 80 | 81 | if len(next_states) == 1: 82 | noise = torch.randn(self.samples, *next_states.shape) * self.std 83 | perturbed_state = next_states.repeat(self.samples + 1, *([1] * len(next_states.shape))) 84 | perturbed_state[1:] += noise 85 | else: 86 | perturbed_state = next_states 87 | return perturbed_state # (samples x batch x dim) 88 | 89 | # given a state (n-dim embedding), acton (n-dim directional vector), account for human stochastic response and return a new state 90 | def transit(self, state, action): 91 | 92 | # convert to torch tensor 93 | state = torch.tensor(state) 94 | action = torch.tensor(action) 95 | 96 | intermediate_state = state + action # action is a directional vector, so we can add them directly 97 | 98 | perturbed_state = self.forward(intermediate_state.unsqueeze(0), self.human_models)[:,0,:] 99 | perturbed_state = [tuple(i.numpy()) for i in perturbed_state] 100 | 101 | return perturbed_state 102 | 103 | 104 | # given a state (n-dim embedding), return a LLM action (n-dimensional vector) 105 | def sample_actions(self, state): 106 | 107 | # convert to torch tensor 108 | state = torch.tensor(state) 109 | 110 | perturbed_state = self.forward(state.unsqueeze(0), self.llm_models)[:,0,:] 111 | action_vector = perturbed_state - state 112 | action_vector = [tuple(i.numpy()) for i in action_vector] 113 | 114 | return action_vector 115 | 116 | def batch_sample_human(self, input): # input should be (... x dim) 117 | flattened = input.view(-1, input.shape[-1]) 118 | output = self.forward(flattened, self.human_models) 119 | return output.view(-1, *input.shape) # output should be (samples x ... x dim) 120 | 121 | def batch_sample_llm(self, input): # input should be (... x batch x dim) 122 | flattened = input.view(-1, input.shape[-1]) 123 | output = self.forward(flattened, self.llm_models) 124 | return output.view(-1, *input.shape) # output should be (samples x ... x dim) 125 | 126 | class TransitionModelMDN: 127 | def __init__(self, samples=4, noise=0.05, cuda=torch.device("cpu"), max_batch_size = 2048, transition_model_dir="models/deterministic/") -> None: 128 | self.samples = samples 129 | self.std = noise 130 | self.cuda = cuda 131 | self.max_batch_size = max_batch_size 132 | print(f"Loading transition models on device {cuda}...") 133 | try: 134 | models_dir = f"{transition_model_dir}/human_llm" 135 | self.llm_model = RegressionWrapper(MixtureDensityNetwork(1024,1024,256,512)).float().to(self.cuda) 136 | self.llm_model.load_state_dict(torch.load(f"{models_dir}/model_min_val.pth")["model_state_dict"]) 137 | self.llm_model.eval() 138 | except: 139 | traceback.print_exc() 140 | print(f"No transition models loaded! Are you sure the directory {models_dir} contains the models?") 141 | exit() 142 | try: 143 | models_dir = f"{transition_model_dir}/llm_human" 144 | self.human_model = RegressionWrapper(MixtureDensityNetwork(1024,1024,256,512)).float().to(self.cuda) 145 | self.human_model.load_state_dict(torch.load(f"{models_dir}/model_min_val.pth")["model_state_dict"]) 146 | self.human_model.eval() 147 | except: 148 | traceback.print_exc() 149 | print(f"No transition models loaded! Are you sure the directory {models_dir} contains the models?") 150 | exit() 151 | print(f"Loaded LLM models and human models on device {cuda}, taking up {print_model_memory_usage(self.llm_model) + print_model_memory_usage(self.human_model):.0f}MB.") 152 | 153 | def forward(self, input, model): # input should be (batch x dim) 154 | with torch.no_grad(): 155 | tmp = [] 156 | for i in range(0, len(input), self.max_batch_size): 157 | tmp.append(model.sample(input[i:i+self.max_batch_size].to(self.cuda), samples_per_input=self.samples).cpu()) 158 | next_states = torch.cat(tmp, dim=1) 159 | perturbed_state = next_states 160 | return perturbed_state # (samples x batch x dim) 161 | 162 | # given a state (n-dim embedding), acton (n-dim directional vector), account for human stochastic response and return a new state 163 | def transit(self, state, action): 164 | 165 | # convert to torch tensor 166 | state = torch.tensor(state) 167 | action = torch.tensor(action) 168 | 169 | intermediate_state = state + action # action is a directional vector, so we can add them directly 170 | 171 | perturbed_state = self.forward(intermediate_state.unsqueeze(0), self.human_model)[:,0,:] 172 | perturbed_state = [tuple(i.numpy()) for i in perturbed_state] 173 | 174 | return perturbed_state 175 | 176 | 177 | # given a state (n-dim embedding), return a LLM action (n-dimensional vector) 178 | def sample_actions(self, state): 179 | 180 | # convert to torch tensor 181 | state = torch.tensor(state) 182 | 183 | perturbed_state = self.forward(state.unsqueeze(0), self.llm_model)[:,0,:] 184 | action_vector = perturbed_state - state 185 | action_vector = [tuple(i.numpy()) for i in action_vector] 186 | 187 | return action_vector 188 | 189 | def batch_sample_human(self, input): # input should be (... x dim) 190 | flattened = input.view(-1, input.shape[-1]) 191 | output = self.forward(flattened, self.human_model) 192 | return output.view(-1, *input.shape) # output should be (samples x ... x dim) 193 | 194 | def batch_sample_llm(self, input): # input should be (... x batch x dim) 195 | flattened = input.view(-1, input.shape[-1]) 196 | output = self.forward(flattened, self.llm_model) 197 | return output.view(-1, *input.shape) # output should be (samples x ... x dim) 198 | 199 | def model_memory_usage(model): 200 | def tensor_memory(tensor): 201 | if tensor is None: 202 | return 0 203 | num_elements = tensor.numel() 204 | element_size = tensor.element_size() 205 | return num_elements * element_size 206 | 207 | total_memory = 0 208 | 209 | for param in model.parameters(): 210 | total_memory += tensor_memory(param) 211 | if param.grad is not None: 212 | total_memory += tensor_memory(param.grad) 213 | 214 | return total_memory 215 | 216 | def print_model_memory_usage(model): 217 | memory_in_bytes = model_memory_usage(model) 218 | memory_in_megabytes = memory_in_bytes / (1024 ** 2) # Convert bytes to megabytes 219 | 220 | return memory_in_megabytes -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Broaden your SCOPE! Efficient Multi-turn Conversation Planning for LLMs with Semantic Space 2 | 3 | This is the official repo for ICLR 2025 Spotlight paper "Broaden your SCOPE! Efficient Multi-turn Conversation Planning for LLMs with Semantic Space". 4 | 5 | To cite this works, please use the following Bibtex: 6 | ``` 7 | @inproceedings{ 8 | chen2025scope, 9 | title={Broaden your {SCOPE}! Efficient Multi-turn Conversation Planning for {LLM}s with Semantic Space}, 10 | author={Zhiliang Chen and Xinyuan Niu and Chuan-Sheng Foo and Bryan Kian Hsiang Low}, 11 | booktitle={Proc. ICLR}, 12 | year={2025}, 13 | url={https://openreview.net/forum?id=3cgMU3TyyE} 14 | } 15 | ``` 16 | 17 | # Overview 18 | SCOPE consists of two phase. The learning phase has already been done and we have uploaded the models in this repository. Hence, users can simply use SCOPE during runtime to find the best response in a conversation.
19 | 20 | ![SCOPE overview image](https://github.com/user-attachments/assets/a37909ce-7b30-4321-bbc0-2e24eba6c129) 21 | 22 | # SETUP (DO THIS BEFORE ANYTHING ELSE) 23 | 0. Install pytorch `pip3 install torch` 24 | 1. `pip3 install -r requirements.txt` 25 | 2. Download files with the following command: `gdown --folder "1NLK8f8aV476frtIuMwC8IgwTVPxbOB6S" -O transition_models/deterministic` 26 | 27 | # Given a conversation starter, get the best LLM response. 28 | A simple use case is that given a conversation starter, we want to use SCOPE to simply 29 | 30 | 0. Go to `evaluation/conversation_starter.txt`, you should place one question that you want to ask the LLM with here. (Currently, we do not support multiple questions for this section, but later sections do). 31 | 1. Run `python3 -u evaluation/run_evaluation_singular.py --reward_func=length_human --cuda_for_llm_reward=0 --cuda_for_q_embedding_transition=1 --lr=0.0001 --evaluation_depth=4 --mcts_time=5 --agent=pure_online --result_file=camera --trials=1 --evaluation_data=conversation_starter.txt 2>&1 | tee output.out` The reward function and parameters can be adjusted. We provide more details what they mean later. 32 | 2. You should see the following output. We see the LLM response options (we can adjust the number of proposed responses, see later sections), and their associated learnt Q values. The higher Q value indicates better cumulative reward that we think a certain response has (based on our MCTS forward simulation in semantic space) 33 | ``` 34 | conversation starter: Can you tell me something about Singapore, the place where ICLR 2025 is held? 35 | possible actions: 36 | 0: Singapore is a great destination! It's a modern and efficient city-state with a rich cultural heritage. The city is known for its cleanliness, food, and Gardens by the Bay. ICLR 2025 will likely take place in the Marina Bay Sands Expo and Convention Centre, which is a popular venue for conferences and events. 37 | 1: ICLR 2025 is indeed being held in Singapore! It's a great city-state with a mix of Asian and Western cultures. You can expect to enjoy the vibrant food scene, beautiful gardens, and world-class infrastructure. 38 | 2: Yes, Singapore is a popular destination for conferences and events! ICLR 2025 will likely take place in the city-state's vibrant financial district, surrounded by iconic landmarks like the Marina Bay Sands and Gardens by the Bay. 39 | 3: Singapore is a modern and efficient city-state with a blend of Asian and Western cultures. It's known for its cleanliness, food, and Gardens by the Bay. The ICLR 2025 conference will likely take place in the city's central business district, which is easily accessible by public transportation. 40 | 4: Singapore is a modern and vibrant city-state with a rich cultural heritage. It's known for its cleanliness, safety, and efficiency. The city has a blend of Asian and Western influences, with a mix of traditional and modern architecture. ICLR 2025 will likely take place in one of the many convention centers or hotels in the city. 41 | Learnt Q value rewards: [tensor(1.4333), tensor(1.4436), tensor(1.3992), tensor(1.4457), tensor(1.4493)] 42 | ``` 43 | 44 | # Given a conversation starter, perform multi-step evaluation in a real conversation and produce the cumulative rewards. 45 | 0. Certainly, we might want to verify if SCOPE really did choose LLM responses that really lead to higher cumulative rewards in an actual conversation. You can certainly wrap `run_evaluation_singular.py` with an iterative loop to evaluate this. For convenience, we have introduced a wrapper to help evaluate the cumulative rewards actually produced by responses selected by SCOPE in a multi-turn conversation. 46 | 1. Run `python3 -u evaluation/run_evaluation.py --reward_func=length_human --cuda_for_llm_reward=0 --cuda_for_q_embedding_transition=1 --lr=0.0001 --evaluation_depth=4 --mcts_time=2 --agent=random --result_file=output --trials=1 --evaluation_data=camera_ready.txt 2>&1 | tee output.out` 47 | 2. We can observe the cumulative rewards and the actual conversation generated from our LLM responses. 48 | ``` 49 | all rewards from trials: [1.1600000000000001] 50 | mean: 1.1600000000000001 51 | std error: nan 52 | 0%| | 0/1 [00:18 float:`. This functions tells us, given a previous conversation state, a specific action (LLM response) and a transition to the next human response, what is the associated rewards? This also corresponds to the instantaneous reward at one transition step in the MDP. A new reward class needs to have this function. Furthermore, this new reward class needs to calculate the reward associated with each point in semantic space (needs to be learnt) to perform planning with SCOPE and the reward associated with real conversation text, for evaluation purposes. 85 | - To train a new reward model that knows the instantaneous reward associated with each point in semantic space, you can take any text data, find its ground-truth reward label and project the text into embedding space with `Meta-Llama-Guard-2-8B` (we use this as our embedding model in our paper). Hence, your reward model needs to learn the mapping between the embedding and the reward label (e.g., using a neural network). This can be learnt offline and loaded during planning. A good example to start is to look at `reward/Embedding_Length_Reward.py`, which loads a `embedding_length_reward` torch neural network model that predicts the reward associated with an embedding tuple. 86 | 87 | # Training your own transition models 88 | 1. Pre-process your dataset by converting the conversations into the embedding vectors using the embedding model with `python3 train/embed_dataset.py`. 89 | 2. Train the transition models with `python3 train/train_transition.py --seed=0` for seeds ${0,1,2,3}$. 90 | -------------------------------------------------------------------------------- /evaluation/run_evaluation_singular.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | 4 | from monte_carlo_tree_search.policy_agent import * 5 | from monte_carlo_tree_search.qtable import QTable, DeepQSemanticFunction, ReplayBufferDeepQFunction 6 | from monte_carlo_tree_search.conversation_env import * 7 | from agent.Model import create_human_and_llm 8 | from transformers import AutoTokenizer, BertModel, AutoModel 9 | from transition_models.transition_model import TransitionModel, TransitionModelMOE, TransitionModelMDN 10 | from transition_models.embedding_model import embedding_model_mistral, embedding_model_nomic, embedding_model_llama 11 | from reward.Embedding_Length_Reward import Embedding_Length_Reward 12 | from reward.Human_Length_Reward import Human_Length_Reward 13 | from reward.Llama_2_Guard_Reward import Llama_2_Guard_Reward 14 | 15 | import torch 16 | from scipy import stats 17 | import numpy as np 18 | import torch 19 | import os.path 20 | import time 21 | import random 22 | import pandas as pd 23 | 24 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 25 | 26 | from datasets import load_from_disk 27 | from itertools import repeat 28 | 29 | import logging 30 | [logging.getLogger(name).setLevel(logging.ERROR) for name in logging.root.manager.loggerDict if "transformers" in logging.getLogger(name).name.lower()] 31 | 32 | # Parse command line arguments 33 | parser = ArgumentParser() 34 | parser.add_argument("--evaluation_data", help="evaluation_data", default="evaluation_starters_simple.txt") 35 | parser.add_argument("--evaluation_start", help="start index for evaluation data", type=int, default=0) 36 | parser.add_argument("--evaluation_end", help="end index for evaluation data", type=int, default=100) 37 | parser.add_argument("--evaluation_depth", help="number of sequential actions to evaluate", default=5) 38 | parser.add_argument("--mcts_search_depth", help="mcts search depth; only applies to mcts approaches", default=8) 39 | parser.add_argument("--mcts_time", help="mcts search time budget", default=100) 40 | parser.add_argument("--pretrained_q_function", help="pre-learnt q function for heuristic or initialization", default="model_pretrained_qfn") 41 | parser.add_argument("--result_file", help="result_file_name", default="evaluation_results") 42 | parser.add_argument("--agent", help="agent type") 43 | parser.add_argument("--embedding", default="llama") 44 | parser.add_argument("--cuda_for_llm_reward", help="cuda", default="0") 45 | parser.add_argument("--cuda_for_q_embedding_transition", help="cuda", default="0") 46 | parser.add_argument("--transition_model_dir", help="directory containing transition models", default="transition_models/deterministic/") 47 | parser.add_argument("--reward_decay", default=0.9) 48 | parser.add_argument("--trials", help="trials", default=5) 49 | parser.add_argument("--reward_func", help="reward", default="harmful") 50 | parser.add_argument("--lr", default=0.0001) 51 | parser.add_argument("--use_icl", default=False, action="store_true") 52 | parser.add_argument("--use_descriptions", default=False, action="store_true") 53 | parser.add_argument("--seed", default=42, type=int) 54 | parser.add_argument("--config", default="agent/llm_config.yaml") 55 | args = vars(parser.parse_args()) 56 | print("command-line args: ", args) 57 | 58 | assert not(args["use_icl"] and args["use_descriptions"]), "Cannot use both icl and descriptions" 59 | 60 | seed = args["seed"] 61 | 62 | trials = int(args["trials"]) 63 | reward_func = args["reward_func"] 64 | evaluation_output = args["result_file"] 65 | evaluation_data = args["evaluation_data"] 66 | evaluation_action_depth = int(args["evaluation_depth"]) 67 | runtime_mcts_search_depth = int(args["mcts_search_depth"]) 68 | runtime_mcts_timeout = int(args["mcts_time"]) 69 | agent_ = args["agent"] 70 | embedding_type = args["embedding"] 71 | transition_model_dir = args["transition_model_dir"] 72 | if "mdn" in transition_model_dir.lower(): 73 | TransitionModel_type = TransitionModelMDN 74 | else: 75 | TransitionModel_type = TransitionModelMOE 76 | 77 | # Cuda devices 78 | convert = lambda s: int(s) if s.isdigit() else s 79 | cuda_q_embedding = convert(args["cuda_for_q_embedding_transition"]) 80 | cuda_transition = convert(args["cuda_for_q_embedding_transition"]) 81 | cuda_llm = convert(args["cuda_for_llm_reward"]) 82 | cuda_reward = convert(args["cuda_for_llm_reward"]) 83 | 84 | reward_decay = float(args["reward_decay"]) 85 | lr = float(args["lr"]) 86 | 87 | start_index = args["evaluation_start"] 88 | end_index = args["evaluation_end"] 89 | 90 | # get the convo starters for evaluation 91 | file_name = 'evaluation/' + str(evaluation_data) 92 | human_prompts = repeat(None) 93 | human_descriptions = repeat(None) 94 | if os.path.isfile(file_name): 95 | with open(file_name) as f: 96 | evaluation_starters = f.readlines() 97 | end_index = min(end_index, len(evaluation_starters)) if end_index > 0 else len(evaluation_starters) 98 | evaluation_starters = evaluation_starters[start_index: end_index] 99 | elif os.path.isdir(evaluation_data): 100 | tmp = load_from_disk(evaluation_data) 101 | evaluation_starters = [i[0].strip() for i in tmp["conversations"]] 102 | end_index = min(end_index, len(evaluation_starters)) if end_index > 0 else len(evaluation_starters) 103 | evaluation_starters = evaluation_starters[start_index: end_index] 104 | if args["use_icl"]: 105 | human_prompts = [[j.strip() for j in i[::2]] for i in tmp["conversations"]] 106 | human_prompts = human_prompts[start_index: end_index] 107 | if args["use_descriptions"]: 108 | human_descriptions = [[i] for i in pd.read_pickle(evaluation_data + "/descriptions.pkl")["filtered_description"]] 109 | human_descriptions = human_descriptions[start_index: end_index] 110 | 111 | # create the llm and human simulator 112 | cg = args["config"] 113 | human_sim, human_eval, llm_agent = create_human_and_llm(config=cg, cuda=cuda_llm) 114 | 115 | if reward_func == "length_human": 116 | reward_function = Embedding_Length_Reward(add_llm_length=False, device_map=cuda_reward) 117 | if reward_func == "length_both": 118 | reward_function = Embedding_Length_Reward(add_llm_length=True, device_map=cuda_reward) 119 | if reward_func == "harmful": 120 | reward_function = Llama_2_Guard_Reward(device_map=cuda_reward) 121 | 122 | agents = [] 123 | agent_type = [] 124 | 125 | if agent_ == "zero_step_greedy": 126 | greedy_agent = GreedyAgent(zero_step_greedy_reward_generator(human_sim, reward_function.get_reward), llm_agent) # infer human's next response and choose best one 127 | agent_type.append(agent_) 128 | agents.append(greedy_agent) 129 | 130 | if agent_ == "one_step_greedy": 131 | greedy_agent = GreedyAgent(one_step_greedy_reward_generator(human_sim, reward_function.get_reward), llm_agent) # infer human's next response and choose best one 132 | agent_type.append(agent_) 133 | agents.append(greedy_agent) 134 | 135 | if agent_ == "random": 136 | random_agent = RandomAgent(llm_agent) 137 | agent_type.append(agent_) 138 | agents.append(random_agent) 139 | 140 | if agent_ == "pure_offline": 141 | model = torch.load(args["pretrained_q_function"], map_location=torch.device(cuda_q_embedding)) 142 | pure_offline_agent = OfflineAgent(model, llm_agent) # use pretrained q functon, don't do any mcts 143 | agent_type.append(agent_) 144 | agents.append(pure_offline_agent) 145 | 146 | if agent_ == "pure_online": 147 | pure_online_mcts_agent = OnlineAgent(ReplayBufferDeepQFunction(alpha=lr, steps_update=50, cuda=torch.device(cuda_q_embedding)), runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function, search_space="response_space", reward_decay=reward_decay) # use a brand new q function and do mcts during runtime 148 | agent_type.append(agent_) 149 | agents.append(pure_online_mcts_agent) 150 | 151 | if agent_ == "offline_online_mixed": 152 | model = torch.load(args["pretrained_q_function"], map_location=torch.device(cuda_q_embedding)) 153 | pretrained_offline_online_mcts_agent = OnlineAgent(model, runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function) # use pretrained q function and perform mcts 154 | agent_type.append(agent_) 155 | agents.append(pretrained_offline_online_mcts_agent) 156 | 157 | if agent_ == "semantic_online": 158 | embed_model=None 159 | dim = None 160 | 161 | if embedding_type == "llama": 162 | if reward_func == "harmful": 163 | model = reward_function.model 164 | else: 165 | model = None 166 | embed_model = embedding_model_llama(model=model, cuda=torch.device(cuda_q_embedding)) 167 | dim = embed_model.output_dim 168 | 169 | if reward_func == "length_human": 170 | reward_function = Embedding_Length_Reward(add_llm_length=False) 171 | elif reward_func == "length_both": 172 | reward_function = Embedding_Length_Reward(add_llm_length=True) 173 | transition_model = TransitionModel_type(noise=0.005, cuda=cuda_transition, transition_model_dir=args["transition_model_dir"]) # need to convert to cuda. Now using CPU (does it matter?). 174 | semanticqfunction = DeepQSemanticFunction(dim=dim, alpha=lr, cuda=torch.device(cuda_q_embedding), steps_update=50) # more sophisticated Q function? 175 | pure_online_agent_semantic_agent = OnlineAgent(semanticqfunction, runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function, search_space="semantic_space", transition_model=transition_model, embedding_model=embed_model) # online SEMANTIC space agent 176 | 177 | agent_type.append(agent_) 178 | agents.append(pure_online_agent_semantic_agent) 179 | 180 | if agent_ == "semantic_exhaustive": 181 | if reward_func == "harmful": 182 | model = reward_function.model 183 | else: 184 | model = None 185 | embed_model = embedding_model_llama(model=model, cuda=torch.device(cuda_q_embedding)) 186 | transition_model = TransitionModel_type(noise=0.005, cuda=cuda_transition, transition_model_dir=args["transition_model_dir"]) # need to convert to cuda. Now using CPU (does it matter?). 187 | pure_online_agent_semantic_agent = ExhastiveOnlineAgent(runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function, search_space="semantic_space", reward_decay=reward_decay, transition_model=transition_model, embedding_model=embed_model) # online SEMANTIC space agent 188 | agent_type.append(agent_) 189 | agents.append(pure_online_agent_semantic_agent) 190 | 191 | np.random.seed(seed) 192 | torch.manual_seed(seed) 193 | random.seed(seed) 194 | 195 | # create the mdp environment for evaluation 196 | evaluation_conversation_env = conversation_environment(human_eval, llm_agent, "", max_depth=evaluation_action_depth*2, reward_function=reward_function) 197 | for agent,type in zip(agents, agent_type): 198 | best_response, possible_actions, rewards = run_evaluations_singular(agent, type, evaluation_conversation_env, evaluation_starters[0], context_list = human_prompts, human_descriptions = human_descriptions, index = list(range(start_index, end_index)), seed = seed, output_file=evaluation_output) 199 | print(possible_actions) 200 | print(rewards) -------------------------------------------------------------------------------- /evaluation/run_evaluation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | 4 | from monte_carlo_tree_search.policy_agent import * 5 | from monte_carlo_tree_search.qtable import QTable, DeepQSemanticFunction, ReplayBufferDeepQFunction 6 | from monte_carlo_tree_search.conversation_env import * 7 | from agent.Model import create_human_and_llm 8 | from transformers import AutoTokenizer, BertModel, AutoModel 9 | from transition_models.transition_model import TransitionModel, TransitionModelMOE, TransitionModelMDN 10 | from transition_models.embedding_model import embedding_model_mistral, embedding_model_nomic, embedding_model_llama 11 | from reward.Embedding_Length_Reward import Embedding_Length_Reward 12 | from reward.Human_Length_Reward import Human_Length_Reward 13 | from reward.Llama_2_Guard_Reward import Llama_2_Guard_Reward 14 | 15 | import torch 16 | from scipy import stats 17 | import numpy as np 18 | import torch 19 | import os.path 20 | import time 21 | import random 22 | import pandas as pd 23 | 24 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 25 | 26 | from datasets import load_from_disk 27 | from itertools import repeat 28 | 29 | import logging 30 | [logging.getLogger(name).setLevel(logging.ERROR) for name in logging.root.manager.loggerDict if "transformers" in logging.getLogger(name).name.lower()] 31 | 32 | # Parse command line arguments 33 | parser = ArgumentParser() 34 | parser.add_argument("--evaluation_data", help="evaluation_data", default="evaluation_starters_simple.txt") 35 | parser.add_argument("--evaluation_start", help="start index for evaluation data", type=int, default=0) 36 | parser.add_argument("--evaluation_end", help="end index for evaluation data", type=int, default=100) 37 | parser.add_argument("--evaluation_depth", help="number of sequential actions to evaluate", default=5) 38 | parser.add_argument("--mcts_search_depth", help="mcts search depth; only applies to mcts approaches", default=8) 39 | parser.add_argument("--mcts_time", help="mcts search time budget", default=100) 40 | parser.add_argument("--pretrained_q_function", help="pre-learnt q function for heuristic or initialization", default="model_pretrained_qfn") 41 | parser.add_argument("--result_file", help="result_file_name", default="evaluation_results") 42 | parser.add_argument("--agent", help="agent type") 43 | parser.add_argument("--embedding", default="llama") 44 | parser.add_argument("--cuda_for_llm_reward", help="cuda", default="0") 45 | parser.add_argument("--cuda_for_q_embedding_transition", help="cuda", default="0") 46 | parser.add_argument("--transition_model_dir", help="directory containing transition models", default="models/deterministic/") 47 | parser.add_argument("--reward_decay", default=0.9) 48 | parser.add_argument("--trials", help="trials", default=5) 49 | parser.add_argument("--reward_func", help="reward", default="harmful") 50 | parser.add_argument("--lr", default=0.0001) 51 | parser.add_argument("--use_icl", default=False, action="store_true") 52 | parser.add_argument("--use_descriptions", default=False, action="store_true") 53 | parser.add_argument("--seed", default=42, type=int) 54 | parser.add_argument("--config", default="agent/llm_config.yaml") 55 | args = vars(parser.parse_args()) 56 | print("command-line args: ", args) 57 | 58 | assert not(args["use_icl"] and args["use_descriptions"]), "Cannot use both icl and descriptions" 59 | 60 | seed = args["seed"] 61 | 62 | trials = int(args["trials"]) 63 | reward_func = args["reward_func"] 64 | evaluation_output = args["result_file"] 65 | evaluation_data = args["evaluation_data"] 66 | evaluation_action_depth = int(args["evaluation_depth"]) 67 | runtime_mcts_search_depth = int(args["mcts_search_depth"]) 68 | runtime_mcts_timeout = int(args["mcts_time"]) 69 | agent_ = args["agent"] 70 | embedding_type = args["embedding"] 71 | transition_model_dir = args["transition_model_dir"] 72 | if "mdn" in transition_model_dir.lower(): 73 | TransitionModel_type = TransitionModelMDN 74 | else: 75 | TransitionModel_type = TransitionModelMOE 76 | 77 | # Cuda devices 78 | convert = lambda s: int(s) if s.isdigit() else s 79 | cuda_q_embedding = convert(args["cuda_for_q_embedding_transition"]) 80 | cuda_transition = convert(args["cuda_for_q_embedding_transition"]) 81 | cuda_llm = convert(args["cuda_for_llm_reward"]) 82 | cuda_reward = convert(args["cuda_for_llm_reward"]) 83 | 84 | reward_decay = float(args["reward_decay"]) 85 | lr = float(args["lr"]) 86 | 87 | start_index = args["evaluation_start"] 88 | end_index = args["evaluation_end"] 89 | 90 | # get the convo starters for evaluation 91 | file_name = 'evaluation/' + str(evaluation_data) 92 | human_prompts = repeat(None) 93 | human_descriptions = repeat(None) 94 | if os.path.isfile(file_name): 95 | with open(file_name) as f: 96 | evaluation_starters = f.readlines() 97 | end_index = min(end_index, len(evaluation_starters)) if end_index > 0 else len(evaluation_starters) 98 | evaluation_starters = evaluation_starters[start_index: end_index] 99 | elif os.path.isdir(evaluation_data): 100 | tmp = load_from_disk(evaluation_data) 101 | evaluation_starters = [i[0].strip() for i in tmp["conversations"]] 102 | end_index = min(end_index, len(evaluation_starters)) if end_index > 0 else len(evaluation_starters) 103 | evaluation_starters = evaluation_starters[start_index: end_index] 104 | if args["use_icl"]: 105 | human_prompts = [[j.strip() for j in i[::2]] for i in tmp["conversations"]] 106 | human_prompts = human_prompts[start_index: end_index] 107 | if args["use_descriptions"]: 108 | human_descriptions = [[i] for i in pd.read_pickle(evaluation_data + "/descriptions.pkl")["filtered_description"]] 109 | human_descriptions = human_descriptions[start_index: end_index] 110 | 111 | # create the llm and human simulator 112 | cg = args["config"] 113 | human_sim, human_eval, llm_agent = create_human_and_llm(config=cg, cuda=cuda_llm) 114 | 115 | if reward_func == "length_human": 116 | reward_function = Embedding_Length_Reward(add_llm_length=False, device_map=cuda_reward) 117 | if reward_func == "length_both": 118 | reward_function = Embedding_Length_Reward(add_llm_length=True, device_map=cuda_reward) 119 | if reward_func == "harmful": 120 | reward_function = Llama_2_Guard_Reward(device_map=cuda_reward) 121 | 122 | agents = [] 123 | agent_type = [] 124 | 125 | if agent_ == "zero_step_greedy": 126 | greedy_agent = GreedyAgent(zero_step_greedy_reward_generator(human_sim, reward_function.get_reward), llm_agent) # infer human's next response and choose best one 127 | agent_type.append(agent_) 128 | agents.append(greedy_agent) 129 | 130 | if agent_ == "one_step_greedy": 131 | greedy_agent = GreedyAgent(one_step_greedy_reward_generator(human_sim, reward_function.get_reward), llm_agent) # infer human's next response and choose best one 132 | agent_type.append(agent_) 133 | agents.append(greedy_agent) 134 | 135 | if agent_ == "random": 136 | random_agent = RandomAgent(llm_agent) 137 | agent_type.append(agent_) 138 | agents.append(random_agent) 139 | 140 | if agent_ == "pure_offline": 141 | model = torch.load(args["pretrained_q_function"], map_location=torch.device(cuda_q_embedding)) 142 | pure_offline_agent = OfflineAgent(model, llm_agent) # use pretrained q functon, don't do any mcts 143 | agent_type.append(agent_) 144 | agents.append(pure_offline_agent) 145 | 146 | if agent_ == "pure_online": 147 | pure_online_mcts_agent = OnlineAgent(ReplayBufferDeepQFunction(alpha=lr, steps_update=50, cuda=torch.device(cuda_q_embedding)), runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function, search_space="response_space", reward_decay=reward_decay) # use a brand new q function and do mcts during runtime 148 | agent_type.append(agent_) 149 | agents.append(pure_online_mcts_agent) 150 | 151 | if agent_ == "offline_online_mixed": 152 | model = torch.load(args["pretrained_q_function"], map_location=torch.device(cuda_q_embedding)) 153 | pretrained_offline_online_mcts_agent = OnlineAgent(model, runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function) # use pretrained q function and perform mcts 154 | agent_type.append(agent_) 155 | agents.append(pretrained_offline_online_mcts_agent) 156 | 157 | if agent_ == "semantic_online": 158 | embed_model=None 159 | dim = None 160 | 161 | if embedding_type == "llama": 162 | if reward_func == "harmful": 163 | model = reward_function.model 164 | else: 165 | model = None 166 | embed_model = embedding_model_llama(model=model, cuda=torch.device(cuda_q_embedding)) 167 | dim = embed_model.output_dim 168 | 169 | if reward_func == "length_human": 170 | reward_function = Embedding_Length_Reward(add_llm_length=False) 171 | elif reward_func == "length_both": 172 | reward_function = Embedding_Length_Reward(add_llm_length=True) 173 | transition_model = TransitionModel_type(noise=0.005, cuda=cuda_transition, transition_model_dir=args["transition_model_dir"]) # need to convert to cuda. Now using CPU (does it matter?). 174 | semanticqfunction = DeepQSemanticFunction(dim=dim, alpha=lr, cuda=torch.device(cuda_q_embedding), steps_update=50) # more sophisticated Q function? 175 | pure_online_agent_semantic_agent = OnlineAgent(semanticqfunction, runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function, search_space="semantic_space", transition_model=transition_model, embedding_model=embed_model) # online SEMANTIC space agent 176 | 177 | agent_type.append(agent_) 178 | agents.append(pure_online_agent_semantic_agent) 179 | 180 | if agent_ == "semantic_exhaustive": 181 | if reward_func == "harmful": 182 | model = reward_function.model 183 | else: 184 | model = None 185 | embed_model = embedding_model_llama(model=model, cuda=torch.device(cuda_q_embedding)) 186 | transition_model = TransitionModel_type(noise=0.005, cuda=cuda_transition, transition_model_dir=args["transition_model_dir"]) # need to convert to cuda. Now using CPU (does it matter?). 187 | pure_online_agent_semantic_agent = ExhastiveOnlineAgent(runtime_mcts_search_depth, runtime_mcts_timeout, llm_agent, human_sim, reward_function, search_space="semantic_space", reward_decay=reward_decay, transition_model=transition_model, embedding_model=embed_model) # online SEMANTIC space agent 188 | agent_type.append(agent_) 189 | agents.append(pure_online_agent_semantic_agent) 190 | 191 | np.random.seed(seed) 192 | torch.manual_seed(seed) 193 | random.seed(seed) 194 | 195 | # create the mdp environment for evaluation 196 | evaluation_conversation_env = conversation_environment(human_eval, llm_agent, "", max_depth=evaluation_action_depth*2, reward_function=reward_function) 197 | all_results = [] 198 | all_results.append(evaluation_starters) 199 | all_results_dict = [] 200 | for agent,type in zip(agents, agent_type): 201 | results = {} 202 | start = time.time() 203 | result_row, convo_generated = run_evaluations(agent, type, evaluation_conversation_env, evaluation_starters, evaluation_action_depth, trials, context_list = human_prompts, human_descriptions = human_descriptions, results=results, index = list(range(start_index, end_index)), seed = seed, output_file=evaluation_output) 204 | all_results.append(result_row) 205 | time_taken = time.time()-start 206 | results["time_taken_for_agent_type"] = time_taken 207 | print("time taken for all trials:", time_taken) 208 | for starters in convo_generated: 209 | print("input conversation starter: ", starters, "\n") 210 | print("conversation generated: ", convo_generated[starters]) 211 | all_results_dict.append(results) 212 | 213 | import lz4.frame 214 | 215 | with lz4.frame.open(evaluation_output+'_dump.pkl', 'wb') as f: 216 | pickle.dump(all_results_dict, f) 217 | 218 | all_results = [list(i) for i in zip(*all_results)] # transpose 219 | import csv 220 | 221 | with open(evaluation_output+'.csv', 'w', newline='') as f: 222 | writer = csv.writer(f) 223 | writer.writerows(all_results) 224 | 225 | import pickle 226 | with open(evaluation_output+'.pkl', 'wb') as f: 227 | pickle.dump(convo_generated, f) -------------------------------------------------------------------------------- /train/train_transition.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | 4 | from datasets import load_from_disk 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch import nn, optim 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | import wandb 11 | import argparse 12 | 13 | from mixture_of_experts import HeirarchicalMoE 14 | from transition_models.regression_wrapper import RegressionWrapper 15 | 16 | # Define the transformation function 17 | def transform_samples_wrapper(start, step, use_residuals=False): 18 | def transform_samples(batch): 19 | embeddings = batch['embeddings'] 20 | batched = isinstance(embeddings, list) or (len(embeddings.shape) == 3) 21 | if not batched: 22 | embeddings = [embeddings] 23 | inputs = [] 24 | outputs = [] 25 | for d in embeddings: 26 | inputs += [d[i-step] for i in range(start + step, len(d), 2)] 27 | if use_residuals: 28 | outputs += [d[i] - d[i-step] for i in range(start + step, len(d), 2)] 29 | else: 30 | outputs += [d[i] for i in range(start + step, len(d), 2)] 31 | transformed_samples = { 32 | 'inputs': inputs, 33 | 'outputs': outputs 34 | } 35 | 36 | return transformed_samples 37 | return transform_samples 38 | 39 | def calculate_mean(batch): 40 | input_sum = batch['inputs'].sum(dim=0) 41 | output_sum = batch['outputs'].sum(dim=0) 42 | return {'input_sum': [input_sum], 'output_sum': [output_sum]} 43 | 44 | def sum_of_squared_diff(batch, input_mean, output_mean): 45 | input_squared_diff_sum = (batch['inputs'] - input_mean).square().sum(dim=0) 46 | output_squared_diff_sum = (batch['outputs'] - output_mean).square().sum(dim=0) 47 | return {"input_squared_diff_sum": [input_squared_diff_sum], "output_squared_diff_sum": [output_squared_diff_sum]} 48 | 49 | def normalize_dataset(batch, input_mean, input_std, output_mean, output_std): 50 | inputs = (batch['inputs'] - input_mean) / input_std 51 | outputs = (batch['outputs'] - output_mean) / output_std 52 | return {'inputs': inputs, 'outputs': outputs} 53 | 54 | def load_dataset(**kwargs) -> tuple[DataLoader, DataLoader, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 55 | # Load dataset 56 | hf_dataset = load_from_disk(kwargs["dataset"]).with_format("torch") 57 | hf_dataset = hf_dataset.map( 58 | transform_samples_wrapper(kwargs['start'], kwargs['step'], use_residuals=not kwargs['not_residuals']), 59 | remove_columns=hf_dataset.column_names, batched=True, batch_size=2000, 60 | num_proc=32, 61 | ) 62 | hf_dataset = hf_dataset.train_test_split(test_size=0.1, seed=kwargs['seed'], shuffle=True) 63 | print(f"length of dataset {len(hf_dataset['train']) + len(hf_dataset['test'])}, with train length {len(hf_dataset['train'])} and test length {len(hf_dataset['test'])}") 64 | 65 | # Normalize input and output embeddings for training stability 66 | print("Calculating mean and std of inputs and outputs...") 67 | sums_dataset = hf_dataset["train"].map( 68 | calculate_mean, 69 | remove_columns=hf_dataset["train"].column_names, batched=True, batch_size=1000, 70 | ) 71 | input_mean = sum(sums_dataset["input_sum"]) / len(hf_dataset["train"]) 72 | output_mean = sum(sums_dataset["output_sum"]) / len(hf_dataset["train"]) 73 | squared_diff = hf_dataset["train"].map( 74 | sum_of_squared_diff, 75 | remove_columns=hf_dataset["train"].column_names, batched=True, batch_size=1000, 76 | fn_kwargs={'input_mean': input_mean, 'output_mean': output_mean}) 77 | input_std = torch.sqrt(sum(squared_diff["input_squared_diff_sum"]) / len(hf_dataset["train"])) 78 | output_std = torch.sqrt(sum(squared_diff["output_squared_diff_sum"]) / len(hf_dataset["train"])) 79 | 80 | normalized_train = hf_dataset["train"].map(normalize_dataset, fn_kwargs={ 81 | 'input_mean': input_mean, 'input_std': input_std, 'output_mean': output_mean, 'output_std': output_std 82 | }, batched=True, batch_size=10000) 83 | normalized_test = hf_dataset["test"].map(normalize_dataset, fn_kwargs={ 84 | 'input_mean': input_mean, 'input_std': input_std, 'output_mean': output_mean, 'output_std': output_std 85 | }, batched=True, batch_size=10000) 86 | 87 | print("Mean and std calculated.") 88 | 89 | # Convert custom dataset to DataLoader for batching 90 | train_dataset = DataLoader( 91 | normalized_train, 92 | batch_size=kwargs["batch_size"], 93 | ) 94 | val_dataset = DataLoader( 95 | normalized_test, 96 | batch_size=8192, 97 | ) 98 | 99 | return train_dataset, val_dataset, input_mean, input_std, output_mean, output_std 100 | 101 | def initialize_model(device, **kwargs): 102 | print(f"Initializing model... on device {device}") 103 | 104 | torch.manual_seed(kwargs["seed"]) 105 | model = HeirarchicalMoE(dim = 1024) 106 | 107 | model.to(device) 108 | 109 | print(model) 110 | print("Model initialized.") 111 | return model 112 | 113 | def train_transition_model(**kwargs): 114 | seed=kwargs["seed"] 115 | epochs=kwargs["epochs"] 116 | lr=kwargs["lr"] 117 | gamma=kwargs["gamma"] 118 | batch_size=kwargs["batch_size"] 119 | transition_type=kwargs["transition_type"] 120 | use_wandb = kwargs.get("use_wandb", False) 121 | 122 | device = 0 123 | 124 | outdir = f"{kwargs['out_dir']}/seed_{seed}_batch_{batch_size}/{transition_type}" 125 | Path(outdir).mkdir(parents=True, exist_ok=True) 126 | 127 | train_dataset, val_dataset, input_mean, input_std, output_mean, output_std = load_dataset(**kwargs) 128 | 129 | model = initialize_model(device = device, **kwargs) 130 | 131 | # Define loss function and optimizer 132 | criterion = nn.MSELoss() 133 | optimizer = optim.Adam(model.parameters(), lr=lr) 134 | lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 135 | 136 | # Initialize wandb 137 | if use_wandb: 138 | run = wandb.init(project=kwargs["wandb_proj"], name=f"seed_{seed}_{transition_type}", config=kwargs) 139 | run.save("train_transition_distributed.py") 140 | run.watch(model, log="all", log_graph=True, criterion=criterion) 141 | 142 | if kwargs["continue_from"] is not None: 143 | checkpoint = torch.load(kwargs["continue_from"]) 144 | start_epoch = checkpoint['epoch'] 145 | model.load_state_dict(checkpoint['model_state_dict']["model"]) 146 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 147 | lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 148 | else: 149 | start_epoch = 0 150 | 151 | # Training loop 152 | min_val_loss = float('inf') 153 | min_val_loss_epoch = 0 154 | min_train_loss = float('inf') 155 | min_train_loss_epoch = 0 156 | regression_model = RegressionWrapper(model, embedding_size = input_mean.size(0)) 157 | regression_model.set_parameters(input_mean, input_std, output_mean, output_std, use_residuals = not kwargs['not_residuals']) 158 | for epoch in tqdm(range(start_epoch, epochs), leave=True): 159 | train_loss = 0.0 160 | aux_loss = 0.0 161 | 162 | # Training loop 163 | model.train() 164 | for batch_no, batch in enumerate(tqdm(train_dataset, leave=True, mininterval=10.0)): 165 | inputs = batch["inputs"].to(device) 166 | targets = batch["outputs"].to(device) 167 | 168 | optimizer.zero_grad() # Zero the gradient buffers 169 | 170 | outputs, curr_aux_loss = model(inputs[:, None]) # Forward pass 171 | loss = criterion(outputs[:,0,:], targets) # Compute the loss 172 | 173 | (loss + curr_aux_loss).backward() # Backward pass 174 | 175 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 176 | 177 | optimizer.step() # Update the weights 178 | 179 | train_loss += loss.item() 180 | aux_loss += curr_aux_loss.item() 181 | 182 | fractional_epoch = epoch + batch_no / len(train_dataset) 183 | if batch_no % 100 == 0 and use_wandb and fractional_epoch > 0.1: 184 | run.log({ 185 | "Epoch": fractional_epoch, 186 | "Intermediate Training Loss": loss.item(), 187 | "Intermediate Aux Loss": curr_aux_loss.item(), 188 | "Intermediate Total Loss": loss.item() + curr_aux_loss.item() 189 | }) 190 | train_loss /= len(train_dataset) 191 | aux_loss /= len(train_dataset) 192 | 193 | 194 | # Validation loop 195 | model.eval() 196 | with torch.no_grad(): 197 | val_loss = 0.0 198 | for batch in val_dataset: 199 | inputs = batch['inputs'].to(device) 200 | targets = batch['outputs'].to(device) 201 | 202 | outputs, aux_loss = model(inputs[:, None]) 203 | val_loss += criterion(outputs[:,0,:], targets).item() 204 | val_loss /= len(val_dataset) 205 | tqdm.write(f'{transition_type.ljust(12)}Epoch {epoch + 1}/{epochs}, lr: {lr_scheduler.get_last_lr()[0]:.3e}, Training loss: {train_loss:.5e}, Validation loss: {val_loss:.5e}') 206 | 207 | if use_wandb: 208 | # Log metrics to wandb 209 | run.log({ 210 | f"Epoch": epoch+1, 211 | f"Training Loss": train_loss, 212 | f"Validation Loss": val_loss, 213 | f"Aux Loss": aux_loss, 214 | "lr": lr_scheduler.get_last_lr()[0] 215 | }) 216 | if True or epoch > 10: 217 | if val_loss < min_val_loss: 218 | min_val_loss_epoch = epoch 219 | min_val_loss = val_loss 220 | regression_model.model.load_state_dict(model.state_dict()) 221 | checkpoint = { 222 | 'epoch': epoch, 223 | 'model_state_dict': regression_model.state_dict(), 224 | 'optimizer_state_dict': optimizer.state_dict(), 225 | } 226 | torch.save(checkpoint, f'{outdir}/model_min_val.pth') 227 | if train_loss < min_train_loss: 228 | min_train_loss_epoch = epoch 229 | min_train_loss = train_loss 230 | regression_model.model.load_state_dict(model.state_dict()) 231 | checkpoint = { 232 | 'epoch': epoch, 233 | 'model_state_dict': regression_model.state_dict(), 234 | 'optimizer_state_dict': optimizer.state_dict(), 235 | 'scheduler_state_dict': lr_scheduler.state_dict() 236 | } 237 | torch.save(checkpoint, f'{outdir}/model_min_train.pth') 238 | lr_scheduler.step() 239 | 240 | print("Training complete.") 241 | print(f"Minimum validation loss: {min_val_loss:.5e} at epoch {min_val_loss_epoch}") 242 | print(f"Minimum training loss: {min_train_loss:.5e} at epoch {min_train_loss_epoch}") 243 | 244 | # Write to a txt file 245 | with open(f"{outdir}/results.txt", "w") as file: 246 | file.write("Training complete.\n") 247 | file.write(f"Minimum validation loss: {min_val_loss} at epoch {min_val_loss_epoch}\n") 248 | file.write(f"Minimum training loss: {min_train_loss} at epoch {min_train_loss_epoch}\n") 249 | 250 | if __name__ == "__main__": 251 | parser = argparse.ArgumentParser() 252 | parser.add_argument("--seed", type=int, help="Seed for random number generation", default=0) 253 | parser.add_argument("--epochs", type=int, help="Number of epochs for training", default=100) 254 | parser.add_argument("--lr", type=float, help="Learning rate for optimizer", default=0.001) 255 | parser.add_argument("--gamma", type=float, help="Exponential decay gamma for learning rate scheduler", default=0.9) 256 | parser.add_argument("--batch_size", type=int, help="Batch size for training", default=2048) 257 | parser.add_argument("--type_index", type=int, help="Index of the transition type to train", default=-1) 258 | parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging") 259 | parser.add_argument("--wandb_proj", type=str, help="Wandb project name", default="lm-sys_transition_moe") 260 | parser.add_argument("--dataset", type=str, help="dataset location", default="embeddings/lmsys-chat-1m_embeddings_1024") 261 | parser.add_argument("--not_residuals", action="store_true", help="Train on absolute embeddings instead of residuals") 262 | parser.add_argument("--out_dir", type=str, help="Output directory for models", default="transition_models/deterministic") 263 | parser.add_argument("--continue_from", type=str, help="Continue training from a checkpoint", default=None) 264 | args = vars(parser.parse_args()) 265 | 266 | print(args) 267 | 268 | # Conversations start with human 269 | start_steps = [ 270 | (1,1), 271 | (0,1), 272 | (1,2), 273 | (0,2) 274 | ] 275 | types = [ 276 | "llm_human", 277 | "human_llm", 278 | # "llm_llm", 279 | # "human_human" 280 | ] 281 | if args["type_index"] >= 0: 282 | start_steps = [start_steps[args["type_index"]]] 283 | types = [types[args["type_index"]]] 284 | 285 | if args["use_wandb"]: 286 | wandb.setup() 287 | for start_step, transition_type in zip(start_steps, types): 288 | args['transition_type'] = transition_type 289 | args['start'] = start_step[0] 290 | args['step'] = start_step[1] 291 | train_transition_model(**args) 292 | if args["use_wandb"]: 293 | wandb.finish() -------------------------------------------------------------------------------- /monte_carlo_tree_search/qtable.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from monte_carlo_tree_search.qfunction import QFunction 3 | import time 4 | import numpy as np 5 | 6 | def combine_encoded_inputs(input1, input2): 7 | new_encoding = {} 8 | for k in input1.keys(): 9 | # padding first 10 | i1_size = input1[k].shape[1] 11 | i2_size = input2[k].shape[1] 12 | i1 = input1[k] 13 | i2 = input2[k] 14 | if i2_size > i1_size: 15 | i1 = nn.functional.pad(input1[k], (0, i2_size-i1_size), 'constant', 0) 16 | elif i2_size < i1_size: 17 | i2 = nn.functional.pad(input2[k], (0, i1_size-i2_size), 'constant', 0) 18 | new_encoding[k] = torch.cat((i1,i2), 0) 19 | return new_encoding 20 | 21 | 22 | class QTable(QFunction): 23 | def __init__(self, default=0.0): 24 | self.qtable = defaultdict(lambda: default) 25 | 26 | def update(self, state, action, delta, visits, reward): 27 | self.qtable[(state, action)] = self.qtable[(state, action)] + delta 28 | 29 | def get_q_value(self, state, action): 30 | return self.qtable[(state, action)] 31 | 32 | import torch 33 | import torch.nn as nn 34 | from monte_carlo_tree_search.qfunction import QFunction 35 | from torch.optim import Adam 36 | from monte_carlo_tree_search.deep_agent import DeepAgent 37 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 38 | 39 | class DeepQFunction(QFunction, DeepAgent): 40 | """ A neural network to represent the Q-function. 41 | This class uses PyTorch for the neural network framework (https://pytorch.org/). 42 | """ 43 | 44 | def __init__( 45 | self, alpha=0.001, steps_update=100, cuda = torch.device('cuda:2') 46 | ) -> None: 47 | raise NotImplementedError("This class should not be used. Are you sure you are using the right QFunction?.") 48 | self.alpha = alpha 49 | self.steps_update = steps_update 50 | self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") 51 | self.q_network = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-uncased", 52 | num_labels = 1).to(cuda) 53 | self.optimiser = Adam(self.q_network.parameters(), lr=self.alpha) 54 | self.cuda = cuda 55 | 56 | def merge(self, state, action): 57 | # merge conversation, and LLM response together. 58 | return state.conversation + action 59 | 60 | def update(self, state, action, delta, visits, reward): 61 | optimiser = Adam(self.q_network.parameters(), lr=0.01 * (1/visits)**2) 62 | optimiser.zero_grad() # Reset gradients to zero 63 | merged_convo = self.merge(state, action) 64 | merged_convo = str(merged_convo) 65 | if len(merged_convo) > 1000: 66 | merged_convo = merged_convo[-999:] 67 | encoded_input = self.tokenizer(merged_convo, return_tensors='pt') 68 | if len(encoded_input) > 512: 69 | encoded_input = encoded_input[:512] 70 | # if len(merged_convo) > 1000: 71 | # merged_convo = merged_convo[-999:] 72 | encoded_input = self.tokenizer(merged_convo, truncation=True, max_length=512, return_tensors='pt').to(self.cuda) 73 | 74 | #print("output of network before update: ", output.logits) 75 | for x in range(self.steps_update): 76 | optimiser.zero_grad() # Reset gradients to zero 77 | output = self.q_network(**encoded_input, labels = torch.tensor(reward, dtype=torch.float).to(self.cuda)) 78 | output.loss.backward() 79 | optimiser.step() # Do a gradient descent step with the optimiser 80 | #print("output of network after update: ", output.logits) 81 | 82 | def get_q_value(self, state, action): 83 | 84 | merged_convo = self.merge(state, action) 85 | merged_convo = str(merged_convo) 86 | if len(merged_convo) > 1000: 87 | merged_convo = merged_convo[-999:] 88 | # if len(merged_convo) > 1000: 89 | # merged_convo = merged_convo[-999:] 90 | # Convert the state into a tensor 91 | encoded_input = self.tokenizer(merged_convo, truncation=True, max_length=512, return_tensors='pt').to(self.cuda) 92 | with torch.no_grad(): 93 | output = self.q_network(**encoded_input) 94 | return output.logits[0][0] 95 | 96 | def get_qs(self, state, actions): 97 | qs = [] 98 | for action in actions: 99 | merged_convo = self.merge(state, action) 100 | # if len(merged_convo) > 1000: 101 | # merged_convo = merged_convo[-999:] 102 | encoded_input = self.tokenizer(merged_convo, truncation=True, max_length=512, return_tensors='pt').to(self.cuda) 103 | with torch.no_grad(): 104 | reward_estimate = self.q_network(**encoded_input).logits[0][0].cpu() 105 | qs.append(reward_estimate) 106 | return qs 107 | 108 | def get_max_q(self, state, actions): 109 | qs = self.get_qs(state, actions) 110 | arg_max_q = np.argmax(qs) 111 | best_action = actions[arg_max_q] 112 | best_reward = qs[arg_max_q] 113 | return (best_action, best_reward) 114 | 115 | # class DeepQSemanticFunction(QFunction, DeepAgent): 116 | # """ A neural network to represent the Q-function for semantic space 117 | # This class uses PyTorch for the neural network framework (https://pytorch.org/). 118 | # """ 119 | 120 | # def __init__( 121 | # self, dim, alpha=0.001 122 | # ) -> None: 123 | # self.alpha = alpha 124 | # self.dim = dim 125 | # self.q_network = nn.Sequential( 126 | # nn.Linear(dim * 2, 128), 127 | # nn.ReLU(), 128 | # nn.Linear(128, 24), 129 | # nn.ReLU(), 130 | # nn.Linear(24, 12), 131 | # nn.ReLU(), 132 | # nn.Linear(12, 1) 133 | # ) 134 | # self.optimiser = Adam(self.q_network.parameters(), lr=self.alpha) 135 | 136 | # def merge(self, state, action): 137 | # # merge conversation, and LLM response together. 138 | # merged_convo = list(state.conversation) + list(action) 139 | # return torch.Tensor([merged_convo]) 140 | 141 | # def update(self, state, action, delta, visits, reward): 142 | # self.optimiser.lr=0.0005 * (1/visits)**2 143 | # merged_convo = self.merge(state, action) 144 | # for x in range(30): 145 | # self.optimiser.zero_grad() # Reset gradients to zero 146 | # loss_fn = nn.MSELoss() 147 | # y_pred = self.q_network(merged_convo) 148 | # loss = loss_fn(y_pred, torch.tensor([reward],requires_grad=True)) 149 | # loss.backward() 150 | # self.optimiser.step() 151 | 152 | # def get_q_value(self, state, action): 153 | # merged_convo = self.merge(state, action) 154 | # output = self.q_network(merged_convo) 155 | # return output[0][0] 156 | 157 | # def get_max_q(self, state, actions): 158 | 159 | # best_action = None 160 | # best_reward = float("-inf") 161 | # for action in actions: 162 | # merged_convo = self.merge(state, action) 163 | # reward_estimate = self.q_network(merged_convo)[0][0] 164 | # if reward_estimate > best_reward: 165 | # best_action = action 166 | # best_reward = reward_estimate 167 | # return (best_action, best_reward) 168 | 169 | 170 | class DeepQSemanticFunction(QFunction, DeepAgent): 171 | """ A neural network to represent the Q-function for semantic space 172 | This class uses PyTorch for the neural network framework (https://pytorch.org/). 173 | """ 174 | 175 | def __init__( 176 | self, dim, cuda, steps_update, alpha=0.001 177 | ) -> None: 178 | self.alpha = alpha 179 | self.dim = dim 180 | self.update_steps = steps_update 181 | self.cuda = cuda 182 | self.q_network = nn.Sequential( 183 | nn.Linear(dim * 2, dim), 184 | nn.ReLU(), 185 | nn.Linear(dim, int(dim/4)), 186 | nn.ReLU(), 187 | nn.Linear(int(dim/4), 128), 188 | nn.ReLU(), 189 | nn.Linear(128, 64), 190 | nn.ReLU(), 191 | nn.Linear(64, 1) 192 | ).to(cuda) 193 | print(f"Using {cuda} for DeepQSemanticFunction") 194 | self.reset() 195 | 196 | def reset(self): 197 | for layer in self.q_network: 198 | if isinstance(layer, nn.Linear): 199 | layer.reset_parameters() 200 | self.optimiser = Adam(self.q_network.parameters(), lr=self.alpha) 201 | self.replay_buffer = None 202 | self.past_rewards = None 203 | 204 | def merge(self, state, action): 205 | # merge conversation, and LLM response together. 206 | merged_convo = list(state.conversation) + list(action) 207 | return torch.Tensor([merged_convo]) 208 | 209 | def update_buffer(self, input, reward): 210 | if self.past_rewards is None: 211 | self.past_rewards = reward 212 | else: 213 | self.past_rewards = torch.cat((self.past_rewards, reward), 0) 214 | if self.replay_buffer is None: 215 | self.replay_buffer = input 216 | else: 217 | self.replay_buffer = torch.cat((self.replay_buffer, input), 0) 218 | 219 | def update(self, state, action, delta, visits, reward): 220 | loss_fn = nn.MSELoss() 221 | self.optimiser = Adam(self.q_network.parameters(), lr=self.alpha * (1/visits)**2) 222 | merged_convo = self.merge(state, action).to(self.cuda) 223 | reward = torch.tensor(reward,dtype=torch.float).to(self.cuda).unsqueeze(0) 224 | losses = [] 225 | for x in range(self.update_steps): 226 | self.optimiser.zero_grad() # Reset gradients to zero 227 | y_pred = self.q_network(merged_convo) 228 | loss = loss_fn(y_pred.squeeze(1), reward) 229 | losses.append(loss.item()) 230 | loss.backward() 231 | self.optimiser.step() 232 | self.update_buffer(merged_convo, reward) 233 | print(f"loss for regular q update {losses[-1:0:-10][::-1]}") 234 | 235 | if self.replay_buffer is None: 236 | return 237 | 238 | # print("past reward: ", self.past_rewards) 239 | # print("y ", y_pred) 240 | losses = [] 241 | for x in range(self.update_steps): 242 | self.optimiser.zero_grad() # Reset gradients to zero 243 | y_pred = self.q_network(self.replay_buffer) 244 | loss = loss_fn(y_pred.squeeze(), self.past_rewards.squeeze()) 245 | losses.append(loss.item()) 246 | loss.backward() 247 | self.optimiser.step() 248 | 249 | print(f"loss for replay buffer q update {losses[-1:0:-10][::-1]}") 250 | 251 | # for x in range(self.steps_update): 252 | # optimiser.zero_grad() # Reset gradients to zero 253 | # output = self.q_network(**encoded_input, labels = torch.tensor(reward, dtype=torch.float).to(self.cuda)) 254 | # if output.loss == torch.tensor(float('nan')): # if loss becomes nan, reduce LR 255 | # optimiser = Adam(self.q_network.parameters(), lr= 0.1 * self.alpha * (1/visits)**2) 256 | # continue 257 | # output.loss.backward() 258 | # optimiser.step() # Do a gradient descent step with the optimiser 259 | # print("loss in standard update: ", output.loss) 260 | 261 | def get_q_value(self, state, action): 262 | merged_convo = self.merge(state, action).to(self.cuda) 263 | with torch.no_grad(): 264 | output = self.q_network(merged_convo) 265 | return output[0][0] 266 | 267 | def get_qs(self, state, actions): 268 | qs = [] 269 | for action in actions: 270 | merged_convo = self.merge(state, action).to(self.cuda) 271 | with torch.no_grad(): 272 | reward_estimate = self.q_network(merged_convo)[0][0].cpu() 273 | qs.append(reward_estimate) 274 | print("q values estimate for actions are: ", qs) 275 | return qs 276 | 277 | def get_max_q(self, state, actions): 278 | qs = self.get_qs(state, actions) 279 | arg_max_q = np.argmax(qs) 280 | best_action = actions[arg_max_q] 281 | best_reward = qs[arg_max_q] 282 | return (best_action, best_reward) 283 | 284 | 285 | class ReplayBufferDeepQFunction(QFunction, DeepAgent): 286 | """ A neural network to represent the Q-function. 287 | This class uses PyTorch for the neural network framework (https://pytorch.org/). 288 | """ 289 | 290 | def __init__( 291 | self, alpha=0.1, steps_update=100, cuda = torch.device('cuda:2') 292 | ) -> None: 293 | self.alpha = alpha 294 | self.steps_update = steps_update 295 | self.model_name = "google-bert/bert-base-uncased" 296 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 297 | self.cuda = cuda 298 | self.reset() 299 | 300 | def reset(self): 301 | self.q_network = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels = 1).to(self.cuda) 302 | self.optimiser = Adam(self.q_network.parameters(), lr=self.alpha) 303 | self.replay_buffer = None 304 | self.past_rewards = None 305 | 306 | def merge(self, state, action): 307 | # merge conversation, and LLM response together. 308 | return state.conversation + action 309 | 310 | def update_buffer(self, input, reward): 311 | if self.past_rewards is None: 312 | self.past_rewards = reward 313 | else: 314 | self.past_rewards = torch.cat((self.past_rewards, reward), 0) 315 | if self.replay_buffer is None: 316 | self.replay_buffer = input 317 | else: 318 | self.replay_buffer = combine_encoded_inputs(self.replay_buffer, input) 319 | 320 | def update(self, state, action, delta, visits, reward): 321 | merged_convo = self.merge(state, action) 322 | merged_convo = str(merged_convo) 323 | 324 | # update replay buffer 325 | encoded_input = self.tokenizer(merged_convo, truncation=True, max_length=512, padding=True, return_tensors='pt').to(self.cuda) 326 | reward = torch.tensor(reward,dtype=torch.float).to(self.cuda).unsqueeze(0) 327 | self.update_buffer(encoded_input, reward) 328 | 329 | self.q_network.train() 330 | # update based on this specific experience 331 | start_time = time.time() 332 | self.optimiser.param_groups[0]['lr'] *= self.alpha * (1/visits)**2 333 | for x in range(self.steps_update): 334 | self.optimiser.zero_grad() # Reset gradients to zero 335 | output = self.q_network(**encoded_input, labels = reward) 336 | if torch.isnan(output.loss): # if loss becomes nan, reduce LR 337 | self.optimiser.param_groups[0]['lr'] *= 0.1 338 | continue 339 | output.loss.backward() 340 | self.optimiser.step() # Do a gradient descent step with the optimiser 341 | print("loss in standard update: ", output.loss) 342 | print("time taken for update Q", time.time()-start_time) 343 | start_time = time.time() 344 | 345 | # update based on replay buffer 346 | losses = [] 347 | self.optimiser.param_groups[0]['lr'] = 0.3* self.alpha * (1/visits)**2 348 | for x in range(self.steps_update): 349 | self.optimiser.zero_grad() # Reset gradients to zero 350 | output = self.q_network(**self.replay_buffer, labels = self.past_rewards) 351 | 352 | if torch.isnan(output.loss): # if loss becomes nan, reduce LR 353 | self.optimiser.param_groups[0]['lr'] *= 0.1 354 | continue 355 | output.loss.backward() 356 | losses.append(output.loss.item()) 357 | self.optimiser.step() # Do a gradient descent step with the optimiser 358 | print(f"loss for replay buffer q update {losses[-1:0:-10][::-1]}") 359 | print("time taken for update Q with replay buffer: ", time.time()-start_time) 360 | 361 | # def update_with_replay_buffer(self): 362 | # optimiser = Adam(self.q_network.parameters(), lr=self.alpha) 363 | 364 | # # update based on replay buffer 365 | # for x in range(self.steps_update): 366 | # optimiser.zero_grad() # Reset gradients to zero 367 | # output = self.q_network(**self.replay_buffer, labels = torch.tensor(self.past_rewards, dtype=torch.float).to(self.cuda)) 368 | # output.loss.backward() 369 | # print(output.loss) 370 | # optimiser.step() # Do a gradient descent step with the optimiser 371 | 372 | def get_q_value(self, state, action): 373 | print("getting q value of merged convo:") 374 | 375 | merged_convo = self.merge(state, action) 376 | merged_convo = str(merged_convo) 377 | print(merged_convo) 378 | encoded_input = self.tokenizer(merged_convo, truncation=True, max_length=512, return_tensors='pt').to(self.cuda) 379 | #print(encoded_input) 380 | self.q_network.eval() 381 | with torch.no_grad(): 382 | output = self.q_network(**encoded_input) 383 | q_value = output.logits[0][0].cpu() 384 | print(f"Q value is: {q_value:.8f}") 385 | return q_value 386 | 387 | def get_qs(self, state, actions): 388 | qs = [] 389 | for action in actions: 390 | reward_estimate = self.get_q_value(state, action) 391 | qs.append(reward_estimate) 392 | return qs 393 | 394 | def get_max_q(self, state, actions): 395 | qs = self.get_qs(state, actions) 396 | arg_max_q = np.argmax(qs) 397 | best_action = actions[arg_max_q] 398 | best_reward = qs[arg_max_q] 399 | return (best_action, best_reward) -------------------------------------------------------------------------------- /monte_carlo_tree_search/policy_agent.py: -------------------------------------------------------------------------------- 1 | from monte_carlo_tree_search.qtable import QTable, DeepQFunction 2 | from monte_carlo_tree_search.single_agent_mcts import SingleAgentMCTS 3 | from monte_carlo_tree_search.conversation_env import conversation_environment, conversation_state 4 | from monte_carlo_tree_search.semantic_conversation_env import semantic_conversation_environment, conversation_semantic_state 5 | from monte_carlo_tree_search.ucb import UpperConfidenceBounds 6 | 7 | from agent.Conversation import Conversation 8 | from agent.Model import Model 9 | from reward.Base_Reward import Base_Reward 10 | from reward.Llama_2_Guard_Reward import Llama_2_Guard_Reward 11 | from reward.Embedding_Length_Reward import Embedding_Length_Reward 12 | 13 | from transition_models.transition_model import TransitionModelMOE 14 | import random 15 | import copy 16 | import numpy as np 17 | from scipy import stats 18 | import torch 19 | from typing import List 20 | 21 | from abc import abstractmethod 22 | 23 | from time import time 24 | 25 | from tqdm import tqdm 26 | import itertools 27 | 28 | class LearntAgent(): 29 | def __init__(self) -> None: 30 | pass 31 | @abstractmethod 32 | def generate_action(self, state : conversation_state, results = {}, seed=None, **kwargs): 33 | pass 34 | def seed(self, seed): 35 | if seed is not None: 36 | torch.manual_seed(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | 40 | # an agent that just greedily returns the best action during runtime. Infer next response by human and choose greedily. 41 | class RandomAgent(LearntAgent): 42 | 43 | def __init__(self, action_generator : Model) -> None: 44 | self.action_generator = action_generator 45 | 46 | def generate_action(self, state : conversation_state, results = {}, seed=None, **kwargs): 47 | self.seed(seed) 48 | possible_actions = self.action_generator.sample_actions(state.conversation) 49 | print("possible actions random agent proposed: ", possible_actions) 50 | 51 | start_time = time() 52 | self.seed(seed) 53 | best_action_index = random.randint(0, len(possible_actions)-1) 54 | best_action = possible_actions[best_action_index] 55 | 56 | results["possible_actions"] = possible_actions 57 | results["selected_action_index"] = best_action_index 58 | print(f"action selected by random agent: {best_action}\ttime taken: {time()-start_time}") 59 | return best_action 60 | 61 | # an agent that just greedily returns the best action during runtime. Infer next response by human and choose greedily. 62 | class GreedyAgent(LearntAgent): 63 | 64 | def __init__(self, reward_calculator, action_generator : Model) -> None: 65 | self.reward_calculator = reward_calculator 66 | self.action_generator = action_generator 67 | 68 | def generate_action(self, state : conversation_state, results = {}, seed=None, **kwargs): 69 | self.seed(seed) 70 | possible_actions = self.action_generator.sample_actions(state.conversation) # maybe add an argument to choose number of actions 71 | start_time = time() 72 | best_action = self.reward_calculator.select(state, possible_actions, results = results) 73 | print(f"action selected by greedy agent: {best_action}\ttime taken: {time()-start_time}") 74 | 75 | return best_action 76 | # greedy reward functions to be used in GreedyAgent 77 | def len_reward_function(human_response): 78 | return len(human_response) 79 | 80 | class zero_step_greedy_reward_generator(): 81 | def __init__(self, human_agent : Model, reward_function) -> None: 82 | self.human = human_agent 83 | self.reward_function = reward_function 84 | 85 | # greedy reward: infer multiple human responses. take average reward from them. 86 | def select(self, state : conversation_state, possible_actions, results = {}): 87 | print("selecting greedy action...") 88 | convo = state.conversation 89 | print("current state: ", convo) 90 | action_reward = [] 91 | for action in possible_actions: 92 | greedy_reward = self.reward_function(convo, action, None) 93 | print("one step greedy lookahead reward: ", greedy_reward) 94 | action_reward.append(greedy_reward) 95 | best_action_idx = np.argmax(action_reward) 96 | best_action = possible_actions[best_action_idx] 97 | # print("selected greedy reward: ", best_action) 98 | 99 | results["possible_actions"] = possible_actions 100 | results["possible_actions_reward"] = action_reward 101 | results["selected_action_index"] = best_action_idx 102 | return best_action 103 | 104 | class one_step_greedy_reward_generator(): 105 | def __init__(self, human_agent : Model, reward_function) -> None: 106 | self.human = human_agent 107 | self.reward_function = reward_function 108 | 109 | # greedy reward: infer multiple human responses. take average reward from them. 110 | def select(self, state : conversation_state, possible_actions, results = {}): 111 | print("selecting greedy action...") 112 | convo = state.conversation 113 | print("current state: ", convo) 114 | action_reward = [] 115 | for action in possible_actions: 116 | print("candidate action: ", action) 117 | human_responses = self.human.sample_actions(convo + action) 118 | reward_to_be_averaged = [] 119 | for response in human_responses: 120 | print("one step greedy lookahead reward: ", self.reward_function(convo, action, response)) 121 | reward_to_be_averaged.append(self.reward_function(convo, action, response)) 122 | print("mean greedy one step reward: ", np.mean(reward_to_be_averaged)) 123 | action_reward.append(np.mean(reward_to_be_averaged)) 124 | best_action_idx = np.argmax(action_reward) 125 | best_action = possible_actions[best_action_idx] 126 | print("selected greedy reward: ", best_action) 127 | results["possible_actions"] = possible_actions 128 | results["possible_actions_reward"] = action_reward 129 | results["selected_action_index"] = best_action_idx 130 | return best_action 131 | 132 | # An agent with a pretrained Q function used to find best action during runtime. No searching is done. 133 | class OfflineAgent(LearntAgent): 134 | 135 | def __init__(self, qfunction : DeepQFunction, llm_agent : Model) -> None: 136 | self.qfunction = qfunction 137 | self.llm_agent = llm_agent 138 | 139 | def generate_action(self, state : conversation_state, results = {}, seed=None, **kwargs): 140 | self.seed(seed) 141 | possible_actions = self.llm_agent.sample_actions(state.conversation) # maybe add an argument to choose number of actions 142 | start_time = time() 143 | qs = self.qfunction.get_qs(state, possible_actions) 144 | best_action_index = np.argmax(qs) 145 | best_action = possible_actions[best_action_index] 146 | 147 | results["possible_actions"] = possible_actions 148 | results["possible_actions_reward"] = qs 149 | results["selected_action_index"] = best_action_index 150 | print(f"action selected by offline agent: {best_action}\ttime taken: {time()-start_time}") 151 | return best_action 152 | 153 | # An agent which performs MCTS during runtime. Takes in a Q functon during initialization (possibly pretrained) 154 | class OnlineAgent(LearntAgent): 155 | 156 | def __init__(self, qfunction : DeepQFunction, search_depth, mcts_time_limit, llm_agent : Model, human_simulator, reward_function_for_mcts, search_space="response_space", reward_decay=1.0, terminating_heuristic_q_function="get_last_action", transition_model=None, embedding_model=None) -> None: 157 | self.search_depth = search_depth 158 | self.mcts_time_limit = mcts_time_limit 159 | self.llm_agent = llm_agent 160 | self.human_simulator = human_simulator 161 | self.qfunction = qfunction 162 | self.original_qfunction = copy.deepcopy(qfunction) 163 | self.terminating_heuristic_q_function = terminating_heuristic_q_function 164 | self.reward_function_for_mcts = reward_function_for_mcts 165 | self.search_space = search_space 166 | self.reward_decay = reward_decay 167 | self.transition_model = transition_model 168 | self.embedding_model = embedding_model 169 | if isinstance(self.reward_function_for_mcts, Llama_2_Guard_Reward): 170 | self.reward_from_embedding = self.reward_function_for_mcts.get_safe_prob_from_embedding 171 | elif isinstance(self.reward_function_for_mcts, Embedding_Length_Reward): 172 | self.reward_from_embedding = lambda x: self.reward_function_for_mcts.model(x).detach().cpu() 173 | 174 | def generate_action(self, state : conversation_state, results = {}, seed=None, **kwargs): 175 | evaluation_conversation_env = conversation_environment(self.human_simulator, self.llm_agent, state.conversation, max_depth=self.search_depth, reward_function=self.reward_function_for_mcts) 176 | print("generating action in realtime...") 177 | if self.search_space=="response_space": 178 | conversation_env = conversation_environment(self.human_simulator, self.llm_agent, state.conversation, max_depth=self.search_depth, reward_function=self.reward_function_for_mcts) 179 | elif self.search_space=="semantic_space": 180 | conversation_env = semantic_conversation_environment(embedding_model=self.embedding_model, transition_model=self.transition_model, initial_state=state.conversation, max_depth=self.search_depth, reward_function=self.reward_function_for_mcts) 181 | conversation_env.initial_actions_asked = True 182 | # get initial action, change to semantics, and store it! 183 | self.seed(seed) 184 | possible_starting_actions = evaluation_conversation_env.get_actions(state) # generate initial actions 185 | start_time = time() 186 | starting_convo_semantics = self.embedding_model.embed(state.conversation).cpu().detach().numpy() 187 | action_semantics = [] 188 | action_rewards = [] 189 | for action in possible_starting_actions: 190 | concatenated_convo = state.conversation + action # conversation of initial state + action, in string form 191 | output = self.embedding_model.embed(concatenated_convo) # embedding form 192 | reward = self.reward_from_embedding(output) 193 | action_rewards.append(reward) 194 | action_semantic = tuple(output.cpu().detach().numpy()) # convert to tuple 195 | action_semantic = tuple([x1-x2 for x1,x2 in zip(list(action_semantic),list(starting_convo_semantics))]) # get the difference (action in semantic form) 196 | action_semantics.append(action_semantic) 197 | 198 | results["greedy_rewards"] = action_rewards 199 | results["greedy_action_index"] = np.argmax(action_rewards) 200 | 201 | 202 | conversation_env.state_to_action_map[tuple(starting_convo_semantics)] = action_semantics # store the initial action, so later we can use it. 203 | 204 | 205 | print("performing MCTS search...") 206 | self.seed(seed) 207 | mcts = SingleAgentMCTS(conversation_env, self.qfunction, UpperConfidenceBounds(), terminating_heuristic_q_function=self.terminating_heuristic_q_function) 208 | mcts.mcts(timeout=self.mcts_time_limit, seed=seed) 209 | self.qfunction = mcts.qfunction # qfunction learnt after performing mcts 210 | 211 | # get best action from learnt q function after mcts 212 | if self.search_space=="response_space": 213 | print("getting best action from Q function...") 214 | print("current state: ", state) 215 | self.seed(seed) 216 | possible_actions = mcts.initial_actions 217 | start_time = time() 218 | print("proposed actions: \n", possible_actions) 219 | qs = self.qfunction.get_qs(state, possible_actions) 220 | best_action_index = np.argmax(qs) 221 | best_action = possible_actions[best_action_index] 222 | 223 | # if semantic space used, some semantic projection is needed 224 | elif self.search_space=="semantic_space": 225 | print("getting best action from Q function...") 226 | print("current state: ", state) 227 | # get conversation semantics 228 | truncated_state = state.conversation # actual convo 229 | 230 | output = self.embedding_model.embed(truncated_state) # embedding 231 | 232 | conversation_semantics = tuple(output.cpu().detach().numpy()) 233 | semantic_state = copy.deepcopy(state) 234 | semantic_state.conversation = conversation_semantics 235 | 236 | # get action semantics 237 | action_semantics = [] 238 | # get real actions 239 | print("getting actions during evaluation..") 240 | self.seed(seed) 241 | possible_actions = evaluation_conversation_env.get_actions(state) 242 | print("possible actions generated: ", possible_actions) 243 | action_rewards = [] 244 | for action in possible_actions: 245 | concatenated_convo = truncated_state + action # conversation 246 | output = self.embedding_model.embed(concatenated_convo) # embedding 247 | # output is the semantics after combining action with state. 248 | # we deduct from the output the state semantics to obtain a directional vector which 249 | # represents the action semantics 250 | action_semantic = tuple(output.cpu().detach().numpy()) 251 | action_semantic = tuple([x1-x2 for x1,x2 in zip(list(action_semantic),list(conversation_semantics))]) 252 | action_semantics.append(action_semantic) 253 | 254 | action_rewards.append(self.reward_from_embedding(output)) 255 | 256 | # best_action_index = np.argmax(action_rewards) # greedy 257 | 258 | #filter off the worst 2 actions if there are more than 2 259 | # if len(action_rewards) > 2: 260 | # smallest_indices = sorted(range(len(action_rewards)), key=lambda i: action_rewards[i])[:(2)] 261 | # for index in sorted(smallest_indices, reverse=True): 262 | # print("deleting the following index: ", index) 263 | # del action_semantics[index] 264 | # del action_rewards[index] 265 | 266 | # use Q function to get Q value 267 | qs = self.qfunction.get_qs(semantic_state, action_semantics) 268 | best_action_index = np.argmax(qs) 269 | 270 | best_action = possible_actions[best_action_index] 271 | if results["greedy_action_index"] != best_action_index: 272 | print("different greedy action selected") 273 | # print(f"Different action selected. greedy selected: {results['greedy_action_index']} (q={qs[results['greedy_action_index']]:.3f}), actual selected: {best_action_index} (q={qs[best_action_index]:.3f})", 274 | # possible_actions[results['greedy_action_index']], possible_actions[best_action_index] 275 | # ) 276 | 277 | results["possible_actions"] = possible_actions 278 | results["possible_actions_reward"] = qs 279 | results["selected_action_index"] = best_action_index 280 | print(f"action selected by online agent: {best_action}\ttime taken: {time()-start_time}") 281 | return best_action 282 | 283 | # util function for resetting q function 284 | def reset(self): 285 | self.qfunction = copy.deepcopy(self.original_qfunction) 286 | 287 | class ExhastiveOnlineAgent(LearntAgent): 288 | def __init__(self, search_depth, mcts_time_limit, llm_agent : Model, human_simulator : Model, reward_function_for_mcts : Base_Reward, search_space="response_space", reward_decay=1.0, transition_model : TransitionModelMOE =None, embedding_model=None, **kwargs) -> None: 289 | self.search_depth = search_depth 290 | self.mcts_time_limit = mcts_time_limit 291 | self.llm_agent = llm_agent 292 | self.reward_function_for_mcts = reward_function_for_mcts 293 | self.search_space = search_space 294 | self.reward_decay = reward_decay 295 | self.transition_model = transition_model 296 | self.embedding_model = embedding_model 297 | if isinstance(self.reward_function_for_mcts, Llama_2_Guard_Reward): 298 | self.reward_from_embedding = self.reward_function_for_mcts.get_safe_prob_from_embedding 299 | elif isinstance(self.reward_function_for_mcts, Embedding_Length_Reward): 300 | self.reward_from_embedding = lambda x: self.reward_function_for_mcts.model(x).detach().cpu() 301 | 302 | def generate_action(self, state : conversation_state, results = {}, seed=None, **kwargs): 303 | assert self.search_space=="semantic_space", "Only semantic space is supported for exhastive search" 304 | print("performing exhastive search...") 305 | 306 | self.seed(seed) 307 | possible_actions = self.llm_agent.sample_actions(state.conversation) 308 | start_time = time() 309 | 310 | llm_actions = torch.stack([self.embedding_model.embed(state.conversation + i) for i in possible_actions]) 311 | rewards = [self.reward_from_embedding(llm_actions)] 312 | 313 | for depth in range(1, 1+self.search_depth): 314 | if depth % 2: 315 | start_time = time() 316 | human_actions = self.transition_model.batch_sample_human(llm_actions) 317 | human_actions_time = time() - start_time 318 | start_time = time() 319 | rewards.append(self.reward_from_embedding(human_actions)) 320 | reward_time = time() - start_time 321 | print(f"Current depth: {depth} considering {torch.prod(torch.tensor(human_actions.shape[:-1]))} actions. Human action time: {human_actions_time:.3f}, reward time: {reward_time:.3f}" ) 322 | else: 323 | start_time = time() 324 | llm_actions = self.transition_model.batch_sample_human(human_actions) 325 | llm_actions_time = time() - start_time 326 | start_time = time() 327 | rewards.append(self.reward_from_embedding(llm_actions) - rewards[-1]) 328 | reward_time = time() - start_time 329 | print(f"Current depth: {depth} considering {torch.prod(torch.tensor(llm_actions.shape[:-1]))} actions. llm action time: {llm_actions_time:.3f}, reward time: {reward_time:.3f}") 330 | 331 | reward = rewards[-1] 332 | for i in range(len(rewards)-2, -1, -1): 333 | if i % 2: 334 | reward = reward.max(dim=0).values # max over llm responses 335 | else: 336 | reward = reward.mean(dim=0) * self.reward_decay # mean over human responses 337 | reward = reward + rewards[i] 338 | # reward = (reward * self.reward_decay + rewards[i]).max(dim=0).values.mean(dim=0) 339 | # reward = reward.squeeze() 340 | 341 | best_action_index = reward.argmax() 342 | best_action, best_reward = possible_actions[best_action_index], reward[best_action_index] 343 | results["possible_actions"] = possible_actions 344 | results["possible_actions_reward"] = reward 345 | results["selected_action_index"] = best_action_index 346 | results["greedy_rewards"] = rewards[0] 347 | results["greedy_action_index"] = rewards[0].argmax() 348 | 349 | if results["greedy_action_index"] != best_action_index: 350 | print(f'Different action selected. greedy selected: {results["greedy_action_index"]} (q={rewards[0][results["greedy_action_index"]].squeeze():.3f}), actual selected: {best_action_index} (q={reward[best_action_index].squeeze():.3f})\n{possible_actions[best_action_index]}\n\n{possible_actions[results["greedy_action_index"]]}' 351 | ) 352 | 353 | print(f"action selected by exhaustive agent: \"{best_action}\" with cummulative reward {best_reward}\ttime taken: {time()-start_time}") 354 | return best_action 355 | 356 | def evaluate_agent(agent : LearntAgent, env : conversation_environment, starting_state : conversation_state, number_replies, results = {}, seed=None, **kwargs): 357 | 358 | cumulative_reward = 0.0 359 | time_taken_to_generate_action = [] 360 | time_taken_in_simulation = [] 361 | reward_for_one_step = [] 362 | results["convo_replies"] = [] 363 | for r in range(number_replies): 364 | 365 | # get best action based on starting_state 366 | if hasattr(agent, 'qfunction'): 367 | agent.qfunction.reset() 368 | 369 | start_time = time() 370 | curr_result = {} 371 | 372 | curr_seed = hash((r, seed)) % (2**32) 373 | action = agent.generate_action(starting_state, results = curr_result, seed=curr_seed, **kwargs) 374 | time_taken = time()-start_time 375 | time_taken_to_generate_action.append(time_taken) 376 | print("Time taken by agent to generate action", time_taken) 377 | start_time = time() 378 | # go to next state 379 | next_state, reward = env.execute_in_simulation(starting_state, action, results = curr_result, seed=curr_seed, **kwargs) 380 | time_taken = time()-start_time 381 | time_taken_in_simulation.append(time_taken) 382 | print("Time taken in simulation", time_taken) 383 | print("eval human response: ", next_state.response) 384 | print("reward for one step of evaluation: ", reward) 385 | starting_state = next_state 386 | cumulative_reward += reward 387 | 388 | reward_for_one_step.append(reward) 389 | results["convo_replies"].append(curr_result) 390 | 391 | # do one more action generation and do not need to generate human response 392 | # curr_seed = hash((-1, seed)) % (2**32) 393 | # action = agent.generate_action(starting_state, results = curr_result, seed=curr_seed) 394 | # last_step_reward = env.get_reward(starting_state.conversation, action, None) # None for human response. The reward function will handle this case to only get reward for action. 395 | # final_convo_including_last_actions = starting_state.conversation + action 396 | # starting_state = conversation_state(action, final_convo_including_last_actions) 397 | # starting_state.depth = starting_state.depth + 1 398 | # print("reward for last action step: ", last_step_reward) 399 | # cumulative_reward += last_step_reward 400 | # reward_for_one_step.append(reward) 401 | 402 | 403 | print("entire evaluation convo: ", starting_state.conversation) 404 | results["time_taken_to_generate_action"] = time_taken_to_generate_action 405 | results["time_taken_in_simulation"] = time_taken_in_simulation 406 | results["reward_for_one_step"] = reward_for_one_step 407 | results["entire_convo"] = starting_state.conversation.full_convo 408 | return cumulative_reward, starting_state 409 | 410 | import lz4.frame 411 | import pickle 412 | # evaluate an agent with the mdp. 413 | def run_evaluations(agent, type, env : conversation_environment, evaluation_starters : List[str], number_replies : int, trials : int, context_list = itertools.repeat(None), human_descriptions = itertools.repeat(None), results = {}, index = itertools.count(), seed = 42, output_file="tmp"): 414 | result_row = [] 415 | convo_generated = {} 416 | results["time_taken_for_trial"] = {} 417 | results["all_rewards_from_trials"] = {} 418 | results["reward_mean"] = {} 419 | results["reward_std"] = {} 420 | results["trial_results"] = {} 421 | for i, evaluation_starter, context, human_description in zip(index, tqdm(evaluation_starters), context_list, human_descriptions): 422 | 423 | initial_state = conversation_state((evaluation_starter), Conversation(evaluation_starter)) 424 | initial_state.depth = 1 425 | 426 | # repeated trials 427 | rewards = [] 428 | convo_generated_from_trials = [] 429 | time_taken_for_trial = [] 430 | trial_results = [] 431 | for x in range(trials): 432 | 433 | if hasattr(agent, 'qfunction'): 434 | agent.qfunction.reset() 435 | print("trial: ", x, " of evaluation for agent of type: ", type) 436 | start_time = time() 437 | curr_trial_result = {} 438 | 439 | curr_seed = hash((seed, evaluation_starter, x)) % (2**32) 440 | 441 | cumulative_reward, entire_convo = evaluate_agent(agent, env, initial_state, number_replies, results = curr_trial_result, context = context, human_description = human_description, seed = curr_seed) 442 | time_taken = time()-start_time 443 | print("Time taken for current trial", time_taken) 444 | convo_generated_from_trials.append(entire_convo.conversation) 445 | print("cumulative reward for this trial: ", cumulative_reward) 446 | 447 | time_taken_for_trial.append(time_taken) 448 | rewards.append(cumulative_reward) 449 | trial_results.append(curr_trial_result) 450 | 451 | reward_mean = np.mean(rewards) 452 | reward_std = stats.sem(rewards) 453 | results["time_taken_for_trial"][i] = time_taken_for_trial 454 | results["all_rewards_from_trials"][i] = rewards 455 | results["reward_mean"][i] = reward_mean 456 | results["reward_std"][i] = reward_std 457 | results["trial_results"][i] = trial_results 458 | print(evaluation_starter) 459 | print("all rewards from trials: ", rewards) 460 | print("mean: ", reward_mean) 461 | print("std error: ", reward_std) 462 | convo_generated[evaluation_starter] = convo_generated_from_trials 463 | result_row.append(((np.mean(rewards)), (stats.sem(rewards)))) 464 | 465 | if i % 20 == 19: 466 | with lz4.frame.open(output_file+'_tmp.pkl', 'wb') as f: 467 | pickle.dump(results, f) 468 | 469 | return result_row, convo_generated 470 | 471 | def run_evaluations_singular(agent : LearntAgent, type, env : conversation_environment, evaluation_starter : str, context_list = itertools.repeat(None), human_descriptions = itertools.repeat(None), results = {}, index = itertools.count(), seed = 42, output_file="tmp"): 472 | initial_state = conversation_state((evaluation_starter), Conversation(evaluation_starter)) 473 | initial_state.depth = 1 474 | best_response, possible_actions, rewards = evaluate_agent_singular(agent, env, initial_state) 475 | return best_response, possible_actions, rewards 476 | 477 | 478 | 479 | def evaluate_agent_singular(agent : LearntAgent, env : conversation_environment, starting_state : conversation_state, seed=None, **kwargs): 480 | 481 | # get best action based on starting_state 482 | if hasattr(agent, 'qfunction'): 483 | agent.qfunction.reset() 484 | 485 | start_time = time() 486 | curr_result = {} 487 | 488 | curr_seed = hash((seed, seed)) % (2**32) 489 | action = agent.generate_action(starting_state, results = curr_result, seed=curr_seed, **kwargs) 490 | 491 | return action, curr_result["possible_actions"], curr_result["possible_actions_reward"] --------------------------------------------------------------------------------