├── src ├── __init__.py ├── transitions.py ├── logger.py ├── nodes_to_csv.py ├── utils.py ├── structured_outputs.py ├── args.py ├── dataset.py ├── log_utils.py ├── deduplication.py ├── mcts_utils.py ├── mcts_viz.html ├── agents.py ├── mcts.py ├── run.py └── beliefs.py ├── artifacts ├── autods_logo.png └── autodiscovery_logo.png ├── .gitignore ├── environment.yml └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /artifacts/autods_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/autodiscovery/HEAD/artifacts/autods_logo.png -------------------------------------------------------------------------------- /artifacts/autodiscovery_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/autodiscovery/HEAD/artifacts/autodiscovery_logo.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .streamlit/ 3 | annotation_state/ 4 | .cache/ 5 | coding/ 6 | .venv/ 7 | logs/ 8 | .idea/ 9 | answer_key_real.csv 10 | .DS_Store 11 | *.err 12 | *.out 13 | debug/ 14 | work/ 15 | output/ 16 | outputs/ 17 | results/ 18 | beam_search/ 19 | discoverybench/ 20 | blade/ 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: autods 2 | dependencies: 3 | - python=3.11 4 | - streamlit=1.37.1 5 | - pandas=1.5.3 6 | - matplotlib=3.10 7 | - pip 8 | - pip: 9 | - IPython==8.26.0 10 | - pyautogen[openai]==0.8 11 | - scikit-learn 12 | - scipy 13 | - statsmodels 14 | - h5py 15 | - boto3 16 | -------------------------------------------------------------------------------- /src/transitions.py: -------------------------------------------------------------------------------- 1 | from autogen import GroupChat, Agent 2 | from typing import Optional 3 | import json 4 | 5 | 6 | class SpeakerSelector: 7 | def __init__(self): 8 | self.code_failure_count = 0 9 | self.experiment_revision_count = 0 10 | 11 | def select_next_speaker(self, last_speaker: Agent, groupchat: GroupChat) -> Optional[Agent]: 12 | """Define a customized speaker selection function for the data exploration workflow. 13 | 14 | Args: 15 | last_speaker: The previous speaker in the conversation 16 | groupchat: The GroupChat instance containing conversation history 17 | 18 | Returns: 19 | The next agent to speak or None to end the conversation 20 | """ 21 | messages = groupchat.messages 22 | 23 | if last_speaker.name == "user_proxy": 24 | return groupchat.agent_by_name("experiment_programmer") 25 | 26 | elif last_speaker.name == "experiment_programmer": 27 | return groupchat.agent_by_name("code_executor") 28 | 29 | elif last_speaker.name == "code_executor": 30 | return groupchat.agent_by_name("experiment_code_analyst") 31 | 32 | elif last_speaker.name == "experiment_code_analyst": 33 | # Check if experiment failed based on structured response 34 | content = messages[-1].get("content", "") 35 | try: 36 | response = json.loads(content) 37 | except json.JSONDecodeError: 38 | # If JSON parsing fails, treat it as an error 39 | response = {"success": False, "analysis": "Error parsing response"} 40 | if not response.get("success", False) and self.code_failure_count < 6: 41 | self.code_failure_count += 1 42 | return groupchat.agent_by_name("experiment_programmer") 43 | else: 44 | self.code_failure_count = 0 45 | return groupchat.agent_by_name("experiment_reviewer") 46 | 47 | elif last_speaker.name == "experiment_reviewer": 48 | content = messages[-1].get("content", "") 49 | try: 50 | response = json.loads(content) 51 | except json.JSONDecodeError: 52 | response = {"success": False, "feedback": "Error parsing reviewer response"} 53 | if not response.get("success", True) and self.experiment_revision_count < 1: 54 | self.experiment_revision_count += 1 55 | return groupchat.agent_by_name("experiment_reviser") 56 | else: 57 | self.experiment_revision_count = 0 58 | return groupchat.agent_by_name("experiment_generator") 59 | 60 | if last_speaker.name == "experiment_reviser": 61 | return groupchat.agent_by_name("experiment_programmer") 62 | 63 | elif last_speaker.name == "experiment_generator": 64 | return None 65 | 66 | return None 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![autodiscovery_logo.png](artifacts/autodiscovery_logo.png) 2 | # Open-ended Scientific Discovery via Bayesian Surprise 3 | 4 | > Link to our NeurIPS 2025 paper: [AutoDiscovery: Open-ended Scientific Discovery via Bayesian Surprise](https://openreview.net/pdf?id=kJqTkj2HhF) 5 | 6 | ## Installation 7 | 8 | Create the environment with: 9 | 10 | ```sh 11 | conda env create -f environment.yml 12 | conda activate autods 13 | ``` 14 | 15 | Set environment variables: 16 | 17 | ```sh 18 | # (for Linux/MacOS/Bash/Cygwin) 19 | export PYTHONPATH=$(pwd):$PYTHONPATH; 20 | 21 | # (for Windows CMD) 22 | set PYTHONPATH=%cd%;%PYTHONPATH% 23 | 24 | # (if OPENAI_API_KEY is not already set) 25 | export OPENAI_API_KEY= 26 | ``` 27 | 28 | ## Datasets 29 | 30 | ### DiscoveryBench 31 | 32 | ```sh 33 | git clone https://github.com/allenai/discoverybench.git temp_db 34 | cp -r temp_db/discoverybench discoverybench 35 | rm -rf temp_db 36 | ``` 37 | 38 | ### Blade 39 | 40 | ```sh 41 | git clone https://github.com/behavioral-data/BLADE.git temp_db 42 | cp -r temp_db/blade_bench/datasets blade 43 | rm -rf temp_db 44 | ``` 45 | 46 | ### BYO-Datasets! 47 | You can also use your own datasets. To do this, pass in a dataset metadata JSON file containing descriptions of the paths of datasets (relative to the metadata file) and their column descriptions in natural language. You can have a look at the metadata files in the `DiscoveryBench` directory from above as examples. 48 | 49 | ## Run AutoDS (MCTS-based hypothesis search and verification) 50 | 51 | For example, to explore the DiscoveryBench NLS SES dataset, the following command can be used: 52 | 53 | ```sh 54 | python src/run.py \ 55 | --work_dir="work" \ 56 | --out_dir="outputs" \ 57 | --dataset_metadata="discoverybench/real/test/nls_ses/metadata_0.json" \ 58 | --n_experiments=16 \ 59 | --model="gpt-4o" \ 60 | --belief_model="gpt-4o" 61 | ``` 62 | 63 | To resume a previous exploration, use the `--continue_from_dir` flag to specify the directory containing the previous 64 | exploration logs. This will allow the script to continue from where it left off, using the MCTS nodes it had generated 65 | so far. 66 | 67 | ## ✍️ Get in touch! 68 | 69 | Please reach out to us on email or open a GitHub issue in case of any issues running the code: dagarwal@cs.umass.edu **(Dhruv Agarwal)**, bodhisattwam@allenai.org **(Bodhisattwa Prasad Majumder)**. 70 | 71 | ## 📄 Citation 72 | If you find our work useful, please cite our paper: 73 | ``` 74 | @inproceedings{ 75 | agarwal2025autodiscovery, 76 | title={AutoDiscovery: Open-ended Scientific Discovery via Bayesian Surprise}, 77 | author={Dhruv Agarwal and Bodhisattwa Prasad Majumder and Reece Adamson and Megha Chakravorty and Satvika Reddy Gavireddy and Aditya Parashar and Harshit Surana and Bhavana Dalvi Mishra and Andrew McCallum and Ashish Sabharwal and Peter Clark}, 78 | booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems}, 79 | year={2025}, 80 | url={https://openreview.net/forum?id=kJqTkj2HhF} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from autogen import ChatCompletion 4 | 5 | class TreeLogger: 6 | 7 | level: int 8 | agent_name: str 9 | node_idx: int 10 | 11 | def __init__(self, log_dir: str): 12 | """Initialize logger that stores logs for each node in a tree exploration. 13 | 14 | Args: 15 | log_dir: Directory to store all log files 16 | """ 17 | self.log_dir = log_dir 18 | os.makedirs(log_dir, exist_ok=True) 19 | 20 | def log_node(self, level: int, node_idx: int, message: str | dict): 21 | """Log a message for a specific node. 22 | 23 | Args: 24 | level: Level of the node in the tree 25 | node_idx: Index of the node within its level 26 | message: Message to log (string or dictionary) 27 | """ 28 | filename = os.path.join(self.log_dir, f"node_{level}_{node_idx}.json") 29 | 30 | with open(filename, 'a') as f: 31 | if isinstance(message, dict): 32 | f.write(json.dumps(message, indent=2)) 33 | else: 34 | f.write(json.dumps(json.loads(message), indent=2)) 35 | 36 | def load_node(self, level: int, node_idx: int, as_json: bool = False) -> list[str | dict]: 37 | """Load the contents of a log file for a specific node. 38 | 39 | Args: 40 | level: Level of the node in the tree 41 | node_idx: Index of the node within its level 42 | as_json: If True, attempt to parse lines as JSON. If False, return raw strings. 43 | 44 | Returns: 45 | List of messages from the log file. Each message is either a string or 46 | dictionary depending on how it was originally logged and the as_json flag. 47 | 48 | Raises: 49 | FileNotFoundError: If the log file does not exist 50 | """ 51 | filename = os.path.join(self.log_dir, f"node_{level}_{node_idx}.json") 52 | 53 | messages = None 54 | with open(filename, 'r') as f: 55 | if as_json: 56 | messages = json.load(f) 57 | else: 58 | messages = f.read() 59 | 60 | return messages 61 | 62 | def log_choices(self, chat_completion: ChatCompletion, as_json: bool = False): 63 | filename = os.path.join(self.log_dir, f"{self.agent_name}_{self.level}_{self.node_idx}.json") 64 | choice_log = [] 65 | for choice in chat_completion.choices: 66 | msg = { 67 | "message": choice.message.content, 68 | "token_logprobs": self._parse_logprobs(choice.logprobs) 69 | } 70 | choice_log.append(msg) 71 | 72 | with open(filename, 'w') as f: 73 | json.dump(choice_log, f, indent=2) 74 | 75 | 76 | 77 | # with open(filename, 'a') as f: 78 | # for choice in chat_completion.choices: 79 | # f.write(choice.message.content + '\n') 80 | 81 | def _parse_logprobs(self, choice_logprobs): 82 | logprobs = [] 83 | for token in choice_logprobs.content: 84 | token_prob = { 85 | "token": token.token, 86 | "bytes": token.bytes, 87 | "logprob": token.logprob 88 | } 89 | logprobs.append(token_prob) 90 | return logprobs -------------------------------------------------------------------------------- /src/nodes_to_csv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import argparse 4 | 5 | from src.utils import try_loading_dict 6 | 7 | 8 | class ArgParser(argparse.ArgumentParser): 9 | def __init__(self, group=None): 10 | super().__init__(description='Get surprising nodes from MCTS logs') 11 | self.add_argument('--in_fpath', type=str, required=True, 12 | help='mcts_nodes.json file path or directory containing mcts_node_*.json files') 13 | self.add_argument('--out_fpath', type=str, required=True, help='output CSV file path') 14 | 15 | 16 | def nodes_to_csv(nodes_or_json_path, out_fpath): 17 | from src.mcts_utils import get_nodes, get_node_level_idx # Import here to avoid circular import issues 18 | mcts_nodes = get_nodes(nodes_or_json_path) 19 | 20 | csv_list = [] 21 | for node in mcts_nodes: 22 | csv_node = {} 23 | node_level, node_idx = get_node_level_idx(node) 24 | 25 | if node_level in [0, 1]: 26 | continue 27 | 28 | try: 29 | prior_mean = round(node["prior"]["mean"], 4) 30 | posterior_mean = round(node["posterior"]["mean"], 4) 31 | belief_change = round(node["belief_change"], 4) 32 | belief_kl = round(node["kl_divergence"], 4) 33 | except: 34 | prior_mean = None 35 | posterior_mean = None 36 | belief_change = None 37 | belief_kl = None 38 | 39 | csv_node['id'] = node['id'].replace('node_', '') 40 | csv_node['success'] = node.get('success', False) 41 | 42 | csv_node['surprisal'] = node.get('surprising', None) 43 | csv_node['prior'] = prior_mean 44 | csv_node['posterior'] = posterior_mean 45 | csv_node['belief_change'] = belief_change 46 | csv_node['belief_kl'] = belief_kl 47 | csv_node['belief_dir'] = None 48 | if prior_mean is not None and posterior_mean is not None: 49 | csv_node['belief_dir'] = 'neg' if posterior_mean < prior_mean else ( 50 | 'same' if posterior_mean == prior_mean else 'pos') 51 | 52 | csv_node['hypothesis'] = node['hypothesis'] 53 | experiment_plan = node['experiment_plan'] 54 | csv_node['experiment_plan'] = f"Objective: {experiment_plan.get('objective', 'N/A')}\n\n" \ 55 | f"Steps: {experiment_plan.get('steps', 'N/A')}\n\n" \ 56 | f"Deliverables: {experiment_plan.get('deliverables', 'N/A')}" 57 | csv_node['analysis'] = node.get('analysis', 'N/A') 58 | csv_node['review'] = node.get('review', 'N/A') 59 | 60 | csv_list.append(csv_node) 61 | csv_list.sort(key=lambda x: x['belief_change'] if x['belief_change'] is not None else float('-inf'), 62 | reverse=True) 63 | 64 | with open(out_fpath, 'w', newline='') as csv_file: 65 | fieldnames = ['id', 'success', 'hypothesis', 'surprisal', 'prior', 'posterior', 'belief_dir', 'belief_change', 66 | 'belief_kl', 'experiment_plan', 'analysis', 'review'] 67 | writer = csv.DictWriter(csv_file, fieldnames=fieldnames) 68 | writer.writeheader() 69 | for row in csv_list: 70 | row = {k: (v if v is not None else '') for k, v in row.items()} 71 | writer.writerow(row) 72 | 73 | print(f"[CSV] MCTS nodes (n={len(csv_list)}; skipping root) saved to {out_fpath}.\n") 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = ArgParser() 78 | args = parser.parse_args() 79 | nodes_to_csv(args.in_fpath, args.out_fpath) 80 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict 4 | 5 | import numpy as np 6 | import boto3 7 | from pydantic import ValidationError 8 | from openai import OpenAI 9 | 10 | 11 | def query_llm( 12 | messages: List[Dict[str, str]], 13 | n_samples: int, 14 | model: str = "gpt-4o", 15 | temperature: float | None = None, 16 | reasoning_effort: str | None = None, 17 | response_format=None, 18 | client: OpenAI = None, 19 | ): 20 | if client is None: 21 | client = OpenAI() 22 | is_reasoning_model = any(model.startswith(prefix) for prefix in ["o", "gpt-5"]) 23 | 24 | n_samples_batch_size = 8 if is_reasoning_model else n_samples 25 | responses = [] 26 | # Sample exactly n_samples responses 27 | for i in range(0, n_samples, n_samples_batch_size): 28 | kwargs = { 29 | "model": model, 30 | "messages": messages, 31 | "n": min(n_samples_batch_size, n_samples - len(responses)), 32 | } 33 | if not is_reasoning_model and temperature is not None: 34 | kwargs["temperature"] = temperature 35 | if is_reasoning_model and reasoning_effort is not None: 36 | kwargs["reasoning_effort"] = reasoning_effort 37 | 38 | if response_format is not None: 39 | kwargs["response_format"] = response_format 40 | 41 | try: 42 | response = client.chat.completions.parse(**kwargs) 43 | except ValidationError: 44 | # Retry if the response format validation fails 45 | response = client.chat.completions.parse(**kwargs) 46 | 47 | for choice in response.choices: 48 | if choice.message.content is None: 49 | continue 50 | responses += [json.loads(choice.message.content)] 51 | return responses 52 | 53 | 54 | def try_loading_dict(_dict_str): 55 | try: 56 | return json.loads(_dict_str) 57 | except json.JSONDecodeError: 58 | try: 59 | return json.loads(_dict_str + '"}') # Fix case where string is truncated 60 | except json.JSONDecodeError: 61 | return {} 62 | 63 | 64 | def fuse_gaussians(means, stds, weight=1.0): 65 | """ 66 | Fuse n independent Gaussian beliefs N(mu_i, sigma_i^2) 67 | into a single Gaussian via product of Gaussians. 68 | 69 | Parameters 70 | ---------- 71 | means : array-like, shape (n,) 72 | The means μ_i of the Gaussian beliefs. 73 | stds : array-like, shape (n,) 74 | The standard deviations σ_i of the Gaussian beliefs. 75 | weight : float, optional 76 | A weight to apply to the precision of each Gaussian. Default is 1.0. 77 | 78 | Returns 79 | ------- 80 | mu_star : float 81 | The fused mean μ_*. 82 | sigma_star : float 83 | The fused standard deviation σ_*. 84 | """ 85 | means = np.array(means, dtype=float) 86 | variances = ( 87 | np.array(stds, dtype=float) ** 2 + 1e-10 88 | ) # Add small value to avoid division by zero 89 | 90 | # Precisions 91 | precisions = weight / variances 92 | 93 | # Combined precision and variance 94 | precision_star = np.sum(precisions) 95 | variance_star = 1.0 / precision_star 96 | 97 | # Combined mean 98 | mu_star = np.sum(precisions * means) / precision_star 99 | sigma_star = np.sqrt(variance_star) 100 | 101 | return mu_star, sigma_star 102 | 103 | 104 | def fetch_from_s3(links: List[str], download_dir="_s3") -> List[str]: 105 | """ 106 | Download data from S3 URLs 107 | Attributes: 108 | links (List[str]): List of S3 URLs to download 109 | download_dir (str): Directory to save downloaded files 110 | Returns: 111 | List of local file paths where files are downloaded 112 | """ 113 | s3_client = boto3.client("s3") 114 | fpaths = [] 115 | for link in links: 116 | _, _, bucket, key = link.split("/", 3) 117 | local_file_path = os.path.join(download_dir, key) 118 | local_dir = os.path.dirname(local_file_path) 119 | os.makedirs(local_dir, exist_ok=True) 120 | byte_str = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() 121 | with open(local_file_path, "wb") as file: 122 | file.write(byte_str) 123 | fpaths.append(local_file_path) 124 | 125 | return fpaths 126 | -------------------------------------------------------------------------------- /src/structured_outputs.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, model_validator 2 | from typing import Optional 3 | from typing_extensions import Self 4 | 5 | 6 | class Relationship(BaseModel): 7 | """ 8 | Represents a relationship between two variables in a hypothesis. 9 | 10 | Attributes: 11 | explanatory (str): The independent/explanatory variable in the relationship 12 | response (str): The dependent/response variable in the relationship 13 | relationship (str): Description of how the explanatory variable affects the response variable 14 | """ 15 | explanatory: str 16 | response: str 17 | relationship: str 18 | 19 | 20 | class HypothesisDimensions(BaseModel): 21 | """ 22 | Structured representation of the key dimensions of a hypothesis. 23 | 24 | Attributes: 25 | contexts (list[str]): List of boundary conditions and assumptions under which the hypothesis holds 26 | variables (list[str]): List of key concepts/variables involved in the hypothesis 27 | relationships (list[Relationship]): List of causal relationships between pairs of variables 28 | """ 29 | contexts: list[str] 30 | variables: list[str] 31 | relationships: list[Relationship] 32 | 33 | 34 | class Hypothesis(BaseModel): 35 | """ 36 | A declarative sentence about the state of the world whose truth value may be inferred from the given dataset(s) using an experiment. 37 | 38 | Attributes: 39 | hypothesis (str): The hypothesis statement 40 | dimensions (HypothesisDimensions): Structured dimensions of the hypothesis 41 | """ 42 | hypothesis: str 43 | dimensions: HypothesisDimensions 44 | 45 | 46 | class ExperimentPlan(BaseModel): 47 | """ 48 | Represents the experiment plan with a title, objective, steps, and deliverables. 49 | 50 | Attributes: 51 | objective (str): The main goal or objective of the experiment 52 | steps (str): List of steps to be followed to implement the experiment 53 | deliverables (str): List of expected outcomes or deliverables from the experiment 54 | """ 55 | objective: str 56 | steps: str 57 | deliverables: str 58 | 59 | 60 | class Experiment(BaseModel): 61 | """ 62 | Represents an experiment with a hypothesis and corresponding experiment plan. 63 | Attributes: 64 | hypothesis (str): A natural-language hypothesis representing an assertion about the world 65 | experiment_plan (ExperimentPlan): The structured experiment plan to verify the hypothesis 66 | """ 67 | hypothesis: str 68 | experiment_plan: ExperimentPlan 69 | 70 | 71 | class ExperimentHypothesis(BaseModel): 72 | """ 73 | Represents an experiment with an experiment plan and a hypothesis. 74 | Attributes: 75 | experiment_plan (ExperimentPlan): A structured experiment plan to verify a hypothesis 76 | hypothesis (str): A natural-language hypothesis representing an assertion about the world that can be 77 | tested by the experiment 78 | """ 79 | experiment_plan: ExperimentPlan 80 | hypothesis: str 81 | 82 | 83 | class ExperimentList(BaseModel): 84 | """ 85 | A collection of experiments. 86 | 87 | Attributes: 88 | experiments (list[Experiment]): List of Experiment objects 89 | """ 90 | experiments: list[Experiment] 91 | 92 | 93 | class ExperimentHypothesisList(BaseModel): 94 | """ 95 | A collection of experiment hypotheses. 96 | 97 | Attributes: 98 | experiments (list[ExperimentHypothesis]): List of ExperimentHypothesis objects 99 | """ 100 | experiments: list[ExperimentHypothesis] 101 | 102 | 103 | class ExperimentCode(BaseModel): 104 | """ 105 | Contains the code implementation for an experiment. 106 | 107 | Attributes: 108 | code (str): The actual code to be executed for the experiment 109 | """ 110 | code: str 111 | 112 | 113 | class ProgramCritique(BaseModel): 114 | """ 115 | Feedback on experiment code implementation. 116 | 117 | Attributes: 118 | fixes (list[str]): List of suggested fixes or improvements for the code 119 | """ 120 | fixes: list[str] 121 | 122 | 123 | class ExperimentAnalyst(BaseModel): 124 | """ 125 | Analysis of experiment results. 126 | 127 | Attributes: 128 | analysis (Optional[str]): Detailed analysis of the experiment outcomes 129 | success (bool): Whether the experiment was successful 130 | """ 131 | analysis: str 132 | success: bool 133 | 134 | @model_validator(mode='after') 135 | def analysis_required_on_success(self) -> Self: 136 | if self.success and self.analysis is None: 137 | raise ValueError('analysis is required when success is True') 138 | return self 139 | 140 | 141 | class ExperimentReviewer(BaseModel): 142 | """ 143 | Review of an experiment's execution and results. 144 | 145 | Attributes: 146 | feedback (str | None): Required feedback when experiment fails, optional otherwise 147 | success (bool): Whether the experiment was successful 148 | 149 | Raises: 150 | ValueError: If success is False and no feedback is provided 151 | """ 152 | feedback: str 153 | success: bool 154 | 155 | @model_validator(mode='after') 156 | def feedback_required_on_failure(self) -> Self: 157 | if not self.success and self.feedback is None: 158 | raise ValueError('feedback is required when success is False') 159 | return self 160 | 161 | 162 | class ImageAnalysis(BaseModel): 163 | """ 164 | Structured representation of plot axes and related analysis information. 165 | 166 | Attributes: 167 | title (str): The title of the plot 168 | x_axis_label (str): Label for the x-axis 169 | y_axis_label (str): Label for the y-axis 170 | x_axis_range (list[int | float]): Range of values on the x-axis 171 | y_axis_range (list[int | float]): Range of values on the y-axis 172 | data_trends (list[str]): List of observed trends in the data 173 | statistical_insights (list[str]): List of statistical observations and metrics 174 | annotations_and_legends (list[str]): List of plot annotations and legend descriptions 175 | """ 176 | plot_type: str 177 | title: str 178 | x_axis_label: str 179 | y_axis_label: str 180 | x_axis_range: list[int] | list[float] 181 | y_axis_range: list[int] | list[float] 182 | data_trends: list[str] 183 | statistical_insights: list[str] 184 | annotations_and_legends: list[str] 185 | 186 | 187 | class ExecutionResult(BaseModel): 188 | exit_code: int 189 | result: str 190 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class ArgParser(argparse.ArgumentParser): 5 | def __init__(self, group=None): 6 | super().__init__(description='Run AutoDS exploration') 7 | 8 | self.add_argument('--dataset_metadata', type=str, required=True, help='Path to dataset metadata.') 9 | self.add_argument('--out_dir', type=str, required=True, help='Output directory for logs.') 10 | self.add_argument("--model", type=str, default="o4-mini", 11 | help="LLM to use for all agents (except belief distribution agent).") 12 | self.add_argument("--belief_model", type=str, default="gpt-4o", 13 | help="LLM to use for belief distribution agent.") 14 | self.add_argument("--user_query", type=str, 15 | help="Custom user query to condition experiment generation during exploration.") 16 | self.add_argument("--temperature", type=float, default=1.0, 17 | help="Temperature setting for all agents (except the belief agent). Should be set to None for OpenaAI o-series models.") 18 | self.add_argument("--belief_temperature", type=float, default=1.0, 19 | help="Temperature setting for the belief agent. Should be set to None for OpenaAI o-series models.") 20 | self.add_argument("--reasoning_effort", type=str, help="Reasoning effort for OpenAI o-series models", 21 | choices=['low', 'medium', 'high'], default='medium') 22 | self.add_argument('--continue_from_dir', type=str, 23 | help='Path to logs dir from a previous run to continue exploration from') 24 | self.add_argument('--continue_from_json', type=str, 25 | help='Path to mcts_nodes.json file to continue exploration from') 26 | self.add_argument('--n_experiments', type=int, help='Number of MCTS iterations (max_iterations)', required=True) 27 | self.add_argument('--k_experiments', type=int, default=8, help='Branching factor for experiments (>= 1)') 28 | self.add_argument('--allow_generate_experiments', action=argparse.BooleanOptionalAction, default=True, 29 | help='Allow nodes to generate new experiments on-demand') 30 | self.add_argument('--n_belief_samples', type=int, default=30, 31 | help='Number of samples for belief distribution evaluation') 32 | self.add_argument('--timestamp_dir', action=argparse.BooleanOptionalAction, default=True, 33 | help='Create timestamped directory for logs') 34 | self.add_argument('--exploration_weight', type=float, help='Exploration weight for UCB1 selection method', 35 | default=2.0) 36 | self.add_argument('--dataset_metadata_type', type=str, choices=['dbench', 'blade'], default='dbench', 37 | help='Type of dataset metadata format (dbench, blade, or ai2)') 38 | self.add_argument('--work_dir', type=str, required=True, help='Working directory for agents') 39 | self.add_argument('--delete_work_dir', action=argparse.BooleanOptionalAction, default=True, 40 | help='Delete the work directory after exploration') 41 | self.add_argument('--beam_width', type=int, default=8, help='Beam width for beam search selection method') 42 | self.add_argument('--use_beam_search', action=argparse.BooleanOptionalAction, default=False, 43 | help='Use beam search selection method') 44 | self.add_argument("--mcts_selection", type=str, 45 | choices=['ucb1', 'beam_search', 'pw', 'pw_all', 'ucb1_recursive'], default='ucb1_recursive', 46 | help="Selection method to use in MCTS (UCB1, beam search, progressive widening, progressive widening with all nodes)") 47 | self.add_argument('--pw_k', type=float, help='Progressive widening constant k', default=1.0) 48 | self.add_argument('--pw_alpha', type=float, help='Progressive widening exponent alpha', default=0.5) 49 | self.add_argument('--k_parents', type=int, default=3, 50 | help='Number of parent levels to include in prompts (None for all)') 51 | self.add_argument('--implicit_bayes_posterior', action=argparse.BooleanOptionalAction, default=False, 52 | help='Whether to use the belief samples with evidence as the direct posterior or to use a Bayesian update that explicitly combines it with the prior.') 53 | self.add_argument('--surprisal_width', type=float, default=0.2, 54 | help='Minimum difference in mean prior and posterior probabilities required to count as a surprisal.') 55 | self.add_argument('--belief_mode', type=str, 56 | choices=['boolean', 'boolean_cat', 'categorical', 'categorical_numeric', 'gaussian'], 57 | default='boolean_cat', help='Belief elicitation mode') 58 | self.add_argument('--use_binary_reward', action=argparse.BooleanOptionalAction, default=False, 59 | help='Use binary reward for MCTS instead of a continuous reward (belief change)') 60 | self.add_argument('--dedupe', action=argparse.BooleanOptionalAction, default=True, 61 | help='Run deduplication after MCTS') 62 | self.add_argument('--only_save_results', action=argparse.BooleanOptionalAction, default=False, 63 | help='Only save results without running MCTS') 64 | self.add_argument('--experiment_first', action=argparse.BooleanOptionalAction, default=False, 65 | help='Generate experiments before hypotheses') 66 | self.add_argument('--code_timeout', type=int, default=30 * 60, 67 | help='Timeout for code execution in seconds') 68 | self.add_argument('--run_eda', action=argparse.BooleanOptionalAction, default=False, 69 | help='Run EDA as part of the initial experiment') 70 | self.add_argument('--n_warmstart', type=int, default=8, 71 | help='Number of initial experiments to run after data loading before using MCTS') 72 | self.add_argument('--use_online_beliefs', action=argparse.BooleanOptionalAction, default=False, 73 | help='Use online beliefs conditioned on past surprisals') 74 | self.add_argument('--evidence_weight', type=float, default=2.0, 75 | help='Weight for the experimental evidence when computing posterior beliefs') 76 | self.add_argument('--kl_scale', type=float, default=5.0, 77 | help='Normalization factor for KL divergence in the reward function') 78 | self.add_argument('--reward_mode', type=str, choices=['belief', 'kl', 'belief_and_kl'], 79 | default='kl', help='Reward mode for MCTS (belief change, KL divergence or both)') 80 | self.add_argument('--warmstart_experiments', type=str, 81 | help='Path to JSON file containing a list of warmstart experiments to run before MCTS') -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from src.utils import fetch_from_s3 5 | 6 | 7 | def load_blade_metadata(info_path: str) -> dict: 8 | with open(info_path, 'r') as file: 9 | return json.load(file) 10 | 11 | 12 | def load_info_metadata(info_path: str) -> dict: 13 | with open(info_path, 'r') as file: 14 | return json.load(file) 15 | 16 | 17 | def get_blade_description(info_path: str) -> str: 18 | """ 19 | Generate a human-readable description of the dataset based on its metadata. 20 | 21 | Args: 22 | info_path: Path to info.json file 23 | 24 | Returns: 25 | str: Formatted description of the dataset 26 | """ 27 | metadata = load_blade_metadata(info_path) 28 | data = metadata.get('data_desc', {}) 29 | 30 | description = [] 31 | 32 | # Add header 33 | description.append("Dataset Description") 34 | 35 | # Add dataset section 36 | description.append("\nDatasets:") 37 | 38 | # Extract dataset name from second-to-last folder in data_path 39 | data_path = metadata.get("data_path", "") 40 | dataset_name = os.path.basename(os.path.dirname(data_path)) or "Unnamed Dataset" 41 | description.append(f"Dataset Name: {dataset_name}") 42 | 43 | dataset_desc = data.get("dataset_description", "No description available.") 44 | description.append(f"Dataset Description: {dataset_desc}") 45 | 46 | # Add columns 47 | description.append("\nColumns:") 48 | fields = data.get("fields", []) 49 | for field in fields: 50 | col_name = field.get("column", "Unnamed") 51 | col_desc = field.get("properties", {}).get("description", "No description available.") 52 | description.append(f"\n{col_name}:") 53 | description.append(f" {col_desc}") 54 | 55 | return "\n".join(description) 56 | 57 | 58 | def load_ai2_metadata(info_path: str) -> dict: 59 | with open(info_path, 'r') as file: 60 | return json.load(file) 61 | 62 | 63 | def get_ai2_description(ai2_metadata_path: str) -> str: 64 | """ 65 | Generate a human-readable description of the AI2 dataset based on its metadata. 66 | 67 | Args: 68 | ai2_metadata_path: Path to the AI2-style metadata JSON file 69 | 70 | Returns: 71 | str: Formatted description of the dataset 72 | """ 73 | metadata = load_ai2_metadata(ai2_metadata_path) 74 | 75 | description = [] 76 | 77 | # Add header 78 | description.append("Dataset Description") 79 | 80 | # Add dataset section 81 | description.append("\nDatasets:") 82 | for dataset in metadata.get("datasets", []): 83 | name = dataset.get("name", "Unnamed Dataset") 84 | desc = dataset.get("description", "No description available.") 85 | description.append(f"Dataset Name: {name}") 86 | description.append(f"Dataset Description: {desc}") 87 | 88 | # Add columns 89 | description.append("\nColumns:") 90 | for col in dataset.get("columns", {}).get("raw", []): 91 | col_name = col.get("name", "Unnamed") 92 | col_desc = col.get("description", "No description available.") 93 | description.append(f"\n{col_name}:") 94 | description.append(f" {col_desc}") 95 | 96 | return "\n".join(description) 97 | 98 | 99 | def load_dataset_metadata(dataset_metadata_path: str, dataset_metadata_key: str = None) -> dict: 100 | with open(dataset_metadata_path, 'r') as file: 101 | dataset_metadata = json.load(file) 102 | if dataset_metadata_key is not None: 103 | dataset_metadata = dataset_metadata[dataset_metadata_key] 104 | return dataset_metadata 105 | 106 | 107 | def get_dataset_description(dataset_metadata_path: str) -> str: 108 | """ 109 | Generate a human-readable description of the dataset based on its metadata. 110 | 111 | Args: 112 | dataset_metadata_path: Path to the dataset metadata JSON file 113 | 114 | Returns: 115 | str: Formatted description of the dataset 116 | """ 117 | 118 | metadata = load_dataset_metadata(dataset_metadata_path) 119 | description = [] 120 | 121 | # Add header 122 | description.append("##### DATASET DESCRIPTION #####") 123 | # Add dataset info 124 | description.append("\n### DATASETS: ###\n") 125 | for dataset in metadata['datasets']: 126 | description.append(f"Dataset Name: {dataset['name']}") 127 | description.append(f"Dataset Description: {dataset['description']}") 128 | description.append("\n### COLUMNS: ###") 129 | for col in dataset['columns']['raw']: 130 | description.append(f"\n{col['name']}:") 131 | description.append(f" {col['description']}") 132 | 133 | return "\n".join(description) 134 | 135 | 136 | def get_datasets_fpaths(dataset_metadata: str, is_blade=False) -> (list, str): 137 | is_s3 = dataset_metadata.startswith("s3") 138 | _dataset_metadata = dataset_metadata 139 | if is_s3: 140 | # Download the metadata and get the local path 141 | _dataset_metadata = fetch_from_s3([dataset_metadata])[0] 142 | # Read the json, loop through "datasets" key, then extract dataset path from "name" key 143 | with open(_dataset_metadata, 'r') as file: 144 | obj = json.load(file) 145 | # Get the dataset paths 146 | datasets = [] 147 | if not is_blade: 148 | for d in obj.get('datasets', []): 149 | datasets.append(d["name"]) 150 | else: 151 | # Blade-style metadata 152 | datasets.append("data.csv") 153 | if is_s3: 154 | # Download the datasets using the s3 path and dataset names and get the local paths 155 | paths = fetch_from_s3([os.path.join(os.path.dirname(dataset_metadata), d) for d in datasets]) 156 | else: 157 | paths = [os.path.join(os.path.dirname(_dataset_metadata), d) for d in datasets] 158 | 159 | return paths, _dataset_metadata 160 | 161 | 162 | def get_load_dataset_experiment(dataset_paths, dataset_metadata, run_eda=False, dataset_metadata_type="dbench"): 163 | # Set up the initial experiment to load the dataset 164 | load_dataset_objective = "Load the dataset and generate summary statistics. " 165 | load_dataset_steps = f"1. Load the dataset(s) at {[os.path.basename(dp) for dp in dataset_paths]}.\n2. Generate summary statistics for the dataset(s)." 166 | load_dataset_deliverables = "1. Dataset(s) loaded.\n2. Summary statistics generated." 167 | if run_eda: 168 | load_dataset_steps += "\n3. Perform some exploratory data analysis (EDA) on the dataset(s) to get a better understanding of the data." 169 | load_dataset_deliverables += "\n3. Exploratory data analysis (EDA) performed." 170 | if dataset_metadata_type == 'blade': 171 | load_dataset_objective += f"Here is the dataset metadata:\n\n{get_blade_description(dataset_metadata)}" 172 | else: # DiscoveryBench-style 173 | load_dataset_objective += f"Here is the dataset metadata:\n\n{get_dataset_description(dataset_metadata)}" 174 | load_dataset_experiment = { 175 | "hypothesis": None, 176 | "experiment_plan": { 177 | "objective": load_dataset_objective, 178 | "steps": load_dataset_steps, 179 | "deliverables": load_dataset_deliverables 180 | } 181 | } 182 | return load_dataset_experiment 183 | -------------------------------------------------------------------------------- /src/log_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from beliefs import evaluate_hypothesis_distribution 5 | 6 | def load_node_logs(directory, level=None, node_idx=None): 7 | """ 8 | Loads and parses log files in a directory that match the format 9 | 'node_level_index.json', saving the contents in a dictionary. 10 | 11 | Args: 12 | directory (str): The path to the directory. 13 | level (int, optional): The level of the node in the tree. If provided, only logs for this level are loaded. 14 | node_idx (int, optional): The index of the node within its level. If provided, only logs for this node are loaded. 15 | 16 | Returns: 17 | dict: A dictionary where keys are tuples (index, level) and values 18 | are the contents of the log files. Returns an empty dict if 19 | no matching files are found or an error occurs. 20 | """ 21 | result = {} 22 | try: 23 | for filename in os.listdir(directory): 24 | match = re.match(r"node_(\d+)_(\d+)\.json", filename) # Updated extension 25 | if match: 26 | file_level = int(match.group(1)) 27 | file_index = int(match.group(2)) 28 | if (level is not None and file_level != level) or (node_idx is not None and file_index != node_idx): 29 | continue 30 | filepath = os.path.join(directory, filename) 31 | try: 32 | with open(filepath, "r") as f: 33 | log_content = json.load(f) # Read the entire log file 34 | result[(file_level, file_index)] = extract_node_messages(log_content) 35 | except OSError as e: 36 | print(f"Error reading {filename}: {e}") 37 | return result 38 | except FileNotFoundError: 39 | print(f"Directory not found: {directory}") 40 | return {} 41 | except OSError as e: 42 | print(f"OS error occurred: {e}") 43 | return {} 44 | 45 | 46 | def extract_node_messages(json_data): 47 | """ 48 | Extracts messages starting from the last occurrence of 49 | a message with role 'user_proxy'. 50 | 51 | Args: 52 | json_data (list): A list of dictionaries representing JSON data. 53 | 54 | Returns: 55 | list: A list of dictionaries containing messages from the last 'user_proxy' 56 | message onwards. Returns an empty list if no 'user_proxy' messages 57 | are found. 58 | """ 59 | if not isinstance(json_data, list): 60 | return [] 61 | 62 | user_proxy_indices = [i for i, msg in enumerate(json_data) if msg.get("name") == "user_proxy"] 63 | 64 | if not user_proxy_indices: 65 | return [] 66 | 67 | last_user_proxy_index = user_proxy_indices[-1] 68 | return json_data[last_user_proxy_index:] 69 | 70 | def extract_hypotheses_from_logs(messages): 71 | """ 72 | Extracts hypotheses from a list of messages. 73 | 74 | Args: 75 | messages (list): A list of message dictionaries 76 | 77 | Returns: 78 | list: List of extracted hypotheses 79 | """ 80 | hypotheses = [] 81 | for msg in messages: 82 | if msg.get("name") == "hypothesis_generator": 83 | try: 84 | content = json.loads(msg.get("content", "{}")) 85 | if "hypothesis" in content: 86 | hypotheses.append(content["hypothesis"]) 87 | except (json.JSONDecodeError, TypeError): 88 | continue 89 | return hypotheses 90 | 91 | def save_belief_distribution(log_dirname, level, node_idx, messages, current_hypothesis, context, distribution, 92 | model="gpt-4o", n_samples=30, is_prior=False, temperature=None): 93 | """Save belief distribution with messages in proper JSON format. 94 | 95 | Args: 96 | log_dirname (str): Directory where belief logs are stored 97 | level (int): Tree level of the current node 98 | node_idx (int): Index of the current node within its level 99 | messages (list): List of message dicts containing hypotheses and evidence 100 | current_hypothesis (str): The hypothesis being evaluated 101 | context (str): Context type - one of "current", "branch", or "all" 102 | distribution (str): Distribution type - either "prior" or "posterior" 103 | """ 104 | belief_result = evaluate_hypothesis_distribution( 105 | messages=messages, 106 | hypothesis=current_hypothesis, 107 | n_samples=n_samples, 108 | temperature=temperature, 109 | is_prior=is_prior, 110 | model=model 111 | ) 112 | 113 | belief_record = { 114 | "belief_result": json.loads(belief_result.model_dump_json()), 115 | "context": context, 116 | "messages": messages, 117 | "distribution": distribution, 118 | "current_hypothesis": current_hypothesis 119 | } 120 | 121 | belief_log_filename = os.path.join(log_dirname, f"belief_{level}_{node_idx}.json") 122 | 123 | try: 124 | with open(belief_log_filename, 'r') as f: 125 | records = json.load(f) 126 | except (FileNotFoundError, json.JSONDecodeError): 127 | records = [] 128 | 129 | records.append(belief_record) 130 | with open(belief_log_filename, 'w') as f: 131 | json.dump(records, f, indent=2) 132 | 133 | def load_parent_hypotheses(log_dirname, level, parent_node_idx, context_type="branch"): 134 | """Load hypotheses from parent node's belief logs based on context type. 135 | 136 | Args: 137 | log_dirname (str): Directory where belief logs are stored 138 | level (int): Current tree level (parent will be level-1) 139 | parent_node_idx (int): Index of the parent node 140 | context_type (str): Context to load - one of "current", "branch", or "all" 141 | 142 | Returns: 143 | list: List of dicts containing hypotheses and their beliefs from the parent node 144 | """ 145 | if parent_node_idx is None: 146 | return [] 147 | 148 | parent_log_filename = os.path.join(log_dirname, f"belief_{level-1}_{parent_node_idx}.json") 149 | if not os.path.exists(parent_log_filename): 150 | return [] 151 | 152 | try: 153 | with open(parent_log_filename, 'r') as f: 154 | records = json.load(f) 155 | except json.JSONDecodeError: 156 | return [] 157 | 158 | # Find latest posterior belief record with matching context 159 | matching_records = [r for r in records 160 | if r["distribution"] == "posterior" and 161 | r["context"] == context_type] 162 | 163 | if not matching_records: 164 | return [] 165 | 166 | latest_record = matching_records[-1] 167 | 168 | hypotheses = [] 169 | for msg in latest_record["messages"]: 170 | if msg.get("name") == "my_hypotheses": 171 | try: 172 | content = json.loads(msg["content"]) 173 | if "hypothesis" in content: 174 | hypotheses.append({ 175 | "hypothesis": content["hypothesis"], 176 | "belief": content.get("belief") 177 | }) 178 | except (json.JSONDecodeError, KeyError): 179 | continue 180 | 181 | # Set belief for latest hypothesis from belief result 182 | if hypotheses: 183 | hypotheses[-1]["belief"] = latest_record["belief_result"]["believes_hypothesis"] 184 | 185 | return hypotheses 186 | 187 | def get_current_hypothesis_and_evidence(node_logs): 188 | """Extract current hypothesis and evidence from node logs. 189 | 190 | Returns: 191 | tuple: (current_hypothesis, evidence_messages) where: 192 | - current_hypothesis (str): The latest hypothesis generated 193 | - evidence_messages (list): All messages from the current node interaction 194 | """ 195 | if not node_logs: 196 | return None, [] 197 | 198 | messages = next(iter(node_logs.values())) 199 | 200 | # Find last user_proxy message by iterating in reverse 201 | start_idx = None 202 | for i in range(len(messages) - 1, -1, -1): 203 | if messages[i].get("name") == "user_proxy": 204 | start_idx = i 205 | break 206 | 207 | if start_idx is None: 208 | return None, [] 209 | 210 | node_messages = messages[start_idx:] 211 | 212 | # Find first hypothesis in node_messages 213 | current_hypothesis = None 214 | for msg in node_messages: 215 | if msg.get("name") == "hypothesis_generator": 216 | try: 217 | content = json.loads(msg["content"]) 218 | current_hypothesis = content.get("hypothesis") 219 | if current_hypothesis: 220 | break 221 | except (json.JSONDecodeError, TypeError): 222 | continue 223 | 224 | return current_hypothesis, node_messages 225 | 226 | def load_all_hypotheses(log_dirname): 227 | """Load all hypotheses and their beliefs from all belief log files. 228 | 229 | Args: 230 | log_dirname (str): Directory where belief logs are stored 231 | 232 | Returns: 233 | list: List of dicts containing hypotheses and their beliefs from all nodes 234 | """ 235 | all_hypotheses = [] 236 | 237 | try: 238 | for filename in os.listdir(log_dirname): 239 | if not filename.startswith("belief_") or not filename.endswith(".json"): 240 | continue 241 | 242 | try: 243 | level, node_idx = map(int, filename[7:-5].split("_")) 244 | except ValueError: 245 | continue 246 | 247 | filepath = os.path.join(log_dirname, filename) 248 | try: 249 | with open(filepath, 'r') as f: 250 | records = json.load(f) 251 | except (json.JSONDecodeError, OSError): 252 | continue 253 | 254 | # Find latest posterior belief record with context "all" 255 | matching_records = [r for r in records 256 | if r["distribution"] == "posterior" and 257 | r["context"] == "all"] 258 | 259 | if not matching_records: 260 | continue 261 | 262 | latest_record = matching_records[-1] 263 | 264 | all_hypotheses.append({ 265 | "hypothesis": latest_record["current_hypothesis"], 266 | "belief": latest_record["belief_result"]["believes_hypothesis"], 267 | # "level": level, 268 | # "node_idx": node_idx 269 | }) 270 | 271 | except OSError: 272 | return [] 273 | 274 | return all_hypotheses -------------------------------------------------------------------------------- /src/deduplication.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import random 5 | import numpy as np 6 | from openai import OpenAI 7 | from scipy.cluster.hierarchy import linkage 8 | from pydantic import BaseModel, Field 9 | from tqdm import tqdm 10 | 11 | from src.utils import query_llm 12 | 13 | 14 | class ArgParser(argparse.ArgumentParser): 15 | def __init__(self, group=None): 16 | super().__init__(description='Get surprising nodes from MCTS logs') 17 | self.add_argument('--in_fpath', type=str, required=True, 18 | help='mcts_nodes.json file path or directory containing mcts_node_*.json files') 19 | self.add_argument('--out_fpath', type=str, help='output directory for clusters and labels') 20 | self.add_argument('--n_samples', type=int, default=30, help='Number of samples for LLM decisions') 21 | self.add_argument('--merge_threshold', type=float, default=0.7, 22 | help='Threshold for merging hypotheses based on LLM decisions') 23 | self.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') 24 | self.add_argument('--model', type=str, default="gpt-4o", 25 | help='LLM model to use for hypothesis merging decisions. Default is gpt-4o.') 26 | self.add_argument('--n_nodes', type=int, default=None, 27 | help='Number of nodes to process. If None, all nodes are processed.') 28 | self.add_argument('--verbose', action=argparse.BooleanOptionalAction, default=False, 29 | help='Whether to print verbose output during deduplication process.') 30 | 31 | 32 | def hyp_dict_to_str(d): 33 | return (f"Hypothesis: {d.get('hypothesis', 'N/A')}\n" 34 | f"Contexts: {d.get('contexts', d.get('context', 'N/A'))}\n" 35 | f"Variables: {d.get('variables', 'N/A')}\n" 36 | f"Relationships: {d.get('relationships', 'N/A')}") 37 | 38 | 39 | def get_structured_hypothesis(node): 40 | level, node_idx = node.get("level"), node.get("node_idx") 41 | h = node.get("hypothesis", None) 42 | if h is None: 43 | return None 44 | hyp_str = h.get("hypothesis", "") 45 | dims = h.get("dimensions", { 46 | "contexts": [], 47 | "variables": [], 48 | "relationships": [] 49 | }) 50 | return { 51 | "node_id": f"node_{level}_{node_idx}", 52 | "hypothesis": hyp_str, 53 | **dims 54 | } 55 | 56 | 57 | def get_structured_hypotheses(in_nodes_list): 58 | node_list = [hyp_obj for node in in_nodes_list if (hyp_obj := get_structured_hypothesis(node)) is not None] 59 | return node_list 60 | 61 | 62 | def get_hypothesis(node): 63 | level, node_idx = node.get("level"), node.get("node_idx") 64 | h = node.get("hypothesis", None) 65 | if h is None: 66 | return None 67 | return { 68 | "node_id": f"node_{level}_{node_idx}", 69 | "hypothesis": h, 70 | } 71 | 72 | 73 | def get_hypotheses(in_nodes_list): 74 | node_list = [hyp_obj for node in in_nodes_list if (hyp_obj := get_hypothesis(node)) is not None] 75 | return node_list 76 | 77 | 78 | def get_embedding(texts, model="text-embedding-3-large", batch_size=128, client=None, n_attempts=1): 79 | """ 80 | Compute embeddings for a list of texts using the OpenAI Embeddings API. 81 | Args: 82 | texts (list): A list of text strings to be embedded. 83 | model (str, optional): The identifier for the embedding model to use. 84 | batch_size (int, optional): The number of texts to process in one API call. 85 | Returns: 86 | numpy.ndarray: An array of embeddings for the input texts. 87 | """ 88 | if client is None: 89 | client = OpenAI() 90 | all_embeddings = [] 91 | for attempt in range(n_attempts): 92 | try: 93 | all_embeddings = [] 94 | # Process the texts in batches 95 | for i in range(0, len(texts), batch_size): 96 | batch = texts[i: i + batch_size] 97 | # Request embeddings for the current batch from the API 98 | response = client.embeddings.create(input=batch, model=model) 99 | for item in response.data: 100 | # Convert the embedding to a NumPy array and add it to the list 101 | all_embeddings.append(np.array(item.embedding)) 102 | break # If successful, exit the loop 103 | except Exception as e: 104 | if attempt < n_attempts - 1: 105 | print(f"Embeddings: Attempt {attempt + 1} failed: {e}. Retrying...") 106 | else: 107 | raise RuntimeError(f"Failed to get embeddings after {n_attempts} attempts.") from e 108 | return np.array(all_embeddings) 109 | 110 | 111 | def get_llm_merge_decision(hyp1: str, hyp2: str, n_samples: int = 30, threshold: float = 0.7, model: str = "gpt-4o", 112 | temperature: float = 1.0, reasoning_effort: str = "medium"): 113 | class ResponseFormat(BaseModel): 114 | is_same: bool = Field(..., description="Whether the two hypotheses are the same or not.") 115 | 116 | system_prompt = "You are a research scientist skilled at analyzing statistical hypotheses." 117 | # prompt = ( 118 | # f"You are given two hypothesis sets. Each set describes a single hypothesis structured into a context for the " 119 | # f"hypothesis, the variables involved, and the statistical relationships between the variables under that " 120 | # f"context. Your task is to determine whether both sets represent the same hypothesis or not.\n\n" 121 | # f"Hypothesis Set 1:\n{hyp1}\n\nHypothesis Set 2:\n{hyp2}" 122 | # ) 123 | prompt = ( 124 | f"You are given two hypothesis sets. Each set describes a single hypothesis, structured into a context for the " 125 | f"hypothesis, the variables involved, and the statistical relationships between the variables under that " 126 | f"context. Your task is to determine whether the two hypotheses are semantically the same or not.\n\n" 127 | f"HYPOTHESIS 1:\n{hyp1}\n\nHYPOTHESIS 2:\n{hyp2}" 128 | ) 129 | all_msgs = [ 130 | {"role": "system", "content": system_prompt}, 131 | {"role": "user", "content": prompt} 132 | ] 133 | response = query_llm(all_msgs, model=model, n_samples=n_samples, 134 | temperature=temperature, reasoning_effort=reasoning_effort, 135 | response_format=ResponseFormat) 136 | true_prop = sum([1 for _res in response if _res["is_same"]]) / n_samples 137 | 138 | return true_prop >= threshold 139 | 140 | 141 | def dedupe(nodes_or_json_path, n_samples=10, merge_threshold=0.7, seed=42, rep_mode="biggest", model="gpt-4o", 142 | n_nodes=None, verbose=False, log_comparisons_fname=None): 143 | random.seed(seed) 144 | np.random.seed(seed) 145 | 146 | from src.mcts_utils import get_nodes # Importing here to avoid circular import issues 147 | nodes_list = get_nodes(nodes_or_json_path)[:n_nodes] 148 | data = get_hypotheses(nodes_list) 149 | 150 | dedup_hyp, dedup_struct_hyp, hyp_to_index, orig_to_dedup = [], [], {}, [] 151 | 152 | # Deduplicate hypotheses by exact match 153 | for d in data: 154 | hyp = d["hypothesis"] # Hypothesis string 155 | if hyp not in hyp_to_index: 156 | dedup_hyp.append(hyp) 157 | dedup_struct_hyp.append(hyp_dict_to_str(d)) 158 | hyp_to_index[hyp] = len(dedup_struct_hyp) - 1 159 | orig_to_dedup.append(hyp_to_index[hyp]) 160 | n_dedup = len(dedup_struct_hyp) 161 | 162 | # Generate embeddings for deduplicated hypotheses 163 | embeds = np.array(get_embedding(dedup_hyp, n_attempts=3)) 164 | 165 | # Initialize assignment structures 166 | clusters = {i: [i] for i in range(n_dedup)} 167 | cluster_assignment = {i: i for i in range(n_dedup)} 168 | hac_to_current = {i: i for i in range(n_dedup)} 169 | cluster_rep = {i: i for i in range(n_dedup)} 170 | 171 | # Perform HAC over LM embeddings and get the linkage matrix 172 | linkage_matrix = linkage(embeds, method='ward') 173 | 174 | # Iterate through the linkage matrix to additionally merge clusters based on LLM decisions 175 | pbar = tqdm(linkage_matrix, desc="Deduplicating") 176 | pbar.set_postfix({"n_clusters": len(clusters)}) 177 | llm_comparisons = [] 178 | for r, row in enumerate(pbar): 179 | hac_node_id = n_dedup + r 180 | left_hac, right_hac = int(row[0]), int(row[1]) 181 | left_current = hac_to_current.get(left_hac) 182 | right_current = hac_to_current.get(right_hac) 183 | if left_current is None or right_current is None or left_current == right_current: 184 | hac_to_current[hac_node_id] = left_current if left_current is not None else right_current 185 | continue 186 | rep_left, rep_right = cluster_rep[left_current], cluster_rep[right_current] 187 | # struct_left, struct_right = dedup_struct_hyp[rep_left], dedup_struct_hyp[rep_right] 188 | struct_left, struct_right = dedup_hyp[rep_left], dedup_hyp[rep_right] 189 | # Get the LLM merge decision 190 | llm_decision = get_llm_merge_decision( 191 | struct_left, struct_right, 192 | n_samples=n_samples, 193 | threshold=merge_threshold, 194 | model=model 195 | ) 196 | if verbose: 197 | print(f"""\n\n 198 | Cluster left (size: {len(clusters[left_current])}): 199 | {struct_left} 200 | 201 | Cluster right (size: {len(clusters[right_current])}): 202 | {struct_right} 203 | 204 | LLM Decision: {'Merge' if llm_decision else 'Do not merge'}\n\n""") 205 | if log_comparisons_fname is not None: 206 | llm_comparisons.append({ 207 | "left_size": len(clusters[left_current]), 208 | "right_size": len(clusters[right_current]), 209 | "left_hypothesis": struct_left, 210 | "right_hypothesis": struct_right, 211 | "llm_decision": llm_decision 212 | }) 213 | 214 | if llm_decision: 215 | if rep_mode == "random": 216 | new_rep = random.choice([rep_left, rep_right]) 217 | elif rep_mode == "biggest": 218 | new_rep = rep_left if len(clusters[left_current]) >= len(clusters[right_current]) else rep_right 219 | else: 220 | raise NotImplementedError 221 | merged_cluster_id = min(left_current, right_current) 222 | other_cluster_id = max(left_current, right_current) 223 | clusters[merged_cluster_id] += clusters[other_cluster_id] 224 | for idx in clusters[merged_cluster_id]: 225 | cluster_assignment[idx] = merged_cluster_id 226 | cluster_rep[merged_cluster_id] = new_rep 227 | del clusters[other_cluster_id] 228 | del cluster_rep[other_cluster_id] 229 | hac_to_current[hac_node_id] = merged_cluster_id 230 | else: 231 | hac_to_current[hac_node_id] = None 232 | 233 | # Update pbar with number of clusters 234 | pbar.set_postfix({"n_clusters": len(clusters)}) 235 | 236 | if log_comparisons_fname is not None: 237 | with open(log_comparisons_fname, 'w') as f: 238 | json.dump(llm_comparisons, f, indent=2) 239 | print(f"LLM comparisons logged to {log_comparisons_fname}") 240 | 241 | final_labels = [cluster_assignment[orig_to_dedup[i]] for i in range(len(orig_to_dedup))] 242 | 243 | # Dedupe nodes and update cluster information 244 | deduped_nodes = [] 245 | for cluster_id, cluster in clusters.items(): 246 | node_copy = copy.deepcopy(nodes_list[cluster_id]) 247 | node_copy['cluster'] = [nodes_list[n]['id'] for n in cluster[1:]] 248 | node_copy['cluster_nodes'] = [copy.deepcopy(nodes_list[n]) for n in cluster if n != cluster_id] 249 | deduped_nodes.append(node_copy) 250 | 251 | return deduped_nodes, final_labels, clusters 252 | 253 | 254 | if __name__ == "__main__": 255 | parser = ArgParser() 256 | args = parser.parse_args() 257 | deduped_nodes, final_labels, clusters = dedupe(nodes_or_json_path=args.in_fpath, 258 | n_samples=args.n_samples, 259 | merge_threshold=args.merge_threshold, 260 | seed=args.seed, 261 | model=args.model, 262 | n_nodes=args.n_nodes, 263 | verbose=args.verbose) 264 | print("Final Labels:", final_labels) 265 | print("Clusters:", clusters) 266 | 267 | if args.out_fpath is not None: 268 | # Save the results to the output file 269 | output_data = { 270 | "final_labels": final_labels, 271 | "clusters": clusters, 272 | "deduped_nodes": deduped_nodes 273 | } 274 | with open(args.out_fpath, 'w') as f: 275 | json.dump(output_data, f, indent=2) 276 | print(f"Results saved to {args.out_fpath}") 277 | -------------------------------------------------------------------------------- /src/mcts_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import regex as re 4 | from collections import defaultdict 5 | from typing import List, Dict, Literal 6 | from glob import glob 7 | 8 | from autogen import GroupChat, GroupChatManager 9 | 10 | from src.deduplication import dedupe 11 | from src.nodes_to_csv import nodes_to_csv 12 | from src.transitions import SpeakerSelector 13 | 14 | 15 | def load_mcts_from_json(json_obj_or_file_or_dir, args=None, replay_mcts=True): 16 | """Load and reconstruct MCTS nodes from a JSON object, log file, or directory. 17 | 18 | Args: 19 | json_obj_or_file_or_dir: Loaded JSON object or a path to the mcts_nodes.json file or to a directory with mcts_node_*.json files. 20 | 21 | Returns: 22 | root: Root MCTSNode 23 | nodes_by_level: Dictionary mapping levels to lists of MCTSNodes 24 | """ 25 | from src.mcts import MCTSNode # Import here to avoid circular import issues 26 | 27 | node_data = get_nodes(json_obj_or_file_or_dir) 28 | 29 | # Initialize tree data structures 30 | nodes_by_level = defaultdict(list) 31 | node_map = {} # Map (level, idx) to node objects for linking 32 | 33 | # Iterate over the nodes in level order and build the tree 34 | node_data.sort(key=lambda x: (int(x['id'].split('_')[1]), int(x['id'].split('_')[2]))) 35 | for data in node_data: 36 | # Create an empty node and initialize from dict (parent links added in second pass) 37 | node = MCTSNode(allow_generate_experiments=args.allow_generate_experiments if args else True) 38 | node.init_from_dict(data) 39 | # Add to data structures 40 | nodes_by_level[node.level].append(node) 41 | node_map[(node.level, node.node_idx)] = node 42 | # Link to parent 43 | if node.parent_id is not None: 44 | parent_level = node.level - 1 45 | parent_idx = node.parent_idx 46 | try: 47 | node.parent = node_map[(parent_level, parent_idx)] 48 | node.parent.children.append(node) 49 | except KeyError: 50 | assert (parent_level, parent_idx) == ( 51 | 0, 0), f"Parent node ({parent_level}, {parent_idx}) not found in node_map." 52 | 53 | # Create root node if it does not exist 54 | if (0, 0) not in node_map: 55 | node = MCTSNode(level=0, node_idx=0, creation_idx=0) 56 | nodes_by_level[0].append(node) # Figure out creation_idx use 57 | node_map[(0, 0)] = node 58 | # Link root to the tree 59 | node.children = [node_map[(1, 0)]] 60 | node_map[(1, 0)].parent = node 61 | 62 | assert len(node_map) == MCTSNode._creation_counter 63 | root = node_map[(0, 0)] 64 | 65 | # Fix tried/untried experiments 66 | for node in node_map.values(): 67 | _tried_experiments, _untried_experiments = [], [] 68 | cur_untried_experiments = set(list(map(get_query_from_experiment, node.untried_experiments))) 69 | for child in node.children: 70 | # Keep only children in tried experiments 71 | _tried_experiments.append(get_experiment_from_query(child.query)) 72 | # Remove child from untried experiments if exists 73 | if child.query in cur_untried_experiments: 74 | cur_untried_experiments.remove(child.query) 75 | _untried_experiments = list(map(get_experiment_from_query, list(cur_untried_experiments))) 76 | node.tried_experiments = _tried_experiments 77 | node.untried_experiments = _untried_experiments 78 | 79 | if replay_mcts: 80 | # Replay MCTS to assign correct visits and values in order of creation_idx 81 | _nodes = sorted(node_map.values(), key=lambda x: x.creation_idx) 82 | # Reset visits and value 83 | for _node in _nodes: 84 | _node.visits = 0 85 | _node.value = 0 86 | # Backpropagate visits and values 87 | for _node in _nodes: 88 | _node.update_counts(visits=1, reward=_node.self_value) 89 | 90 | return root, nodes_by_level 91 | 92 | 93 | def save_nodes(nodes_dict_or_list, log_dirname, run_dedupe=True, model="gpt-4o", save_csv=True, 94 | time_elapsed=None): 95 | """Save MCTS nodes to JSON and optionally to CSV. 96 | 97 | Args: 98 | nodes_dict_or_list: Dictionary or list of MCTSNode objects or dicts. 99 | log_dirname: Directory to save the JSON and CSV files. 100 | run_dedupe: Whether to deduplicate nodes based on hypothesis. 101 | model: Model to use for deduplication. 102 | save_csv: Whether to save nodes to a CSV file. 103 | time_elapsed: Optional time elapsed for logging purposes. 104 | """ 105 | from src.mcts import MCTSNode # Import here to avoid circular import issues 106 | 107 | if type(nodes_dict_or_list) in [dict, defaultdict]: 108 | nodes_list = [] 109 | for level, nodes in nodes_dict_or_list.items(): 110 | if level == 0: 111 | continue 112 | for node in nodes: 113 | nodes_list.append(node.to_dict()) 114 | else: 115 | nodes_list = nodes_dict_or_list 116 | if type(nodes_list[0]) is MCTSNode: 117 | # Convert MCTSNode objects to dicts 118 | nodes_list = [node.to_dict() for node in nodes_list] 119 | 120 | # Save nodes to JSON 121 | nodes_list = save_nodes_to_json(nodes_list, log_dirname, run_dedupe=run_dedupe, dedupe_model=model, 122 | time_elapsed=time_elapsed) 123 | 124 | # Save nodes to CSV 125 | if save_csv: 126 | csv_output_file = os.path.join(log_dirname, "mcts_nodes.csv") 127 | nodes_to_csv(nodes_list, csv_output_file) 128 | 129 | 130 | def save_nodes_to_json(nodes_list, log_dirname, run_dedupe=True, dedupe_model="gpt-4o", log_dedupe_comparisons=False, 131 | time_elapsed=None): 132 | """Save all MCTS nodes to a JSON file. 133 | 134 | Args: 135 | nodes_list: List of MCTS node objects. 136 | log_dirname: Directory to save the JSON file 137 | run_dedupe: Whether to deduplicate nodes based on hypothesis. 138 | dedupe_model: Model to use for deduplication. 139 | log_dedupe_comparisons: Whether to log deduplication comparisons to a file. 140 | time_elapsed: Optional time elapsed for logging purposes. 141 | """ 142 | # Optionally deduplicate nodes based on hypothesis 143 | if run_dedupe: 144 | deduped_nodes, _, _ = dedupe(nodes_list, model=dedupe_model, 145 | log_comparisons_fname=None if not log_dedupe_comparisons else os.path.join( 146 | log_dirname, "dedupe_comparisons.json")) 147 | file_to_save = deduped_nodes 148 | else: 149 | file_to_save = nodes_list 150 | 151 | output_file = os.path.join(log_dirname, "mcts_nodes.json") 152 | with open(output_file, "w") as f: 153 | json.dump(file_to_save, f, indent=2) 154 | print(f"[JSON] MCTS nodes (n={len(file_to_save)}) saved to {output_file}.\n") 155 | # Also save the original nodes list for reference 156 | original_nodes_file = os.path.join(log_dirname, "mcts_nodes_all.json") 157 | with open(original_nodes_file, "w") as f: 158 | json.dump(nodes_list, f, indent=2) 159 | print(f"[JSON] Original MCTS nodes (n={len(nodes_list)}) saved to {original_nodes_file}.\n") 160 | if time_elapsed is not None: 161 | print(f"[Exploration] Time elapsed: {time_elapsed:.2f} seconds.\n") 162 | return file_to_save 163 | 164 | 165 | def get_msgs_from_latest_query(messages): 166 | # Find last user_proxy message by iterating in reverse 167 | start_idx = None 168 | for i, message in enumerate(reversed(messages)): 169 | if message.get("name") == "user_proxy": 170 | start_idx = len(messages) - 1 - i 171 | break 172 | if start_idx is None: 173 | return [] 174 | node_messages = messages[start_idx:] 175 | return node_messages 176 | 177 | 178 | def setup_group_chat(agents, max_rounds): 179 | # Set up the group chat with agents and rules 180 | group_chat = GroupChat( 181 | agents=list(agents.values()), 182 | messages=[], 183 | max_round=max_rounds, 184 | speaker_selection_method=SpeakerSelector().select_next_speaker 185 | ) 186 | chat_manager = GroupChatManager(groupchat=group_chat, llm_config=None) 187 | return group_chat, chat_manager 188 | 189 | 190 | def get_nodes(in_fpath_or_json: str | List[Dict[str, any]]) -> List[Dict[str, any]] | None: 191 | """ 192 | Load MCTS nodes from a file, directory, or a list of dictionaries without creating class objects. 193 | Args: 194 | in_fpath_or_json: Path to the MCTS nodes JSON file, a directory containing MCTS node files, or a list of MCTS nodes as dictionaries. 195 | 196 | Returns: 197 | List of MCTS nodes as dictionaries. 198 | """ 199 | if type(in_fpath_or_json) is list: 200 | mcts_nodes = in_fpath_or_json 201 | else: 202 | # Load the MCTS nodes from the input file 203 | if os.path.isdir(in_fpath_or_json): 204 | mcts_nodes = [] 205 | filenames = glob(os.path.join(in_fpath_or_json, 'mcts_node_*.json')) 206 | for filename in filenames: 207 | with open(filename, 'r') as f: 208 | obj = json.load(f) 209 | mcts_nodes.append(obj) 210 | else: 211 | with open(in_fpath_or_json, 'r') as f: 212 | mcts_nodes = json.load(f) 213 | return mcts_nodes 214 | 215 | 216 | def print_node_info(node): 217 | prior_mean = node.prior.get_mean_belief() 218 | posterior_mean = node.posterior.get_mean_belief(prior=node.prior) 219 | direction = "+" if posterior_mean > prior_mean else ("-" if posterior_mean < prior_mean else "=") 220 | print(f"""\n\n\ 221 | ================================================================================ 222 | 223 | NODE_LEVEL={node.level}, NODE_IDX={node.node_idx}: 224 | ------------------------- 225 | 226 | Hypothesis: {node.hypothesis} 227 | Prior: {prior_mean:.4f} 228 | Posterior: {posterior_mean:.4f} 229 | Surprisal: {node.surprising} 230 | Belief Change: {node.belief_change:.4f} ({direction}) 231 | KL Divergence: {node.kl_divergence:.4f} 232 | Reward: {node.self_value:.4f} 233 | 234 | ================================================================================\n\n""") 235 | 236 | 237 | def get_query_from_experiment(exp): 238 | hypothesis = exp['hypothesis'] 239 | exp_plan = exp['experiment_plan'] 240 | new_query = "" 241 | if hypothesis is not None: 242 | new_query += f"Hypothesis: {hypothesis}\n\n" 243 | new_query += f"""\ 244 | Experiment objective: {exp_plan['objective']} 245 | 246 | Steps for the programmer: 247 | {exp_plan['steps']} 248 | 249 | Deliverables: 250 | {exp_plan['deliverables']}""" 251 | return new_query 252 | 253 | 254 | def get_experiment_from_query(query): 255 | # Extract the hypothesis and experiment plan from the query 256 | hypothesis_match = re.search(r'Hypothesis:\s*(.*)', query) 257 | hypothesis = hypothesis_match.group(1).strip() if hypothesis_match else None 258 | 259 | exp_plan_match = re.search(r'Experiment objective:\s*(.*?)(?=\n\n|$)', query, re.DOTALL) 260 | exp_plan = exp_plan_match.group(1).strip() if exp_plan_match else None 261 | 262 | steps_match = re.search(r'Steps for the programmer:\s*(.*?)(?=\n\n|$)', query, re.DOTALL) 263 | steps = steps_match.group(1).strip() if steps_match else None 264 | 265 | deliverables_match = re.search(r'Deliverables:\s*(.*?)(?=\n\n|$)', query, re.DOTALL) 266 | deliverables = deliverables_match.group(1).strip() if deliverables_match else None 267 | 268 | return { 269 | "hypothesis": hypothesis, 270 | "experiment_plan": { 271 | "objective": exp_plan, 272 | "steps": steps, 273 | "deliverables": deliverables 274 | } 275 | } 276 | 277 | 278 | def get_node_level_idx(node_or_id): 279 | from src.mcts import MCTSNode 280 | 281 | # Get the level and index of a node from its ID (e.g., "node__") or MCTSNode/dict. 282 | if type(node_or_id) is MCTSNode: 283 | id = node_or_id.id 284 | elif type(node_or_id) is dict: 285 | id = node_or_id["id"] 286 | elif type(node_or_id) is str: 287 | id = node_or_id 288 | 289 | return map(int, id.split("_")[1:]) 290 | 291 | 292 | def get_context_string(hyp_exp_query, code_output=None, analysis=None, review=None, 293 | belief_mean=None, include_code_output=False): 294 | # Format the experiment to include as context in, e.g., an LLM call. 295 | context_str = hyp_exp_query 296 | if include_code_output and code_output is not None: 297 | context_str += f"\n\nCode Output:\n{code_output}" 298 | if analysis is not None: 299 | context_str += f"\n\nAnalysis:\n{analysis}" 300 | if review is not None: 301 | context_str += f"\n\nReview:\n{review}" 302 | if belief_mean is not None: 303 | context_str += f"\n\nBelief about this hypothesis (range 0-1: definitely false -> uncertain -> definitely true): {belief_mean:.4f}" 304 | 305 | return context_str 306 | 307 | 308 | def get_self_value(belief_change, kl_divergence, binary=True, width=0.2, kl_scale=20.0, 309 | mode: Literal["belief", "kl", "belief_and_kl"] = "belief_and_kl"): 310 | """Get self value for a node based on its belief. 311 | 312 | Args: 313 | belief_change (float): Change in belief from prior to posterior. 314 | kl_divergence (float): KL divergence between prior and posterior beliefs. 315 | binary (bool): Whether the surprisal reward is binary or continuous. 316 | width (float): Surprisal width for belief change. 317 | kl_scale (float): Normalization factor for KL divergence. 318 | mode (str): Mode to use for self value calculation. Choices: "belief", "kl", "both". 319 | 320 | Returns: 321 | float: Self value based on the belief type. 322 | bool: Whether it is a surprisal. 323 | """ 324 | if mode == "belief": 325 | if binary: 326 | return float(belief_change >= width), bool(belief_change >= width) 327 | else: 328 | # Continuous reward normalized by the surprisal width 329 | return belief_change / width, bool((belief_change / width) >= 1.0) 330 | elif mode == "kl": 331 | # KL divergence reward normalized by the KL scale 332 | if binary: 333 | return float(kl_divergence >= kl_scale), bool(kl_divergence >= kl_scale) 334 | else: 335 | return kl_divergence / kl_scale, bool((kl_divergence / kl_scale) >= 1.0) 336 | elif mode == "belief_and_kl": 337 | # Satisfy both modes 338 | belief_value, is_surprising_belief = get_self_value(belief_change, kl_divergence, binary, width, kl_scale, 339 | mode="belief") 340 | kl_value, is_surprising_kl = get_self_value(belief_change, kl_divergence, binary, width, kl_scale, 341 | mode="kl") 342 | # Combine both values 343 | combined_value = max(belief_value, kl_value) 344 | is_surprising = bool(is_surprising_belief or is_surprising_kl) 345 | return combined_value, is_surprising 346 | 347 | raise ValueError(f"Invalid mode: {mode}. Choose from 'belief', 'kl', or 'belief_and_kl'.") 348 | -------------------------------------------------------------------------------- /src/mcts_viz.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | MCTS Tree Visualization 7 | 8 | 156 | 157 | 158 |

MCTS Tree Visualization

159 | 160 |
161 | Surprising Results: 0 162 |
163 | 164 |
165 |

Drop MCTS log file here or click to upload

166 | 167 |
168 |
169 | 170 |
171 |
172 |
173 |

Node Details

174 |

Click on a node to view details

175 |
176 |
177 | 178 | 465 | 466 | 467 | -------------------------------------------------------------------------------- /src/agents.py: -------------------------------------------------------------------------------- 1 | from autogen import ConversableAgent, UserProxyAgent 2 | from src.structured_outputs import ExperimentList, ExperimentCode, ExperimentAnalyst, ExperimentReviewer, Experiment, \ 3 | ExperimentHypothesisList 4 | import os 5 | import json 6 | from autogen.coding import LocalCommandLineCodeExecutor 7 | from typing import Tuple 8 | 9 | import copy 10 | from typing import List, Dict 11 | import autogen.agentchat.contrib.capabilities.transforms as transforms 12 | from autogen.agentchat.contrib.capabilities import transform_messages 13 | 14 | IMAGE_ANALYSIS_PATCH = """\ 15 | import matplotlib.pyplot as plt 16 | import functools 17 | from io import BytesIO 18 | import base64 19 | from openai import OpenAI 20 | 21 | 22 | client = OpenAI() 23 | 24 | image_analyst_prompt = '''Please analyze the given plot image and provide the following: 25 | 26 | 1. Plot Type: Identify the type of plot (e.g., heatmap, bar plot, scatter plot) and its purpose. 27 | 2. Axes: 28 | * Titles and labels, including units. 29 | * Value ranges for both axes. 30 | 3. Data Trends: 31 | * For scatter plots: note trends, clusters, or outliers. 32 | * For bar plots: highlight the tallest and shortest bars and patterns. 33 | * For heatmaps: identify areas of high and low values. 34 | etc... 35 | 4. Annotations and Legends: Describe key annotations or legends. 36 | 5. Statistical Insights: Provide insights based on the information presented in the plot.''' 37 | 38 | 39 | def image_to_text(): 40 | for fig_num in plt.get_fignums(): 41 | fig = plt.figure(fig_num) # Get the current figure 42 | with BytesIO() as buf: 43 | # Save the figure to a PNG buffer 44 | fig.savefig(buf, format='png', dpi=200) 45 | buf.seek(0) 46 | # Encode image to base64 47 | base64_image = base64.b64encode(buf.read()).decode('utf-8') 48 | messages = [ 49 | { 50 | 'role': 'system', 51 | 'content': 'You are a research scientist responsible for analyzing plots and figures from running experiments and providing detailed descriptions.' 52 | }, 53 | { 54 | 'role': 'user', 55 | 'content': [ 56 | {'type': 'text', 'text': image_analyst_prompt}, 57 | { 58 | "type": "image_url", 59 | "image_url": { 60 | "url": f"data:image/png;base64,{base64_image}" 61 | } 62 | } 63 | ] 64 | } 65 | ] 66 | # Get image analysis from the LLM 67 | response = client.chat.completions.create( 68 | model="gpt-4o", 69 | messages=messages, 70 | max_tokens=1000, 71 | ) 72 | analysis = response.choices[0].message.content 73 | print(f"\\n=== Plot Analysis (fig. {fig_num}) ===\\n") 74 | print(analysis) 75 | print("\\n" + "="*50) 76 | 77 | plt.close(fig) 78 | 79 | 80 | def patch_matplotlib_show(): 81 | # Replace plt.show with our custom function 82 | plt.show = functools.partial(image_to_text) 83 | 84 | 85 | # Apply the patch 86 | patch_matplotlib_show() 87 | """ 88 | 89 | 90 | class CodeBlockWrapperTransform(transforms.MessageTransform): 91 | 92 | def apply_transform(self, messages: List[Dict]) -> List[Dict]: 93 | # Deep copy messages to avoid modifying the original 94 | transformed_messages = copy.deepcopy(messages) 95 | message = transformed_messages[-1] 96 | 97 | try: 98 | code = json.loads(message["content"]).get("code", "# Failed to parse code from message") 99 | except json.JSONDecodeError: 100 | code = "# Failed to parse code from message" 101 | 102 | message["content"] = f"```python\n{IMAGE_ANALYSIS_PATCH}\n\n{code}\n```" 103 | 104 | return transformed_messages 105 | 106 | def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: 107 | return "CodeBlockWrapperTransform", True 108 | 109 | 110 | def get_openai_config(api_key: str | None = None, temperature: float | None = None, 111 | reasoning_effort: str | None = None, timeout: int = 600, model_name="o4-mini"): 112 | config = { 113 | "api_type": "openai", 114 | "model": model_name, 115 | "timeout": timeout, 116 | "api_key": api_key, 117 | "max_retries": 3, 118 | "cache_seed": None # Disabling caching also addresses this bug: https://github.com/ag2ai/ag2/issues/1103 119 | } 120 | if temperature is not None: 121 | config["temperature"] = temperature 122 | 123 | # Make o-series specific changes 124 | if model_name.startswith("o"): 125 | if reasoning_effort is not None: 126 | config["reasoning_effort"] = reasoning_effort # Defaults to medium 127 | else: 128 | config["logprobs"] = True 129 | 130 | return config 131 | 132 | 133 | def get_agents(work_dir, model_name="o4-mini", temperature=None, reasoning_effort=None, branching_factor=3, 134 | user_query=None, experiment_first=False, code_timeout=30 * 60) -> dict[str, ConversableAgent]: 135 | llm_config = get_openai_config(api_key=os.getenv("OPENAI_API_KEY"), model_name=model_name, temperature=temperature, 136 | reasoning_effort=reasoning_effort) 137 | 138 | # Create token limit transform 139 | token_limit_capability = transform_messages.TransformMessages(transforms=[ 140 | transforms.MessageTokenLimiter(max_tokens_per_message=10_000, min_tokens=12_000) 141 | ]) 142 | 143 | # Experiment Generator 144 | _user_query_or_empty = f"{user_query}\n\n" if user_query is not None else "" 145 | 146 | experiment_generator = ConversableAgent( 147 | name="experiment_generator", 148 | llm_config={**llm_config, 149 | "response_format": ExperimentList if not experiment_first else ExperimentHypothesisList}, 150 | system_message=( 151 | 'You are a research scientist who is interested in doing open-ended, data-driven research using the provided dataset(s). ' 152 | f'{_user_query_or_empty}' 153 | f'Be creative and think of new and interesting verifiable {"experiments" if experiment_first else "hypotheses"} and corresponding {"hypotheses" if experiment_first else "experiments"}. ' 154 | 'The hypothesis should be a falsifiable statement that can be sufficiently tested by an experiment using the provided data. ' 155 | 'Explain in natural language what this experiment plan is so that a programmer can implement it (do not provide the code yourself). ' 156 | 'Remember, you are interested in open-ended research, so your proposals may be exploratory in nature and may have only an indirect connection to the previous explorations provided. ' 157 | 'Here are some instructions that you must follow:\n' 158 | '1. Strictly use only the dataset(s) provided and do not simulate dummy/synthetic data or columns that cannot be derived from the existing columns.\n' 159 | '2. Each hypothesis (and experiment plan) should be creative, independent, and self-contained.\n' 160 | '3. Use the prior experiments/hypotheses as inspiration to think of interesting and creative new experiments/hypotheses. However, do not repeat the same experiments/hypotheses.\n\n' 161 | 'Here is a possible approach to coming up with a new hypothesis and experiment plan:\n' 162 | '1. Find an interesting context: this could be a specific subset of the data. E.g., if the dataset has multiple categorical variables, you could split the data based on specific values of such variables, which would then allow you to validate a hypothesis in the specific contexts defined by the values of those variables.\n' 163 | '2. Find interesting variables: these could be the columns in the dataset that you find interesting or relevant to the context. You are allowed and encouraged to create composite variables derived from the existing variables.\n' 164 | '3. Find interesting relationships: these are interactions between the variables that you find interesting or relevant to the context. You are encouraged to propose experiments involving complex predictive or causal models.\n' 165 | '4. You must require that your proposed hypotheses are verifiable using robust statistical tests. Remember, your programmer can install python packages via pip which can allow it to write code for complex statistical analyses.\n' 166 | '5. Multiple datasets: If you are provided with more than one dataset, then try to also propose hypotheses that utilize contexts, variables, and relationships across datasets, e.g., this may involve using join or similar operations.\n\n' 167 | 'Generally, in typical data-driven research, you will need to explore and visualize the data for possible high-level insights, clean, transform, or derive new variables from the dataset to be suited for the investigation, deep-dive into specific parts of the data for fine-grained analysis, perform data modeling, and run statistical tests. ' 168 | f'Now, generate exactly {branching_factor} new hypotheses with their experiment plans.' 169 | ), 170 | human_input_mode="NEVER", 171 | ) 172 | 173 | install_snippet = ("""\nimport subprocess 174 | import sys 175 | 176 | def install(package): 177 | subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", package])\n\n\n""") 178 | 179 | # Experiment Programmer 180 | experiment_programmer = ConversableAgent( 181 | name="experiment_programmer", 182 | llm_config={**llm_config, "response_format": ExperimentCode}, 183 | system_message=( 184 | 'You are a scientific experiment programmer proficient in writing python code given an experiment plan. ' 185 | 'Your code will be included in a python file that is executed and any relevant results should be printed to standard out or presented using plt.show appropriately. ' 186 | 'Make sure you provide python code in the proper format to execute. ' 187 | 'Ensure your code is clean and concise, and include debug statements only when they are absolutely necessary. ' 188 | 'Use only the dataset given and do not assume any other files are available. The state is not preserved between code blocks, so do not assume any variables or imports from previous code blocks. ' 189 | 'Import any libraries you need to use. Always attempt to import a library before installing it (it may already be installed). ' 190 | 'If you need to install a library, use the following code example:' 191 | f'{install_snippet}' 192 | 'When installing python packages, use the --quiet option to minimize unnecessary output.' 193 | 'Prefer using installed libraries over installing new libraries whenever possible. ' 194 | 'If possible, instead of downgrading library versions, try to adapt your code to work with a more updated version that is already installed. ' 195 | 'Never attempt to create a new environment. Always use the current environment. ' 196 | 'If the code requires generating plots, use plt.show (not plt.savefig). ' 197 | 'Avoid printing the whole data structure to the console directly if it is large; instead, print concise results that are directly relevant to the experiment. ' 198 | 'You are allowed 6 total attempts to run the code, including debugging attempts.\n\n' 199 | 'Debugging instructions:\n' 200 | '1. Only debug if you are either unsure about the executability or validity of the code (i.e., whether it satisfies the proposed experiment).\n' 201 | '2. If the code you are writing is intended for debugging, the first line of your code must be "# [debug]" only.\n' 202 | '3. DO NOT use "[debug]" anywhere else in your code.\n' 203 | '4. DO NOT combine any debug code and the actual experiment implementation code; keep them separate.\n' 204 | '5. For each experiment, you are allowed to debug at most 3 times.\n' 205 | '6. As much as possible, minimize the number of debugging steps you use.' 206 | ), 207 | human_input_mode="NEVER", 208 | ) 209 | 210 | # Experiment Analyst 211 | experiment_analyst = ConversableAgent( 212 | name="experiment_code_analyst", 213 | llm_config={**llm_config, "response_format": ExperimentAnalyst}, 214 | system_message=( 215 | 'You are a research scientist responsible for evaluating the code execution output for a scientific experiment written by a programmer. ' 216 | 'If no code was executed, there was an error, or the code fails silently, return the success status as **false**. ' 217 | 'If the code includes a line "# [debug]" i.e "[debug]" as a comment, strictly treat this as a debugging experiment. ' 218 | 'In such cases, strictly return the success status as **false**, provide information that it was a debug code execution, ' 219 | 'give feedback and request the experiment to be retried with the new information. ' 220 | 'Otherwise, analyze the results and provide a short summary of the code output.' 221 | ), 222 | human_input_mode="NEVER", 223 | ) 224 | 225 | # Experiment Reviewer 226 | experiment_reviewer = ConversableAgent( 227 | name="experiment_reviewer", 228 | llm_config={**llm_config, "response_format": ExperimentReviewer}, 229 | system_message=( 230 | 'You are a research scientist responsible for holistically reviewing the entire experiment pipeline, i.e., the generated code, the output, and the analysis w.r.t. the original experiment plan. ' 231 | 'Assess whether the experiment was faithfully implemented, i.e., whether the implementation follows the experiment plan without significant deviation and whether the hypothesis was in fact tested sufficiently. ' 232 | 'If you find issues or inconsistencies in any part of the experiment pipeline, return the success status as **false** and provide feedback about what is wrong. ' 233 | 'Otherwise, return the success status as **true** and provide a summary of the hypothesis, experiment results, and findings.' 234 | ), 235 | human_input_mode="NEVER", 236 | ) 237 | 238 | # Experiment Reviser 239 | experiment_reviser = ConversableAgent( 240 | name="experiment_reviser", 241 | llm_config={**llm_config, "response_format": Experiment}, 242 | system_message=( 243 | 'You are a research scientist revisiting the most recent experiment, which could not be conducted correctly due to issues in the code or the formulation of the experiment plan,' 244 | 'as indicated by the reviewer. Your goal is to revise this failed experiment plan by addressing the issues and limitations pointed out by the reviewer. ' 245 | 'The revised experiment plan should still aim to validate the most recent hypothesis. ' 246 | 'Do not provide the code yourself but explain in natural language what the experiment should do for a programmer. ' 247 | 'Strictly use only the dataset provided and do not create synthetic data or columns that cannot be derived from the given columns. ' 248 | 'The experiment should be creative, independent, and self-contained. ' 249 | 'Generally, in typical data-driven research, you will need to explore and visualize the data for possible high-level insights, clean, transform, or derive new variables from the dataset to be suited for the investigation, deep-dive into specific parts of the data for fine-grained analysis, perform data modeling, and run statistical tests.' 250 | ), 251 | human_input_mode="NEVER", 252 | ) 253 | 254 | ## Timeout Code Executor 255 | executor = LocalCommandLineCodeExecutor( 256 | timeout=code_timeout, # Timeout in seconds 257 | work_dir=work_dir, 258 | # virtual_env_context=create_virtual_env(os.path.join(work_dir, ".venv")) # TODO: Fix virtual env creation 259 | ) 260 | # TODO: Fix docker-based execution 261 | # executor = DockerCommandLineCodeExecutor( 262 | # # image="python:3.11-alpine", 263 | # timeout=30 * 60, # Timeout in seconds 264 | # work_dir=work_dir, 265 | # # virtual_env_context=create_virtual_env(os.path.join(work_dir, ".venv")) 266 | # ) 267 | 268 | # Create an agent with code executor configuration. 269 | code_executor = ConversableAgent( 270 | "code_executor", 271 | llm_config=False, 272 | code_execution_config={"executor": executor}, 273 | human_input_mode="NEVER", 274 | ) 275 | # Apply image analysis patch to the code executor 276 | transform_messages_capability = transform_messages.TransformMessages(transforms=[CodeBlockWrapperTransform()]) 277 | transform_messages_capability.add_to_agent(code_executor) 278 | 279 | user_proxy = UserProxyAgent( 280 | name="user_proxy", 281 | description="Responsible for providing the initial query", 282 | code_execution_config=False, 283 | human_input_mode="NEVER" 284 | ) 285 | 286 | agents = [experiment_generator, experiment_programmer, experiment_analyst, experiment_reviewer, experiment_reviser, 287 | code_executor, user_proxy] 288 | 289 | # Apply token limit to all agents 290 | for agent in agents: 291 | token_limit_capability.add_to_agent(agent) 292 | 293 | agents_dict = {agent.name: agent for agent in agents} 294 | return agents_dict 295 | -------------------------------------------------------------------------------- /src/mcts.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import Optional 4 | import random 5 | import json 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | 10 | from src.mcts_utils import get_query_from_experiment, get_experiment_from_query, get_node_level_idx, get_context_string 11 | from src.utils import try_loading_dict 12 | 13 | 14 | class MCTSNode(object): 15 | _creation_counter = 0 16 | 17 | def __init__(self, level=None, node_idx=None, hypothesis=None, query=None, parent_idx=None, parent=None, 18 | untried_experiments=None, allow_generate_experiments=False, experiment_plan=None, code=None, 19 | code_output=None, analysis=None, review=None, creation_idx=None): 20 | # Tree attributes 21 | self.creation_idx = creation_idx if creation_idx is not None else MCTSNode._creation_counter 22 | MCTSNode._creation_counter += 1 # Used to replay MCTS from log files 23 | self.time_elapsed = None 24 | self.level = level 25 | self.node_idx = node_idx 26 | self.id = f"node_{self.level}_{self.node_idx}" if self.level is not None and self.node_idx is not None else None 27 | self.children = [] 28 | self.parent = parent # MCTSNode 29 | if self.parent is not None: 30 | self.parent_idx = self.parent.node_idx 31 | self.parent_id = self.parent.id 32 | self.parent.children.append(self) 33 | else: 34 | self.parent_idx = parent_idx 35 | self.parent_id = f"node_{self.level - 1}_{self.parent_idx}" if self.parent_idx is not None else None 36 | self.success = None 37 | 38 | # Agent attributes 39 | self.query = query 40 | self.messages = [] 41 | self.untried_experiments = copy.deepcopy(untried_experiments) if untried_experiments is not None else [] 42 | self.tried_experiments = [] # Track all tried experiments 43 | self.allow_generate_experiments = allow_generate_experiments 44 | 45 | # Experiment attributes 46 | self.hypothesis = hypothesis 47 | self.experiment_plan = experiment_plan 48 | self.code = code 49 | self.code_output = code_output 50 | self.analysis = analysis 51 | self.review = review 52 | 53 | # MCTS attributes 54 | self.visits = 0 # Visits to this node or its children 55 | self.value = 0. # Number of surprising hypotheses 56 | self.self_value = 0. # Value of this node only (not aggregated from children) 57 | 58 | # Belief attributes 59 | self.surprising: Optional[bool] = None 60 | self.prior = None 61 | self.posterior = None 62 | self.belief_change: Optional[float] = None # Change in belief from prior to posterior 63 | self.kl_divergence: Optional[float] = None 64 | 65 | def init_from_dict(self, data): 66 | """Initialize node attributes from a dictionary.""" 67 | # Tree attributes 68 | self.creation_idx = data.get('creation_idx', MCTSNode._creation_counter) 69 | self.time_elapsed = data.get('time_elapsed', self.time_elapsed) 70 | self.id = data.get('id', None) 71 | if self.id is not None: 72 | self.level, self.node_idx = get_node_level_idx(self.id) 73 | else: 74 | self.level = data['level'] 75 | self.node_idx = data['node_idx'] 76 | self.id = f"node_{self.level}_{self.node_idx}" 77 | self.parent_id = data.get('parent_id', self.parent_id) 78 | if self.parent_id is not None: 79 | _, self.parent_idx = get_node_level_idx(self.parent_id) 80 | else: 81 | self.parent_idx = data.get('parent_idx', self.parent_idx) 82 | if self.parent_idx is not None: 83 | self.parent_id = f"node_{self.level - 1}_{self.parent_idx}" 84 | self.success = data.get('success', self.success) 85 | 86 | # Agent attributes 87 | self.query = data.get('query', "N/A") 88 | self.messages = data.get('messages', self.messages) 89 | self.untried_experiments = data.get('untried_experiments', self.untried_experiments) 90 | # self.tried_experiments = data.get('tried_experiments', self.tried_experiments) 91 | self.allow_generate_experiments = self.allow_generate_experiments and self.level > 0 92 | 93 | # Experiment attributes 94 | self.hypothesis = data.get('hypothesis', self.hypothesis) 95 | self.experiment_plan = data.get('experiment_plan', self.experiment_plan) 96 | self.code = data.get('code', self.code) 97 | self.code_output = data.get('code_output', self.code_output) 98 | self.analysis = data.get('analysis', self.analysis) 99 | self.review = data.get('review', self.review) 100 | 101 | # MCTS attributes 102 | self.visits = data.get('visits', self.visits) 103 | self.value = data.get('value', self.value) 104 | self.self_value = data.get('self_value', self.self_value) 105 | 106 | # Belief attributes 107 | self.surprising = data.get('surprising', self.surprising) 108 | from src.beliefs import BELIEF_MODE_TO_CLS # Import here to avoid circular import issues 109 | if 'prior' in data and data['prior']: 110 | belief_cls = BELIEF_MODE_TO_CLS[data['prior']['_type']] 111 | self.prior = belief_cls.DistributionFormat(**data['prior']) 112 | if 'posterior' in data and data['posterior']: 113 | belief_cls = BELIEF_MODE_TO_CLS[data['posterior']['_type']] 114 | self.posterior = belief_cls.DistributionFormat(**data['posterior']) 115 | self.belief_change = data.get('belief_change', self.belief_change) 116 | self.kl_divergence = data.get('kl_divergence', self.kl_divergence) 117 | 118 | def get_next_experiment(self, experiment_generator=None, n_retry=3): 119 | """ 120 | Returns the next untried experiment. If none left and generating experiments is allowed, generates more using 121 | the experiment generator agent. 122 | """ 123 | new_experiment, new_query = None, None 124 | 125 | if n_retry > 0: 126 | if self.untried_experiments: 127 | idx = random.randrange(len(self.untried_experiments)) 128 | new_experiment = self.untried_experiments.pop(idx) 129 | self.tried_experiments.append(new_experiment) 130 | elif self.allow_generate_experiments and experiment_generator is not None: 131 | # Generate new experiments on-demand, providing all previous experiments as context 132 | _messages = self.messages + [{ 133 | "role": "user", 134 | "content": f"Generate new experiments given these previously attempted experiments: {json.dumps(self.tried_experiments)}" 135 | }] 136 | _reply = experiment_generator.generate_reply(messages=_messages) 137 | try: 138 | experiments = try_loading_dict(_reply).get("experiments", []) 139 | except (json.JSONDecodeError, TypeError): 140 | experiments = [] 141 | self.untried_experiments = experiments.copy() 142 | if self.untried_experiments: 143 | idx = random.randrange(len(self.untried_experiments)) 144 | new_experiment = self.untried_experiments.pop(idx) 145 | self.tried_experiments.append(new_experiment) 146 | 147 | if new_experiment is not None: 148 | try: 149 | new_query = get_query_from_experiment(new_experiment) 150 | except: 151 | pass 152 | if new_query is None: 153 | return self.get_next_experiment(experiment_generator=experiment_generator, n_retry=n_retry - 1) 154 | 155 | return new_experiment, new_query 156 | 157 | def has_untried_experiments(self): 158 | return bool(self.untried_experiments) or self.allow_generate_experiments 159 | 160 | def to_dict(self): 161 | return { 162 | "id": self.id, 163 | "success": self.success, 164 | "parent_id": self.parent_id, 165 | "creation_idx": self.creation_idx, 166 | "time_elapsed": self.time_elapsed, 167 | "visits": self.visits, 168 | "value": self.value, 169 | "self_value": self.self_value, 170 | "surprising": self.surprising, 171 | "belief_change": self.belief_change, 172 | "kl_divergence": self.kl_divergence, 173 | "prior": self.prior.to_dict() if self.prior else None, 174 | "posterior": self.posterior.to_dict() if self.posterior else None, 175 | "hypothesis": self.hypothesis, 176 | "experiment_plan": self.experiment_plan, 177 | "code": self.code, 178 | "code_output": self.code_output, 179 | "analysis": self.analysis, 180 | "review": self.review, 181 | "untried_experiments": self.untried_experiments, 182 | "tried_experiments": self.tried_experiments, 183 | "query": self.query, 184 | "messages": self.messages, 185 | } 186 | 187 | def read_experiment_from_messages(self, store_new_experiments=False): 188 | """Extracts experiment details from messages and updates the node's attributes.""" 189 | latest_experiment = None 190 | was_revised = False 191 | latest_programmer = None 192 | latest_code_executor = None 193 | latest_analyst = None 194 | latest_reviewer = None 195 | latest_reviewer_feedback = "N/A" 196 | latest_reviewer_success = False 197 | latest_experiment_generator = None 198 | 199 | for msg in reversed(self.messages): 200 | if not latest_experiment and msg.get("name") in ["user_proxy", "experiment_reviser"]: 201 | latest_experiment = msg.get("content") 202 | if msg.get("name") == "experiment_reviser": 203 | was_revised = True 204 | elif not latest_programmer and msg.get("name") == "experiment_programmer": 205 | latest_programmer = try_loading_dict(msg.get("content")).get("code", "N/A") 206 | elif not latest_code_executor and msg.get("name") == "code_executor": 207 | latest_code_executor = msg.get("content") 208 | elif not latest_analyst and msg.get("name") in ["experiment_analyst", "experiment_code_analyst"]: 209 | latest_analyst = try_loading_dict(msg.get("content")).get("analysis", "N/A") 210 | elif not latest_reviewer and msg.get("name") == "experiment_reviewer": 211 | latest_reviewer = try_loading_dict(msg.get("content")) 212 | latest_reviewer_feedback = latest_reviewer.get("feedback", "N/A") 213 | if latest_reviewer_feedback == "": 214 | latest_reviewer_feedback = "N/A" 215 | latest_reviewer_success = latest_reviewer.get("success", False) 216 | elif not latest_experiment_generator and msg.get("name") == "experiment_generator": 217 | latest_experiment_generator = try_loading_dict(msg.get("content")).get("experiments", []) 218 | 219 | if (latest_experiment and latest_programmer and 220 | latest_code_executor and latest_analyst and latest_reviewer): 221 | break 222 | 223 | if was_revised: 224 | latest_experiment_obj = try_loading_dict(latest_experiment) 225 | # Change what the query should now be based on the revised experiment 226 | self.query = get_query_from_experiment(latest_experiment_obj) 227 | else: 228 | latest_experiment_obj = get_experiment_from_query(latest_experiment) # assuming it is a query string 229 | 230 | self.hypothesis = latest_experiment_obj.get("hypothesis", "N/A") 231 | self.experiment_plan = latest_experiment_obj.get("experiment_plan", "N/A") 232 | self.code = latest_programmer 233 | self.code_output = latest_code_executor 234 | self.analysis = latest_analyst 235 | self.review = latest_reviewer_feedback 236 | self.success = latest_reviewer_success 237 | 238 | # Store new experiments into untried_experiments 239 | if store_new_experiments and latest_experiment_generator: 240 | self.untried_experiments += latest_experiment_generator 241 | 242 | def get_context(self, include_code_output=False) -> None | str: 243 | """Returns the node's hypothesis, experiment, output, analysis, and review.""" 244 | if len(self.messages) == 0: 245 | return None 246 | context_str = get_context_string(self.query, self.code_output, self.analysis, self.review, 247 | include_code_output=include_code_output) 248 | return context_str 249 | 250 | def get_path_context(self, k: Optional[int] = None, skip_root=True) -> None | list: 251 | """Returns messages from the node to the root 252 | 253 | Args: 254 | k: Optional maximum number of parent levels to include. If None, includes all parents. 255 | skip_root: If True, skips the root node in the context. 256 | """ 257 | context = self.parent.get_context() if self.parent is not None else None 258 | if context is not None: 259 | if skip_root and self.parent.level <= 1: 260 | return [] 261 | context = [context] 262 | k_remaining = None if k is None else k - 1 263 | if context is not None and self.parent is not None and (k_remaining is None or k_remaining > 0): 264 | parent_context = self.parent.get_path_context(k=k_remaining, skip_root=skip_root) 265 | if parent_context is not None: 266 | context = parent_context + context 267 | return context 268 | 269 | def update_counts(self, visits: int = 1, reward: float = 0): 270 | """Update the visit count and value of this node and its parents.""" 271 | self.visits += visits 272 | self.value += reward 273 | if self.parent is not None: 274 | self.parent.update_counts(visits=visits, reward=reward) 275 | 276 | 277 | def default_mcts_selection(exploration_weight): 278 | def select(node, nodes_by_level): 279 | # Traverse the tree until we find a node with untried experiments 280 | while node.children and not node.has_untried_experiments(): 281 | # Select the child with the highest UCB1 value 282 | node = max(node.children, key=lambda n: ucb1(n, exploration_weight)) 283 | return node 284 | 285 | return select 286 | 287 | 288 | def progressive_widening(k, alpha, exploration_weight=1.0): 289 | """ 290 | Create a progressive widening selection function. 291 | 292 | Args: 293 | k: Progressive widening constant. 294 | alpha: Progressive widening exponent. 295 | exploration_weight: Exploration weight for UCB1 selection method. 296 | 297 | Returns: 298 | A callable function that accepts a `node` and returns the selected node. 299 | """ 300 | 301 | def select(node, nodes_by_level): 302 | # Get the number of visits and children for the current node 303 | num_visits = node.visits 304 | num_children = len(node.children) 305 | 306 | # Check if we can add a new child based on the progressive widening condition 307 | if (num_children < k * (num_visits ** alpha)) and node.has_untried_experiments(): 308 | # Sample a new child (expand the tree) 309 | return node 310 | 311 | # Otherwise, recursively sample from the children 312 | if node.children: 313 | # Select a child node recursively using the same selection function 314 | return select(max(node.children, key=lambda n: ucb1(n, exploration_weight)), nodes_by_level) 315 | 316 | # If no children exist, return the current node 317 | return node 318 | 319 | return select 320 | 321 | 322 | def progressive_widening_all(k, alpha, exploration_weight=1.0): 323 | """ 324 | Create a progressive widening selection function with an alternative implementation that selects the node 325 | which has the highest UCB1 value from the entire set of nodes. 326 | 327 | Args: 328 | k: Progressive widening constant. 329 | alpha: Progressive widening exponent. 330 | exploration_weight: Exploration weight for UCB1 selection method. 331 | 332 | Returns: 333 | A callable function that accepts a `node` and returns the selected node. 334 | """ 335 | 336 | def select(node, nodes_by_level): 337 | all_nodes = [n for level, nodes in nodes_by_level.items() if level > 0 for n in nodes] 338 | # Sort all nodes by UCB1 value 339 | all_nodes_sorted = sorted(all_nodes, key=lambda n: ucb1(n, exploration_weight), reverse=True) 340 | 341 | for _node in all_nodes_sorted: 342 | num_visits = _node.visits 343 | num_children = len(_node.children) 344 | # If we can add a new child based on the progressive widening condition 345 | if (num_children < k * (num_visits ** alpha)) and _node.has_untried_experiments(): 346 | # Sample a new child (expand the tree) 347 | return _node 348 | 349 | # If no children exist, return the current node 350 | return node 351 | 352 | return select 353 | 354 | 355 | def ucb1_recursive(exploration_weight=1.0): 356 | """ 357 | Create a UCB1 traversal selection function. 358 | 359 | Args: 360 | exploration_weight: Exploration weight for UCB1 selection method. 361 | 362 | Returns: 363 | A callable function that accepts a `node` and returns the selected node. 364 | """ 365 | 366 | def select(node, nodes_by_level): 367 | # Sort self and children by UCB1 value 368 | all_nodes = [node] + node.children 369 | sorted_nodes = sorted(all_nodes, key=lambda n: ucb1(n, exploration_weight), reverse=True) 370 | 371 | for best_node in sorted_nodes: 372 | if best_node.has_untried_experiments(): 373 | if best_node is node: 374 | return best_node 375 | return select(best_node, nodes_by_level) 376 | 377 | return select 378 | 379 | 380 | def beam_search(branching_factor, beam_width, log_dirname=None): 381 | """ 382 | Create a beam search selection function. 383 | 384 | Args: 385 | branching_factor: Maximum number of children per node 386 | beam_width: Number of nodes to keep in beam 387 | 388 | Returns: 389 | A callable function that accepts a root node and returns selected node 390 | """ 391 | beam = [] # Current nodes in beam 392 | 393 | def select(root, nodes_by_level): 394 | nonlocal beam 395 | 396 | # Initialize beam with root if empty 397 | if not beam: 398 | beam = [root] 399 | # Log initial beam state 400 | if log_dirname: 401 | beam_state = [{"level": node.level, "node_idx": node.node_idx} for node in beam] 402 | with open(os.path.join(log_dirname, f"beam_level_{root.level}.json"), "w") as f: 403 | json.dump(beam_state, f, indent=2) 404 | 405 | # Try nodes in current beam 406 | for node in beam: 407 | if node.has_untried_experiments() and len(node.children) < branching_factor: 408 | return node 409 | 410 | # All nodes in beam are exhausted, select new beam 411 | all_children = [] 412 | for node in beam: 413 | all_children.extend(node.children) 414 | 415 | # Sort children by UCB1 score and select top beam_width 416 | if all_children: 417 | beam = sorted(all_children, key=lambda n: ucb1(n), reverse=True)[:beam_width] 418 | # Log new beam state 419 | if log_dirname: 420 | beam_state = [{"level": node.level, "node_idx": node.node_idx} for node in beam] 421 | level = beam[0].level if beam else 0 422 | with open(os.path.join(log_dirname, f"beam_level_{level}.json"), "w") as f: 423 | json.dump(beam_state, f, indent=2) 424 | return select(root, nodes_by_level) # Recurse with new beam 425 | 426 | return beam[0] # Default to first beam node if no children 427 | 428 | return select 429 | 430 | 431 | def ucb1(node, exploration_weight=1.0): 432 | """Upper Confidence Bound 1 calculation for node selection""" 433 | if node.visits == 0: 434 | return float('inf') 435 | exploitation_term = node.value / node.visits 436 | exploration_term = 0 437 | if node.parent: 438 | exploration_term = np.sqrt(2 * np.log(node.parent.visits) / node.visits) 439 | return exploitation_term + exploration_weight * exploration_term 440 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | from time import time 5 | 6 | from src.agents import get_agents 7 | from src.mcts import MCTSNode, default_mcts_selection, beam_search, progressive_widening, progressive_widening_all, \ 8 | ucb1_recursive 9 | from src.dataset import get_datasets_fpaths, get_load_dataset_experiment 10 | from src.logger import TreeLogger 11 | 12 | from src.beliefs import calculate_prior_and_posterior_beliefs 13 | from datetime import datetime 14 | import shutil 15 | 16 | from src.args import ArgParser 17 | from src.mcts_utils import load_mcts_from_json, save_nodes, get_msgs_from_latest_query, setup_group_chat, \ 18 | print_node_info, get_self_value, get_context_string 19 | 20 | 21 | def select_node(selection_method, root, nodes_by_level, n_warmstart=0): 22 | """ 23 | Select the next node to expand in MCTS using the provided selection method. 24 | 25 | Args: 26 | selection_method: Function to select nodes in MCTS. 27 | root: Root MCTSNode to select from. 28 | nodes_by_level: Dictionary of nodes by level. 29 | n_warmstart: Number of warmstart experiments to run after data loading but before MCTS selection. 30 | 31 | Returns: 32 | Selected MCTSNode for expansion. 33 | """ 34 | n_children_at_data_loader = len(nodes_by_level[2]) 35 | 36 | # If there are warmstart experiments left to run, select the data loader node to execute the next experiment. 37 | if len(nodes_by_level[1]) > 0 and (n_warmstart - n_children_at_data_loader) > 0: 38 | return nodes_by_level[1][0] 39 | 40 | return selection_method(root, nodes_by_level) 41 | 42 | 43 | def compute_and_store_reward(node, belief_model_name, belief_temperature, reasoning_effort, 44 | n_belief_samples, implicit_bayes_posterior, surprisal_width, belief_mode, 45 | use_binary_reward, all_surprisals=None, use_online_beliefs=False, 46 | evidence_weight=1.0, kl_scale=20.0, reward_mode="belief", TEMP_LOG=None): 47 | s_conditioned_prior = None 48 | evidence_msg = [] 49 | 50 | # If there are past surprisal, computed the s-conditioned prior 51 | if all_surprisals is not None and len(all_surprisals) > 0 and use_online_beliefs: 52 | # Build evidence message for prior belief elicitation 53 | evidence_msg = [{ 54 | "role": "user", 55 | "content": "Previous study:\n\n" + get_context_string( 56 | hyp_exp_query=f"Hypothesis: {nodes_by_level[level_index[0]][level_index[1]].hypothesis}", 57 | analysis=nodes_by_level[level_index[0]][level_index[1]].analysis, 58 | review=nodes_by_level[level_index[0]][level_index[1]].review, 59 | belief_mean=nodes_by_level[level_index[0]][level_index[1]].posterior.mean, 60 | include_code_output=False 61 | ) 62 | } for level_index in all_surprisals] 63 | try: 64 | pt_prior, s_conditioned_prior, _, _ = calculate_prior_and_posterior_beliefs( 65 | node, 66 | model=belief_model_name, 67 | temperature=belief_temperature, 68 | reasoning_effort=reasoning_effort, 69 | n_samples=n_belief_samples, 70 | implicit_bayes_posterior=implicit_bayes_posterior, 71 | surprisal_width=surprisal_width, 72 | belief_mode=belief_mode, 73 | evidence_msg=evidence_msg 74 | ) 75 | except ValueError as e: 76 | print(f"Error for node {node.id}: {e}") 77 | node.success = False 78 | return 79 | 80 | # TEMPORARY LOGGING 81 | if TEMP_LOG is not None: 82 | TEMP_LOG.append({ 83 | 'node_id': node.id, 84 | 'belief_change': None, 85 | 'kl_divergence': None, 86 | 'hypothesis': node.hypothesis, 87 | 'pt_prior': pt_prior.to_dict(), 88 | 'surprisal_evidence': [e['content'] for e in evidence_msg], 89 | 's_conditioned_prior': s_conditioned_prior.to_dict(), 90 | }) 91 | 92 | # Build the evidence message for the current node 93 | evidence_msg.append({ 94 | "role": "user", 95 | "content": "Current experiment:\n\n" + get_context_string( 96 | hyp_exp_query=node.query, 97 | code_output=node.code_output, 98 | analysis=node.analysis, 99 | review=node.review, 100 | include_code_output=False 101 | ) 102 | }) 103 | 104 | # Compute the prior and posterior beliefs for the current node 105 | try: 106 | prior, posterior, belief_change, kl_divergence = calculate_prior_and_posterior_beliefs( 107 | node, 108 | model=belief_model_name, 109 | temperature=belief_temperature, 110 | reasoning_effort=reasoning_effort, 111 | n_samples=n_belief_samples, 112 | implicit_bayes_posterior=implicit_bayes_posterior, 113 | surprisal_width=surprisal_width, 114 | belief_mode=belief_mode, 115 | prior=s_conditioned_prior, 116 | evidence_msg=evidence_msg, 117 | evidence_weight=evidence_weight 118 | ) 119 | except ValueError as e: 120 | print(f"Error for node {node.id}: {e}") 121 | node.success = False 122 | return 123 | 124 | # TEMPORARY LOGGING 125 | if TEMP_LOG is not None and len(TEMP_LOG) > 0: 126 | # Generate the posterior without surprisals 127 | _, _posterior, _belief_change, _kl_divergence = calculate_prior_and_posterior_beliefs( 128 | node, 129 | model=belief_model_name, 130 | temperature=belief_temperature, 131 | reasoning_effort=reasoning_effort, 132 | n_samples=n_belief_samples, 133 | implicit_bayes_posterior=implicit_bayes_posterior, 134 | surprisal_width=surprisal_width, 135 | belief_mode=belief_mode, 136 | prior=pt_prior, 137 | evidence_msg=evidence_msg[-1:], 138 | evidence_weight=evidence_weight 139 | ) 140 | 141 | TEMP_LOG[-1]['current_evidence'] = evidence_msg[-1]['content'] 142 | TEMP_LOG[-1]['online_posterior'] = posterior.to_dict() 143 | TEMP_LOG[-1]['belief_change'] = belief_change 144 | TEMP_LOG[-1]['kl_divergence'] = kl_divergence 145 | TEMP_LOG[-1]['offline_posterior'] = _posterior.to_dict() 146 | TEMP_LOG[-1]['offline_belief_change'] = _belief_change 147 | TEMP_LOG[-1]['offline_kl_divergence'] = _kl_divergence 148 | TEMP_LOG[-1]['current_surprisals'] = all_surprisals.copy() 149 | 150 | print(f"\n\n======================= SURPRISAL-CONDITION BELIEFS =======================\n") 151 | print(json.dumps({k: v for k, v in TEMP_LOG[-1].items() if 152 | k in ["pt_prior", "s_conditioned_prior", "online_posterior", "offline_posterior"]}, indent=2)) 153 | 154 | node.prior = prior 155 | node.posterior = posterior 156 | node.belief_change = belief_change 157 | node.kl_divergence = kl_divergence 158 | # Compute reward and surprisal 159 | node.self_value, node.surprising = get_self_value(belief_change=node.belief_change, 160 | kl_divergence=node.kl_divergence, 161 | binary=use_binary_reward, width=surprisal_width, 162 | kl_scale=kl_scale, mode=reward_mode) 163 | if node.surprising: 164 | # Store the surprisal 165 | all_surprisals.append((node.level, node.node_idx)) 166 | # TODO: Update all past nodes with the new surprisal set 167 | 168 | 169 | def run_mcts( 170 | root, 171 | nodes_by_level, 172 | dataset_paths, 173 | log_dirname, 174 | work_dir, 175 | model_name="gpt-4o", 176 | belief_model_name="gpt-4o", 177 | max_iterations=100, 178 | branching_factor=8, 179 | max_rounds=100000, 180 | selection_method=None, 181 | allow_generate_experiments=False, 182 | n_belief_samples=30, 183 | k_parents=3, 184 | temperature=1.0, 185 | belief_temperature=1.0, 186 | reasoning_effort="medium", 187 | implicit_bayes_posterior=False, 188 | surprisal_width=0.2, 189 | user_query=None, 190 | belief_mode="categorical", 191 | use_binary_reward=True, 192 | run_dedupe=True, 193 | experiment_first=False, 194 | code_timeout=30 * 60, 195 | n_warmstart=0, 196 | use_online_beliefs=False, 197 | evidence_weight=1.0, 198 | kl_scale=20.0, 199 | reward_mode="belief_and_kl", 200 | warmstart_experiments=None 201 | ): 202 | """ 203 | Run AutoDS exploration. In MCTS, root node level=0 is a dummy node with no experiment, level=1 is the first real node with the dataset loading experiment, levels > 1 are the actual MCTS nodes with hypotheses and experiments. 204 | 205 | Args: 206 | root: Root MCTSNode to continue from. 207 | nodes_by_level: Dictionary to store nodes by level. 208 | dataset_paths: List of paths to dataset files. 209 | log_dirname: Directory to save logs and MCTS nodes. 210 | work_dir: Working directory for agents. 211 | model_name: LLM model name for agents. 212 | belief_model_name: LLM model name for belief distribution agent. 213 | max_iterations: Maximum number of MCTS iterations. 214 | branching_factor: Maximum number of children per node. 215 | max_rounds: Maximum number of rounds for the group chat. 216 | selection_method: Function to select nodes in MCTS (default is UCB1). 217 | allow_generate_experiments: Whether to allow nodes to generate new experiments on demand. 218 | n_belief_samples: Number of samples for belief distribution evaluation. 219 | k_parents: Number of parent levels to include in logs (None for all). 220 | temperature: Temperature setting for all agents (except belief agent). 221 | belief_temperature: Temperature setting for the belief agent. 222 | reasoning_effort: Reasoning effort for OpenAI o-series models. 223 | implicit_bayes_posterior: Whether to use the belief samples with evidence as the direct posterior or to use a Bayesian update that explicitly combines it with the prior. 224 | surprisal_width: Minimum difference in mean prior and posterior probabilities required to count as a surprisal. 225 | user_query: Custom user query to condition experiment generation during exploration. 226 | belief_mode: Belief elicitation mode (boolean, categorical, categorical_numeric, or probability). 227 | use_binary_reward: Whether to use binary reward for MCTS instead of a continuous reward (belief change). 228 | run_dedupe: Whether to deduplicate nodes before saving to JSON and CSV. 229 | experiment_first: If True, an experiment will be generated before its hypothesis. 230 | code_timeout: Timeout for code execution in seconds (default is 30 minutes). 231 | n_warmstart: Number of warmstart experiments to run after data loading but before MCTS selection. 232 | use_online_beliefs: Whether to use online beliefs (i.e., beliefs updated with evidence from previous nodes). 233 | evidence_weight: Weight for the experimental evidence for posterior calculation. 234 | kl_scale: Normalization factor for KL divergence in reward calculation. 235 | reward_mode: Mode for reward calculation (belief, kl, or belief_and_kl). 236 | warmstart_experiments: Path to JSON file with warmstart experiments to run after data loading but before MCTS selection. 237 | """ 238 | # Setup logger 239 | logger = TreeLogger(log_dirname) 240 | 241 | # Track time 242 | start_time = time() 243 | 244 | # Create work directory if it doesn't exist 245 | os.makedirs(work_dir, exist_ok=True) 246 | 247 | # Copy the dataset file paths to the working directory (to avoid modifying the original dataset) 248 | for dataset_fpath in dataset_paths: 249 | shutil.copy(dataset_fpath, work_dir) 250 | 251 | # Get agents 252 | agent_objs = get_agents(work_dir, model_name=model_name, temperature=temperature, 253 | reasoning_effort=reasoning_effort, branching_factor=branching_factor, 254 | user_query=user_query, experiment_first=experiment_first, code_timeout=code_timeout) 255 | user_proxy = agent_objs["user_proxy"] 256 | experiment_generator = agent_objs["experiment_generator"] 257 | 258 | # Set up the group chat 259 | groupchat, chat_manager = setup_group_chat(agent_objs, max_rounds) 260 | 261 | if selection_method is None: 262 | # Default selection method is UCB1 263 | selection_method = default_mcts_selection(exploration_weight=1.0) 264 | 265 | # Store the list of (level, node_idx) tuples for surprising nodes; if resuming, load them from the previous run 266 | all_surprisals = [] 267 | for level in nodes_by_level: 268 | for node in nodes_by_level[level]: 269 | if node.surprising: 270 | all_surprisals.append((node.level, node.node_idx)) 271 | 272 | # Load warmstart experiments if provided 273 | _warmstart_experiments = None 274 | if warmstart_experiments is not None: 275 | with open(warmstart_experiments, "r") as f: 276 | _warmstart_experiments = json.load(f) 277 | 278 | # TEMPORARY LOGGING 279 | TEMP_LOG = [] 280 | 281 | try: 282 | for iteration_idx in range(max_iterations): 283 | # MCTS SELECTION, EXPANSION, and EXECUTION 284 | print(f"\n\n######### ITERATION {iteration_idx + 1} / {max_iterations} #########\n") 285 | 286 | # Select the next node to expand 287 | node = select_node(selection_method, root, nodes_by_level, n_warmstart) 288 | # Fetch or generate the next experiment from the selected node (retries built in) 289 | new_experiment, new_query = node.get_next_experiment(experiment_generator=experiment_generator) 290 | 291 | if new_query is not None: 292 | # Create a new node for the next experiment 293 | new_level = node.level + 1 294 | new_node_idx = len(nodes_by_level[new_level]) 295 | node = MCTSNode(level=new_level, node_idx=new_node_idx, hypothesis=new_experiment["hypothesis"], 296 | experiment_plan=new_experiment["experiment_plan"], query=new_query, parent=node, 297 | allow_generate_experiments=allow_generate_experiments and new_level > 0, 298 | untried_experiments=_warmstart_experiments if new_level == 1 else None) 299 | # Update logger state 300 | logger.level = node.level 301 | logger.node_idx = node.node_idx 302 | 303 | # Load previous explorations (make sure the root is always included) 304 | node_context = [] 305 | if node.level > 1: 306 | node_context = [root.children[0].get_context(include_code_output=True)] + node.get_path_context( 307 | k=k_parents - 1, skip_root=True) 308 | node_messages = [] 309 | if node_context is not None: 310 | node_messages += [ 311 | {"name": "user_proxy", "role": "user", "content": "PREVIOUS EXPLORATION:\n\n" + n} for n in 312 | node_context] 313 | node_messages += [ 314 | {"name": "user_proxy", "role": "user", "content": node.query}] 315 | _, last_message = chat_manager.resume(messages=node_messages) 316 | 317 | # Track time per node 318 | _node_start_time = time() 319 | 320 | # Execute current experiment and generate new experiments 321 | user_proxy.initiate_chat(recipient=chat_manager, message=last_message, clear_history=False) 322 | 323 | # Store the raw message logs for the node 324 | logger.log_node(node.level, node.node_idx, chat_manager.messages_to_string(groupchat.messages)) 325 | 326 | # Get messages starting from the current query and update the node 327 | node.messages = get_msgs_from_latest_query(groupchat.messages) 328 | node.read_experiment_from_messages( 329 | store_new_experiments=False if node.level == 1 and _warmstart_experiments is not None else True) 330 | # Calculate beliefs and rewards 331 | if node.success and node.level > 1: 332 | compute_and_store_reward(node, belief_model_name, belief_temperature, reasoning_effort, 333 | n_belief_samples, implicit_bayes_posterior, surprisal_width, belief_mode, 334 | use_binary_reward, all_surprisals, use_online_beliefs=use_online_beliefs, 335 | evidence_weight=evidence_weight, kl_scale=kl_scale, 336 | reward_mode=reward_mode, TEMP_LOG=TEMP_LOG) 337 | 338 | if node.success: # i.e., reward was computed successfully 339 | # Print debug information 340 | print_node_info(node) 341 | 342 | # TEMPORARY LOGGING 343 | if TEMP_LOG: 344 | temp_log_file = os.path.join(log_dirname, "temp_log.json") 345 | with open(temp_log_file, "w") as f: 346 | json.dump(TEMP_LOG, f, indent=2) 347 | print(f"Temporary log saved to {temp_log_file}") 348 | 349 | # End time tracking for the node 350 | _node_end_time = time() 351 | node.time_elapsed = round(_node_end_time - _node_start_time, 2) 352 | 353 | # Add the new node to the nodes_by_level dictionary 354 | nodes_by_level[node.level].append(node) 355 | 356 | # MCTS BACKPROPAGATION 357 | node.update_counts(visits=1, reward=node.self_value) 358 | 359 | # Save the current state of the node 360 | node_file = os.path.join(log_dirname, f"mcts_{node.id}.json") 361 | with open(node_file, "w") as f: 362 | json.dump(node.to_dict(), f, indent=2) 363 | else: 364 | # No new experiment was generated; don't change the state of the tree and sample again 365 | print(f"No new experiment generated for node {node.level}_{node.node_idx}. Skipping this iteration.") 366 | except KeyboardInterrupt: 367 | print("\n\n######### EXPLORATION INTERRUPTED! SAVING THE CURRENT STATE... #########\n\n") 368 | 369 | # End time tracking 370 | end_time = time() 371 | time_elapsed = end_time - start_time 372 | 373 | # Save all MCTS nodes 374 | save_nodes(nodes_by_level, log_dirname, run_dedupe, belief_model_name, time_elapsed=time_elapsed) 375 | 376 | 377 | if __name__ == "__main__": 378 | parser = ArgParser() 379 | args = parser.parse_args() 380 | print("Script arguments:") 381 | print(args.__dict__, "\n") 382 | 383 | # Validate and fix arguments 384 | if "o4-mini" in args.model and args.temperature is not None: 385 | print("Warning: Setting temperature for o4-mini is not permitted. Using default None.") 386 | args.temperature = None 387 | if "o4-mini" in args.belief_model and args.belief_temperature is not None: 388 | print("Warning: Setting temperature for o4-mini belief model is not permitted. Using default None.") 389 | args.belief_temperature = None 390 | 391 | # Create log directory 392 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 393 | log_dirname = os.path.join(args.out_dir, timestamp) if args.timestamp_dir else args.out_dir 394 | work_dirname = os.path.join(args.work_dir, timestamp) if args.timestamp_dir else args.work_dir 395 | 396 | # Setup logger 397 | logger = TreeLogger(log_dirname) 398 | 399 | # Save args 400 | args_file = os.path.join(log_dirname, "args.json") 401 | with open(args_file, "w") as f: 402 | json.dump(vars(args), f, indent=2) 403 | print(f"\nArguments saved to {args_file}\n") 404 | 405 | # Get dataset paths 406 | dataset_paths, dataset_metadata = get_datasets_fpaths(args.dataset_metadata, 407 | is_blade=args.dataset_metadata_type == 'blade') 408 | load_dataset_experiment = get_load_dataset_experiment(dataset_paths, dataset_metadata, run_eda=args.run_eda, 409 | dataset_metadata_type=args.dataset_metadata_type) 410 | 411 | if args.continue_from_dir or args.continue_from_json: 412 | if args.continue_from_dir is not None: 413 | # Load nodes from a directory 414 | root, nodes_by_level = load_mcts_from_json(args.continue_from_dir, args) 415 | # Copy all files except args.json from continue_from_dir to the new log directory 416 | for filename in os.listdir(args.continue_from_dir): 417 | if filename != "args.json": 418 | shutil.copy(os.path.join(args.continue_from_dir, filename), os.path.join(log_dirname, filename)) 419 | else: 420 | # Load from a JSON file that contains all the nodes (not de-duplicated) 421 | root, nodes_by_level = load_mcts_from_json(args.continue_from_json, args) 422 | 423 | if args.only_save_results: 424 | # Save nodes to JSON and exit 425 | save_nodes(nodes_by_level, log_dirname, run_dedupe=args.dedupe, model=args.belief_model) 426 | exit(0) 427 | 428 | if args.continue_from_dir is not None: 429 | # Copy all files except args.json from continue_from_dir to the new log directory 430 | for filename in os.listdir(args.continue_from_dir): 431 | if filename != "args.json": 432 | shutil.copy(os.path.join(args.continue_from_dir, filename), os.path.join(log_dirname, filename)) 433 | else: 434 | # Create the individual node files in the log directory 435 | for node in nodes_by_level.values(): 436 | for n in node: 437 | node_file = os.path.join(log_dirname, f"mcts_{n.id}.json") 438 | with open(node_file, "w") as f: 439 | json.dump(n.to_dict(), f, indent=2) 440 | 441 | # Calculate remaining iterations to reach n_experiments 442 | total_nodes = sum(len(nodes) for nodes in nodes_by_level.values()) 443 | remaining_iters = (args.n_experiments + 1) - total_nodes # + 1 to account for root node 444 | if remaining_iters <= 0: 445 | print(f"Already reached or exceeded target of {args.n_experiments} experiments") 446 | exit(0) 447 | print( 448 | f"RESUMING: Running {remaining_iters} more experiments to reach the target experiment count of {args.n_experiments}.\n") 449 | else: 450 | root = MCTSNode(level=0, node_idx=0, hypothesis=None, query=None, 451 | allow_generate_experiments=False, untried_experiments=[load_dataset_experiment]) 452 | nodes_by_level = defaultdict(list) 453 | nodes_by_level[0].append(root) 454 | remaining_iters = args.n_experiments + 1 # + 1 to account for root node 455 | 456 | # Set up selection method based on args 457 | if args.mcts_selection == "pw": 458 | # Progressive Widening 459 | assert args.pw_k is not None and args.pw_alpha is not None 460 | selection_method = progressive_widening(args.pw_k, args.pw_alpha, args.exploration_weight) 461 | elif args.mcts_selection == "pw_all": 462 | # Progressive Widening 463 | assert args.pw_k is not None and args.pw_alpha is not None 464 | selection_method = progressive_widening_all(args.pw_k, args.pw_alpha, args.exploration_weight) 465 | elif args.mcts_selection == "beam_search": 466 | # Beam Search 467 | selection_method = beam_search(args.k_experiments, args.beam_width, args.out_dir) 468 | elif args.mcts_selection == "ucb1": 469 | # UCB1 470 | selection_method = default_mcts_selection(args.exploration_weight) 471 | elif args.mcts_selection == "ucb1_recursive": 472 | # UCB1 recursive 473 | selection_method = ucb1_recursive(args.exploration_weight) 474 | else: 475 | raise ValueError(f"Unknown MCTS selection method: {args.mcts_selection}") 476 | print(f"MCTS selection method: {args.mcts_selection}\n") 477 | 478 | # Run exploration 479 | run_mcts( 480 | root=root, 481 | nodes_by_level=nodes_by_level, 482 | dataset_paths=dataset_paths, 483 | log_dirname=log_dirname, 484 | work_dir=work_dirname, 485 | max_iterations=remaining_iters, 486 | branching_factor=args.k_experiments, 487 | selection_method=selection_method, 488 | allow_generate_experiments=args.allow_generate_experiments, 489 | n_belief_samples=args.n_belief_samples, 490 | k_parents=args.k_parents, 491 | model_name=args.model, 492 | belief_model_name=args.belief_model, 493 | temperature=args.temperature, 494 | belief_temperature=args.belief_temperature, 495 | reasoning_effort=args.reasoning_effort, 496 | implicit_bayes_posterior=args.implicit_bayes_posterior, 497 | surprisal_width=args.surprisal_width, 498 | user_query=args.user_query, 499 | belief_mode=args.belief_mode, 500 | use_binary_reward=args.use_binary_reward, 501 | run_dedupe=args.dedupe, 502 | experiment_first=args.experiment_first, 503 | code_timeout=args.code_timeout, 504 | n_warmstart=args.n_warmstart, 505 | use_online_beliefs=args.use_online_beliefs, 506 | evidence_weight=args.evidence_weight, 507 | kl_scale=args.kl_scale, 508 | reward_mode=args.reward_mode, 509 | warmstart_experiments=args.warmstart_experiments, 510 | ) 511 | 512 | if args.delete_work_dir: 513 | shutil.rmtree(args.work_dir) 514 | print(f"\nDELETED WORKING DIRECTORY: {args.work_dir}") 515 | 516 | print(f"\nRUN FINISHED!\n\nLOGS: {log_dirname}") 517 | -------------------------------------------------------------------------------- /src/beliefs.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Optional 2 | 3 | import numpy as np 4 | from pydantic import BaseModel, Field 5 | 6 | from src.mcts import MCTSNode 7 | from src.mcts_utils import get_context_string 8 | from src.utils import query_llm, fuse_gaussians 9 | 10 | from scipy.special import betaln, gammaln, psi # betaln = log Beta function, psi = digamma 11 | 12 | 13 | class BeliefTrueFalse: 14 | class DistributionFormat: 15 | """ 16 | A distribution of beliefs about the hypothesis using true/false responses (Bernoulli). 17 | 18 | Attributes: 19 | n: Number of samples used to compute the distribution 20 | n_true: Number of "true" responses 21 | n_false: Number of "false" responses 22 | mean: Mean belief probability (optional, computed if not provided) 23 | prior_params: Parameters for the prior Beta distribution (alpha, beta) 24 | """ 25 | 26 | def __init__(self, 27 | n: float = Field(..., description="Number of samples used to compute the distribution"), 28 | n_true: float = Field(..., description='Number of "true" responses'), 29 | n_false: float = Field(..., description='Number of "false" responses'), 30 | mean: float | None = None, 31 | prior_params: Tuple[float, float] = (0.5, 0.5), 32 | **kwargs): 33 | self.n = n 34 | self.n_true = n_true 35 | self.n_false = n_false 36 | self.mean = mean 37 | self._empirical_mean = 0.5 38 | self.prior_params = prior_params 39 | 40 | def __repr__(self): 41 | return f"BeliefTrueFalse.DistributionFormat(n={self.n}, n_true={self.n_true}, n_false={self.n_false})" 42 | 43 | def to_dict(self): 44 | return { 45 | "_type": "boolean", 46 | "prior_params": self.prior_params, 47 | "n": self.n, 48 | "n_true": self.n_true, 49 | "n_false": self.n_false, 50 | "_empirical_mean": self._empirical_mean, 51 | "mean": self.mean, 52 | } 53 | 54 | def get_mean_belief(self, prior=None, recompute=False) -> float: 55 | """ 56 | Get the mean of the prior/posterior belief distribution. 57 | Args: 58 | prior (BeliefTrueFalse.DistributionFormat): Prior distribution format object. 59 | recompute (bool): Whether to recompute the mean even if it is already set. 60 | Returns: 61 | float: The mean belief probability. 62 | """ 63 | if self.mean is None or recompute: 64 | # Compute the mean belief using the Beta distribution 65 | if self.n > 0: 66 | self._empirical_mean = self.n_true / self.n 67 | self.mean = (self.prior_params[0] + self.n_true) / (self.n + sum(self.prior_params)) 68 | 69 | if prior is not None: 70 | # Bayesian update: Beta(n_true + a, n_false + b) where a and b are prior parameters 71 | post_alpha = prior.n_true + prior.prior_params[0] 72 | # post_beta = prior.n_false + prior.prior_params[1] 73 | self.mean = (self.n_true + post_alpha) / (self.n + prior.n + sum(prior.prior_params)) 74 | return self.mean 75 | 76 | def update(self, n_true: int = 0, n_false: int = 0, distr=None, normalize: bool = False): 77 | """ 78 | Update the distribution with new counts. 79 | """ 80 | if distr is not None: 81 | self.n_true += distr.n_true 82 | self.n_false += distr.n_false 83 | else: 84 | self.n_true += n_true 85 | self.n_false += n_false 86 | n = distr.n if distr is not None else (n_true + n_false) 87 | if normalize: 88 | total = self.n + n 89 | self.n_true /= (total / self.n) 90 | self.n_false /= (total / self.n) 91 | else: 92 | self.n += n 93 | # Reset mean 94 | _ = self.get_mean_belief(recompute=True) 95 | 96 | def get_params(self) -> Tuple[float, float]: 97 | """ 98 | Get the parameters of the Beta distribution. 99 | Returns: 100 | Tuple[float, float]: Parameters (alpha, beta) of the Beta distribution. 101 | """ 102 | return self.prior_params[0] + self.n_true, self.prior_params[1] + self.n_false 103 | 104 | class ResponseFormat(BaseModel): 105 | """ 106 | Belief about the support for the hypothesis. 107 | 108 | Attributes: 109 | belief (bool | None): Whether the hypothesis is true or false. If you do not have enough information to 110 | comment on the hypothesis, return None. 111 | """ 112 | belief: bool | None = Field(..., description="Whether the hypothesis is true") 113 | 114 | @staticmethod 115 | def parse_response(response: List[dict], 116 | prior_params: Tuple[float, float] = (0.5, 0.5), 117 | weight: float = 1.0) -> 'BeliefTrueFalse.DistributionFormat': 118 | """ 119 | Parse the response from the LLM into a DistributionFormat. 120 | 121 | Args: 122 | response (dict): The response from the LLM containing belief counts. 123 | prior_params (Tuple[float, float]): Parameters for the prior Beta distribution (alpha, beta). 124 | weight (float): Weight to apply to the counts (default is 1.0). 125 | 126 | Returns: 127 | BeliefTrueFalse.DistributionFormat: Parsed distribution format. 128 | """ 129 | n, n_true, n_false = 0, 0, 0 130 | for _res in response: 131 | if _res["belief"] is not None: 132 | n += 1 # Count only responses that provide a belief 133 | n_true += int(_res["belief"]) 134 | n_false += int(not _res["belief"]) 135 | n *= weight 136 | n_true *= weight 137 | n_false *= weight 138 | 139 | return BeliefTrueFalse.DistributionFormat(n=n, n_true=n_true, n_false=n_false, prior_params=prior_params) 140 | 141 | @staticmethod 142 | def kl_divergence(distr1: 'BeliefTrueFalse.DistributionFormat', 143 | distr2: 'BeliefTrueFalse.DistributionFormat') -> float: 144 | """ 145 | Compute the KL divergence between two distributions. 146 | Args: 147 | distr1 (BeliefTrueFalse.DistributionFormat): First distribution. 148 | distr2 (BeliefTrueFalse.DistributionFormat): Second distribution. 149 | Returns: 150 | float: KL divergence between the two distributions. 151 | """ 152 | alpha1, beta1 = distr1.get_params() 153 | alpha2, beta2 = distr2.get_params() 154 | term1 = betaln(alpha2, beta2) - betaln(alpha1, beta1) 155 | term2 = (alpha1 - alpha2) * psi(alpha1) 156 | term3 = (beta1 - beta2) * psi(beta1) 157 | term4 = (alpha2 - alpha1 + beta2 - beta1) * psi(alpha1 + beta1) 158 | return term1 + term2 + term3 + term4 159 | 160 | 161 | class BeliefCategorical: 162 | score_per_category = { 163 | "definitely_false": 0.1, 164 | "maybe_false": 0.3, 165 | "uncertain": 0.5, 166 | "maybe_true": 0.7, 167 | "definitely_true": 0.9 168 | } 169 | 170 | class DistributionFormat: 171 | """ 172 | A distribution of beliefs about the hypothesis using categorical buckets (Categorical). 173 | Attributes: 174 | n: Number of samples used to compute the distribution 175 | definitely_true: Number of "definitely true" responses 176 | maybe_true: Number of "maybe true" responses 177 | uncertain: Number of "uncertain" responses 178 | maybe_false: Number of "maybe false" responses 179 | definitely_false: Number of "definitely false" responses 180 | mean: Mean belief probability (optional, computed if not provided) 181 | prior_params: Parameters for the prior Dirichlet distribution (alpha1, alpha2, alpha3, alpha4, alpha5) 182 | """ 183 | 184 | def __init__(self, 185 | n: float = Field(..., description="Number of samples used to compute the distribution"), 186 | definitely_true: float = Field(..., description='Number of "definitely true" responses'), 187 | maybe_true: float = Field(..., description='Number of "maybe true" responses'), 188 | uncertain: float = Field(..., description='Number of "uncertain" responses'), 189 | maybe_false: float = Field(..., description='Number of "maybe false" responses'), 190 | definitely_false: float = Field(..., description='Number of "definitely false" responses'), 191 | mean: float | None = None, 192 | prior_params: Tuple[float, float, float, float, float] = (0.2, 0.2, 0.2, 0.2, 0.2), 193 | **kwargs): 194 | self.n = n 195 | self.definitely_true = definitely_true 196 | self.maybe_true = maybe_true 197 | self.uncertain = uncertain 198 | self.maybe_false = maybe_false 199 | self.definitely_false = definitely_false 200 | self.mean = mean 201 | self._empirical_mean = 0.5 202 | self.prior_params = prior_params # Parameters for the prior Dirichlet distribution 203 | 204 | def __repr__(self): 205 | return (f"BeliefCategorical.DistributionFormat(n={self.n}, definitely_true={self.definitely_true}, " 206 | f"maybe_true={self.maybe_true}, uncertain={self.uncertain}, " 207 | f"maybe_false={self.maybe_false}, definitely_false={self.definitely_false})") 208 | 209 | def to_dict(self): 210 | return { 211 | "_type": "categorical", 212 | "prior_params": self.prior_params, 213 | "n": self.n, 214 | "definitely_true": self.definitely_true, 215 | "maybe_true": self.maybe_true, 216 | "uncertain": self.uncertain, 217 | "maybe_false": self.maybe_false, 218 | "definitely_false": self.definitely_false, 219 | "_empirical_mean": self._empirical_mean, 220 | "mean": self.mean, 221 | } 222 | 223 | def get_mean_belief(self, prior=None, recompute=False) -> float: 224 | """ 225 | Get the mean of the prior/posterior belief distribution. 226 | Args: 227 | prior (BeliefCategorical.DistributionFormat): Prior distribution format object. 228 | recompute (bool): Whether to recompute the mean even if it is already set. 229 | Returns: 230 | float: The mean belief probability. 231 | """ 232 | if self.mean is None or recompute: 233 | # Compute the mean belief using the Dirichlet distribution 234 | if self.n > 0: 235 | mean_per_category = { 236 | "definitely_true": self.definitely_true / self.n, 237 | "maybe_true": self.maybe_true / self.n, 238 | "uncertain": self.uncertain / self.n, 239 | "maybe_false": self.maybe_false / self.n, 240 | "definitely_false": self.definitely_false / self.n 241 | } 242 | self._empirical_mean = sum( 243 | mean_per_category[cat] * BeliefCategorical.score_per_category[cat] for cat in mean_per_category) 244 | 245 | mean_per_category = { 246 | "definitely_true": (self.definitely_true + self.prior_params[0]) / ( 247 | self.n + sum(self.prior_params)), 248 | "maybe_true": (self.maybe_true + self.prior_params[1]) / ( 249 | self.n + sum(self.prior_params)), 250 | "uncertain": (self.uncertain + self.prior_params[2]) / (self.n + sum(self.prior_params)), 251 | "maybe_false": (self.maybe_false + self.prior_params[3]) / ( 252 | self.n + sum(self.prior_params)), 253 | "definitely_false": (self.definitely_false + self.prior_params[4]) / ( 254 | self.n + sum(self.prior_params)) 255 | } 256 | self.mean = sum( 257 | mean_per_category[cat] * BeliefCategorical.score_per_category[cat] for cat in mean_per_category) 258 | 259 | if prior is not None: 260 | # Bayesian update 261 | mean_per_category = { 262 | "definitely_true": (self.definitely_true + prior.definitely_true + prior.prior_params[0]) / ( 263 | self.n + prior.n + sum(prior.prior_params)), 264 | "maybe_true": (self.maybe_true + prior.maybe_true + prior.prior_params[1]) / ( 265 | self.n + prior.n + sum(prior.prior_params)), 266 | "uncertain": (self.uncertain + prior.uncertain + prior.prior_params[2]) / ( 267 | self.n + prior.n + sum(prior.prior_params)), 268 | "maybe_false": (self.maybe_false + prior.maybe_false + prior.prior_params[3]) / ( 269 | self.n + prior.n + sum(prior.prior_params)), 270 | "definitely_false": (self.definitely_false + prior.definitely_false + prior.prior_params[4]) / ( 271 | self.n + prior.n + sum(prior.prior_params)) 272 | } 273 | self.mean = sum( 274 | mean_per_category[cat] * BeliefCategorical.score_per_category[cat] for cat in mean_per_category) 275 | return self.mean 276 | 277 | def update(self, 278 | definitely_true: int = 0, 279 | maybe_true: int = 0, 280 | uncertain: int = 0, 281 | maybe_false: int = 0, 282 | definitely_false: int = 0, 283 | distr=None, 284 | normalize: bool = False): 285 | """ 286 | Update the distribution with new counts. 287 | """ 288 | if distr is not None: 289 | self.definitely_true += distr.definitely_true 290 | self.maybe_true += distr.maybe_true 291 | self.uncertain += distr.uncertain 292 | self.maybe_false += distr.maybe_false 293 | self.definitely_false += distr.definitely_false 294 | else: 295 | self.definitely_true += definitely_true 296 | self.maybe_true += maybe_true 297 | self.uncertain += uncertain 298 | self.maybe_false += maybe_false 299 | self.definitely_false += definitely_false 300 | n = distr.n if distr is not None else ( 301 | definitely_true + maybe_true + uncertain + maybe_false + definitely_false 302 | ) 303 | if normalize: 304 | total = self.n + n 305 | self.definitely_true /= (total / self.n) 306 | self.maybe_true /= (total / self.n) 307 | self.uncertain /= (total / self.n) 308 | self.maybe_false /= (total / self.n) 309 | self.definitely_false /= (total / self.n) 310 | else: 311 | self.n += n 312 | # Reset mean 313 | _ = self.get_mean_belief(recompute=True) 314 | 315 | def get_params(self) -> Tuple[float, float, float, float, float]: 316 | """ 317 | Get the parameters of the Dirichlet distribution. 318 | Returns: 319 | Tuple[float, float, float, float, float]: Parameters (alpha1, alpha2, alpha3, alpha4, alpha5) of the Dirichlet distribution. 320 | """ 321 | return (self.prior_params[0] + self.definitely_true, 322 | self.prior_params[1] + self.maybe_true, 323 | self.prior_params[2] + self.uncertain, 324 | self.prior_params[3] + self.maybe_false, 325 | self.prior_params[4] + self.definitely_false) 326 | 327 | class ResponseFormat(BaseModel): 328 | """ 329 | Belief about the support for the hypothesis. 330 | 331 | Attributes: 332 | belief (str): Belief about the support for the hypothesis. Choices are: 333 | "definitely true": Hypothesis is definitely true. 334 | "maybe true": Hypothesis may be true. 335 | "uncertain": Hypothesis is equally likely to be true or false (e.g., because of relevant but contradictory evidence). 336 | "maybe false": Hypothesis may be false. 337 | "definitely false": Hypothesis is definitely false. 338 | "cannot comment": Not enough information to comment on the hypothesis (e.g., due to lack of domain knowledge or lack of relevant evidence). 339 | """ 340 | belief: str = Field(..., description="Belief about the hypothesis", 341 | choices=["definitely true", "maybe true", "uncertain", 342 | "maybe false", "definitely false", "cannot comment"]) 343 | 344 | @staticmethod 345 | def parse_response(response: List[dict], prior_params: Tuple[float, float, float, float, float] = ( 346 | 0.2, 0.2, 0.2, 0.2, 0.2), weight=1.0) -> 'BeliefCategorical.DistributionFormat': 347 | """ 348 | Parse the response from the LLM into a DistributionFormat. 349 | 350 | Args: 351 | response (dict): The response from the LLM containing belief counts. 352 | prior_params (Tuple[float, float, float, float, float]): Parameters for the prior Dirichlet distribution. 353 | weight (float): Weight to apply to the counts (default is 1.0). 354 | 355 | Returns: 356 | BeliefCategorical.DistributionFormat: Parsed distribution format. 357 | """ 358 | cannot_comment = sum(1 for _res in response if _res["belief"] == "cannot comment") 359 | definitely_true = weight * sum(1 for _res in response if _res["belief"] == "definitely true") 360 | maybe_true = weight * sum(1 for _res in response if _res["belief"] == "maybe true") 361 | uncertain = weight * sum(1 for _res in response if _res["belief"] == "uncertain") 362 | maybe_false = weight * sum(1 for _res in response if _res["belief"] == "maybe false") 363 | definitely_false = weight * sum(1 for _res in response if _res["belief"] == "definitely false") 364 | n = weight * (len(response) - cannot_comment) # Exclude responses with "cannot comment" 365 | 366 | return BeliefCategorical.DistributionFormat( 367 | n=n, 368 | definitely_true=definitely_true, 369 | maybe_true=maybe_true, 370 | uncertain=uncertain, 371 | maybe_false=maybe_false, 372 | definitely_false=definitely_false, 373 | prior_params=prior_params 374 | ) 375 | 376 | @staticmethod 377 | def kl_divergence(distr1: 'BeliefCategorical.DistributionFormat', 378 | distr2: 'BeliefCategorical.DistributionFormat') -> float: 379 | """ 380 | Compute the KL divergence between two distributions. 381 | Args: 382 | distr1 (BeliefCategorical.DistributionFormat): First distribution. 383 | distr2 (BeliefCategorical.DistributionFormat): Second distribution. 384 | Returns: 385 | float: KL divergence between the two distributions. 386 | """ 387 | alpha = np.array(distr1.get_params()) 388 | beta = np.array(distr2.get_params()) 389 | 390 | sum_alpha = np.sum(alpha) 391 | sum_beta = np.sum(beta) 392 | 393 | term1 = gammaln(sum_alpha) - np.sum(gammaln(alpha)) 394 | term2 = -gammaln(sum_beta) + np.sum(gammaln(beta)) 395 | term3 = np.sum((alpha - beta) * (psi(alpha) - psi(sum_alpha))) 396 | 397 | return term1 + term2 + term3 398 | 399 | 400 | class BeliefCategoricalNumeric: 401 | score_per_category = { 402 | "0-0.2": 0.1, 403 | "0.2-0.4": 0.3, 404 | "0.4-0.6": 0.5, 405 | "0.6-0.8": 0.7, 406 | "0.8-1.0": 0.9 407 | } 408 | 409 | class DistributionFormat: 410 | """ 411 | A distribution of beliefs about the hypothesis using numerical buckets (Categorical). 412 | Attributes: 413 | n: Number of samples used to compute the distribution 414 | bucket_02: Number of responses that fall in the range [0.0, 0.2) 415 | bucket_24: Number of responses that fall in the range [0.2, 0.4) 416 | bucket_46: Number of responses that fall in the range [0.4, 0.6) 417 | bucket_68: Number of responses that fall in the range [0.6, 0.8) 418 | bucket_810: Number of responses that fall in the range [0.8, 1.0) 419 | mean: Mean belief probability (optional, computed if not provided) 420 | prior_params: Parameters for the prior Dirichlet distribution (alpha1, alpha2, alpha3, alpha4, alpha5) 421 | """ 422 | 423 | def __init__(self, 424 | n: float = Field(..., description="Number of samples used to compute the distribution"), 425 | bucket_02: float = Field(..., description='Number of responses that fall in the range [0.0, 0.2)'), 426 | bucket_24: float = Field(..., description='Number of responses that fall in the range [0.2, 0.4)'), 427 | bucket_46: float = Field(..., description='Number of responses that fall in the range [0.4, 0.6)'), 428 | bucket_68: float = Field(..., description='Number of responses that fall in the range [0.6, 0.8)'), 429 | bucket_810: float = Field(..., 430 | description='Number of responses that fall in the range [0.8, 1.0)'), 431 | mean: float | None = None, 432 | prior_params: Tuple[float, float, float, float, float] = (0.2, 0.2, 0.2, 0.2, 0.2), 433 | **kwargs): 434 | self.n = n 435 | self.bucket_02 = bucket_02 436 | self.bucket_24 = bucket_24 437 | self.bucket_46 = bucket_46 438 | self.bucket_68 = bucket_68 439 | self.bucket_810 = bucket_810 440 | self.mean = mean 441 | self._empirical_mean = 0.5 442 | self.prior_params = prior_params # Parameters for the prior Dirichlet distribution 443 | 444 | def __repr__(self): 445 | return (f"BeliefCategoricalNumeric.DistributionFormat(n={self.n}, bucket_02={self.bucket_02}, " 446 | f"bucket_24={self.bucket_24}, bucket_46={self.bucket_46}, " 447 | f"bucket_68={self.bucket_68}, bucket_810={self.bucket_810})") 448 | 449 | def to_dict(self): 450 | return { 451 | "_type": "categorical_numeric", 452 | "prior_params": self.prior_params, 453 | "n": self.n, 454 | "bucket_02": self.bucket_02, 455 | "bucket_24": self.bucket_24, 456 | "bucket_46": self.bucket_46, 457 | "bucket_68": self.bucket_68, 458 | "bucket_810": self.bucket_810, 459 | "_empirical_mean": self._empirical_mean, 460 | "mean": self.mean, 461 | } 462 | 463 | def get_mean_belief(self, prior=None, recompute=False) -> float: 464 | """ 465 | Get the mean of the prior/posterior belief distribution. 466 | Args: 467 | prior (BeliefCategoricalNumeric.DistributionFormat): Prior distribution format object. 468 | recompute (bool): Whether to recompute the mean even if it is already set. 469 | Returns: 470 | float: The mean belief probability. 471 | """ 472 | if self.mean is None or recompute: 473 | # Compute the mean belief using the Dirichlet distribution 474 | if self.n > 0: 475 | mean_per_category = { 476 | "0-0.2": self.bucket_02 / self.n, 477 | "0.2-0.4": self.bucket_24 / self.n, 478 | "0.4-0.6": self.bucket_46 / self.n, 479 | "0.6-0.8": self.bucket_68 / self.n, 480 | "0.8-1.0": self.bucket_810 / self.n 481 | } 482 | self._empirical_mean = sum( 483 | mean_per_category[cat] * BeliefCategoricalNumeric.score_per_category[cat] for cat in 484 | mean_per_category) 485 | 486 | mean_per_category = { 487 | "0-0.2": (self.bucket_02 + self.prior_params[0]) / (self.n + sum(self.prior_params)), 488 | "0.2-0.4": (self.bucket_24 + self.prior_params[1]) / (self.n + sum(self.prior_params)), 489 | "0.4-0.6": (self.bucket_46 + self.prior_params[2]) / (self.n + sum(self.prior_params)), 490 | "0.6-0.8": (self.bucket_68 + self.prior_params[3]) / (self.n + sum(self.prior_params)), 491 | "0.8-1.0": (self.bucket_810 + self.prior_params[4]) / (self.n + sum(self.prior_params)) 492 | } 493 | self.mean = sum( 494 | mean_per_category[cat] * BeliefCategoricalNumeric.score_per_category[cat] for cat in 495 | mean_per_category) 496 | 497 | if prior is not None: 498 | # Bayesian update 499 | mean_per_category = { 500 | "0-0.2": (self.bucket_02 + prior.bucket_02 + prior.prior_params[0]) / ( 501 | self.n + prior.n + sum(prior.prior_params)), 502 | "0.2-0.4": (self.bucket_24 + prior.bucket_24 + prior.prior_params[1]) / ( 503 | self.n + prior.n + sum(prior.prior_params)), 504 | "0.4-0.6": (self.bucket_46 + prior.bucket_46 + prior.prior_params[2]) / ( 505 | self.n + prior.n + sum(prior.prior_params)), 506 | "0.6-0.8": (self.bucket_68 + prior.bucket_68 + prior.prior_params[3]) / ( 507 | self.n + prior.n + sum(prior.prior_params)), 508 | "0.8-1.0": (self.bucket_810 + prior.bucket_810 + prior.prior_params[4]) / ( 509 | self.n + prior.n + sum(prior.prior_params)) 510 | } 511 | self.mean = sum(mean_per_category[cat] * BeliefCategoricalNumeric.score_per_category[cat] for cat in 512 | mean_per_category) 513 | 514 | return self.mean 515 | 516 | def update(self, 517 | bucket_02: int = 0, 518 | bucket_24: int = 0, 519 | bucket_46: int = 0, 520 | bucket_68: int = 0, 521 | bucket_810: int = 0, 522 | distr=None, 523 | normalize: bool = False): 524 | """ 525 | Update the distribution with new counts. 526 | """ 527 | if distr is not None: 528 | self.bucket_02 += distr.bucket_02 529 | self.bucket_24 += distr.bucket_24 530 | self.bucket_46 += distr.bucket_46 531 | self.bucket_68 += distr.bucket_68 532 | self.bucket_810 += distr.bucket_810 533 | else: 534 | self.bucket_02 += bucket_02 535 | self.bucket_24 += bucket_24 536 | self.bucket_46 += bucket_46 537 | self.bucket_68 += bucket_68 538 | self.bucket_810 += bucket_810 539 | n = distr.n if distr is not None else ( 540 | bucket_02 + bucket_24 + bucket_46 + bucket_68 + bucket_810 541 | ) 542 | if normalize: 543 | total = self.n + n 544 | self.bucket_02 /= (total / self.n) 545 | self.bucket_24 /= (total / self.n) 546 | self.bucket_46 /= (total / self.n) 547 | self.bucket_68 /= (total / self.n) 548 | self.bucket_810 /= (total / self.n) 549 | else: 550 | self.n += n 551 | # Reset mean 552 | _ = self.get_mean_belief(recompute=True) 553 | 554 | def get_params(self) -> Tuple[float, float, float, float, float]: 555 | """ 556 | Get the parameters of the Dirichlet distribution. 557 | Returns: 558 | Tuple[float, float, float, float, float]: Parameters (alpha1, alpha2, alpha3, alpha4, alpha5) of the Dirichlet distribution. 559 | """ 560 | return (self.prior_params[0] + self.bucket_02, 561 | self.prior_params[1] + self.bucket_24, 562 | self.prior_params[2] + self.bucket_46, 563 | self.prior_params[3] + self.bucket_68, 564 | self.prior_params[4] + self.bucket_810) 565 | 566 | class ResponseFormat(BaseModel): 567 | """ 568 | Belief about the support for the hypothesis. 569 | 570 | Attributes: 571 | belief (str): Belief about the support for the hypothesis. Choices are: 572 | "0-0.2": Hypothesis is definitely false. 573 | "0.2-0.4": Hypothesis may be false. 574 | "0.4-0.6": Hypothesis is equally likely to be true or false (e.g., because of relevant but contradictory evidence). 575 | "0.6-0.8": Hypothesis may be true. 576 | "0.8-1.0": Hypothesis is definitely true. 577 | "cannot comment": Not enough information to comment on the hypothesis (e.g., due to lack of domain knowledge or lack of relevant evidence). 578 | """ 579 | belief: str = Field(..., description="Belief about the hypothesis being true", 580 | choices=["0-0.2", "0.2-0.4", "0.4-0.6", "0.6-0.8", "0.8-1.0", 581 | "cannot comment"]) 582 | 583 | @staticmethod 584 | def parse_response(response: List[dict], prior_params: Tuple[float, float, float, float, float] = ( 585 | 0.2, 0.2, 0.2, 0.2, 0.2), weight: float = 1.0) -> 'BeliefCategoricalNumeric.DistributionFormat': 586 | """ 587 | Parse the response from the LLM into a DistributionFormat. 588 | 589 | Args: 590 | response (dict): The response from the LLM containing belief counts. 591 | prior_params (Tuple[float, float, float, float, float]): Parameters for the prior Dirichlet distribution. 592 | weight (float): Weight to apply to the counts (default is 1.0). 593 | 594 | Returns: 595 | BeliefCategoricalNumeric.DistributionFormat: Parsed distribution format. 596 | """ 597 | cannot_comment = sum(1 for _res in response if _res["belief"] == "cannot comment") 598 | bucket_02 = weight * sum(1 for _res in response if _res["belief"] == "0-0.2") 599 | bucket_24 = weight * sum(1 for _res in response if _res["belief"] == "0.2-0.4") 600 | bucket_46 = weight * sum(1 for _res in response if _res["belief"] == "0.4-0.6") 601 | bucket_68 = weight * sum(1 for _res in response if _res["belief"] == "0.6-0.8") 602 | bucket_810 = weight * sum(1 for _res in response if _res["belief"] == "0.8-1.0") 603 | n = weight * (len(response) - cannot_comment) # Exclude responses with "cannot comment" 604 | 605 | return BeliefCategoricalNumeric.DistributionFormat( 606 | n=n, 607 | bucket_02=bucket_02, 608 | bucket_24=bucket_24, 609 | bucket_46=bucket_46, 610 | bucket_68=bucket_68, 611 | bucket_810=bucket_810, 612 | prior_params=prior_params 613 | ) 614 | 615 | @staticmethod 616 | def kl_divergence(distr1: 'BeliefCategoricalNumeric.DistributionFormat', 617 | distr2: 'BeliefCategoricalNumeric.DistributionFormat') -> float: 618 | """ 619 | Compute the KL divergence between two distributions. 620 | Args: 621 | distr1 (BeliefCategoricalNumeric.DistributionFormat): First distribution. 622 | distr2 (BeliefCategoricalNumeric.DistributionFormat): Second distribution. 623 | Returns: 624 | float: KL divergence between the two distributions. 625 | """ 626 | alpha = np.array(distr1.get_params()) 627 | beta = np.array(distr2.get_params()) 628 | 629 | sum_alpha = np.sum(alpha) 630 | sum_beta = np.sum(beta) 631 | 632 | term1 = gammaln(sum_alpha) - np.sum(gammaln(alpha)) 633 | term2 = -gammaln(sum_beta) + np.sum(gammaln(beta)) 634 | term3 = np.sum((alpha - beta) * (psi(alpha) - psi(sum_alpha))) 635 | 636 | return term1 + term2 + term3 637 | 638 | 639 | class BeliefGauss: 640 | """ 641 | A distribution of beliefs about the hypothesis using Gaussian mean and standard deviation samples. 642 | 643 | Attributes: 644 | n: Number of samples used to compute the distribution 645 | mean: Mean probability of the hypothesis being true 646 | stddev: Standard deviation of the probabilities 647 | prior_params: Parameters for the prior Gaussian distribution (mean, stddev) 648 | samples: Dictionary containing means and standard deviations of the samples 649 | weight: Weight to apply to the counts (default is 1.0) 650 | """ 651 | 652 | class DistributionFormat: 653 | def __init__(self, 654 | n: float = Field(..., description="Number of samples used to compute the distribution"), 655 | mean: float = Field(..., description="Mean probability of the hypothesis being true"), 656 | stddev: float = Field(..., description="Standard deviation of the probabilities"), 657 | prior_params: Tuple[float, float] = (0.5, 5), 658 | samples=None, 659 | weight=1.0, 660 | **kwargs): 661 | self.n = n 662 | self.samples = samples 663 | self.weight = weight 664 | self._empirical_mean, self._empirical_stddev = fuse_gaussians(self.samples["means"], 665 | self.samples["stddevs"]) 666 | self.mean = mean 667 | self.stddev = stddev 668 | self.prior_params = prior_params # Parameters for the prior Gaussian distribution (mean, stddev) 669 | 670 | def __repr__(self): 671 | return f"BeliefGauss.DistributionFormat(n={self.n}, mean={self.mean}, stddev={self.stddev})" 672 | 673 | def to_dict(self): 674 | return { 675 | "_type": "gaussian", 676 | "prior_params": self.prior_params, 677 | "samples": self.samples, 678 | "n": self.n, 679 | "_empirical_mean": self._empirical_mean, 680 | "_empirical_stddev": self._empirical_stddev, 681 | "mean": self.mean, 682 | "stddev": self.stddev 683 | } 684 | 685 | def get_mean_belief(self, prior=None, recompute=False) -> float: 686 | """ 687 | Get the mean of the prior/posterior belief distribution. 688 | Args: 689 | prior (BeliefGauss.DistributionFormat): Prior distribution format object. 690 | recompute (bool): Whether to recompute the mean even if it is already set. 691 | Returns: 692 | float: The mean belief probability. 693 | """ 694 | if recompute: 695 | # Compute the mean belief using the Gaussian distribution 696 | self._empirical_mean, self._empirical_stddev = fuse_gaussians(self.samples["means"], 697 | self.samples["stddevs"]) 698 | self.mean, self.stddev = fuse_gaussians([self.prior_params[0]] + self.samples["means"], 699 | [self.prior_params[1]] + self.samples["stddevs"], 700 | weight=self.weight) 701 | 702 | if prior is not None: 703 | # Bayesian update 704 | self.mean, self.stddev = fuse_gaussians( 705 | [prior.mean, self.mean], [prior.stddev, self.stddev] 706 | ) 707 | 708 | return self.mean 709 | 710 | def update(self, means: List[float] = None, stddevs: List[float] = None, 711 | distr=None, normalize: bool = False): 712 | """ 713 | Update the distribution with new samples or another distribution. 714 | """ 715 | if distr is not None: 716 | self.n += distr.n 717 | # TOFIX: samples don't take into account the original weight of the distribution 718 | self.samples["means"].extend(distr.samples["means"]) 719 | self.samples["stddevs"].extend(distr.samples["stddevs"]) 720 | self._empirical_mean, self._empirical_stddev = fuse_gaussians( 721 | [self._empirical_mean, distr._empirical_mean], [self._empirical_stddev, distr._empirical_stddev] 722 | ) 723 | self.mean, self.stddev = fuse_gaussians( 724 | [self.mean, distr.mean], [self.stddev, distr.stddev] 725 | ) 726 | else: 727 | self.n += len(means) 728 | self.samples["means"].extend(means) 729 | self.samples["stddevs"].extend(stddevs) 730 | self._empirical_mean, self._empirical_stddev = fuse_gaussians([self._empirical_mean] + means, 731 | [self._empirical_stddev] + stddevs) 732 | self.mean, self.stddev = fuse_gaussians([self.mean] + means, [self.stddev] + stddevs) 733 | 734 | def get_params(self) -> Tuple[float, float]: 735 | """ 736 | Get the parameters of the Gaussian distribution. 737 | Returns: 738 | Tuple[float, float]: Parameters (mean, stddev) of the Gaussian distribution. 739 | """ 740 | return self.mean, self.stddev 741 | 742 | class ResponseFormat(BaseModel): 743 | """ 744 | Belief (Gaussian) distribution about the support for the hypothesis. 745 | 746 | Attributes: 747 | belief_mean (float): Mean of the belief distribution (0.0 to 1.0). <0.2: the hypothesis is most likely 748 | false; 0.2-0.4: the hypothesis may be false; 0.4-0.6: the hypothesis is equally likely to be true or false; 749 | 0.6-0.8: the hypothesis may be true; >0.8: the hypothesis is most likely true. 750 | belief_stddev (float): Standard deviation of the belief distribution (0.0 to infinity). Smaller values 751 | indicate more confidence in the belief. If you are not confident in your belief, set the standard deviation 752 | to a large value (e.g., 1000.0). 753 | """ 754 | belief_mean: float = Field(..., description="Mean of the belief distribution", 755 | ge=0.0, le=1.0) 756 | belief_stddev: float = Field(..., description="Standard deviation of the belief distribution", 757 | ge=0.0, le=float('inf')) 758 | 759 | @staticmethod 760 | def parse_response(response: List[dict], prior_params: Tuple[float, float] = (0.5, 5), 761 | weight: float = 1.0) -> 'BeliefGauss.DistributionFormat': 762 | """ 763 | Parse the response from the LLM into a DistributionFormat. 764 | 765 | Args: 766 | response (dict): The response from the LLM containing belief probabilities. 767 | prior_params (Tuple[float, float]): Parameters for the prior Gaussian distribution (mean, stddev). 768 | weight (float): Weight to apply to the counts (default is 1.0). 769 | 770 | Returns: 771 | BeliefGauss.DistributionFormat: Parsed distribution format. 772 | """ 773 | n = weight * len(response) 774 | 775 | means = [_res["belief_mean"] for _res in response] 776 | stddevs = [_res["belief_stddev"] for _res in response] 777 | 778 | mean, stddev = fuse_gaussians([prior_params[0]] + means, [prior_params[1]] + stddevs, weight=weight) 779 | 780 | return BeliefGauss.DistributionFormat(n=n, mean=mean, stddev=stddev, prior_params=prior_params, 781 | samples={"means": means, "stddevs": stddevs}, weight=weight) 782 | 783 | @staticmethod 784 | def kl_divergence(distr1: 'BeliefGauss.DistributionFormat', 785 | distr2: 'BeliefGauss.DistributionFormat') -> float: 786 | """ 787 | Compute the KL divergence between two Gaussian distributions. 788 | Args: 789 | distr1 (BeliefGauss.DistributionFormat): First distribution. 790 | distr2 (BeliefGauss.DistributionFormat): Second distribution. 791 | Returns: 792 | float: KL divergence between the two distributions. 793 | """ 794 | mean1, stddev1 = distr1.get_params() 795 | mean2, stddev2 = distr2.get_params() 796 | 797 | return 0.5 * (np.log(stddev2 ** 2 / stddev1 ** 2) + (stddev1 ** 2 + (mean1 - mean2) ** 2) / stddev2 ** 2 - 1) 798 | 799 | 800 | class BeliefTrueFalseCat: 801 | score_per_category = { 802 | "definitely_false": 0.0, 803 | "maybe_false": 0.25, 804 | "uncertain": 0.5, 805 | "maybe_true": 0.75, 806 | "definitely_true": 1.0 807 | } 808 | 809 | class DistributionFormat: 810 | """ 811 | A distribution of beliefs about the hypothesis using categorical buckets (Categorical). 812 | Attributes: 813 | n: Number of samples used to compute the distribution 814 | definitely_true: Number of "definitely true" responses 815 | maybe_true: Number of "maybe true" responses 816 | uncertain: Number of "uncertain" responses 817 | maybe_false: Number of "maybe false" responses 818 | definitely_false: Number of "definitely false" responses 819 | mean: Mean belief probability (optional, computed if not provided) 820 | prior_params: Parameters for the prior Beta distribution (alpha, beta) 821 | """ 822 | 823 | def __init__(self, 824 | n: float = Field(..., description="Number of samples used to compute the distribution"), 825 | definitely_true: float = Field(..., description='Number of "definitely true" responses'), 826 | maybe_true: float = Field(..., description='Number of "maybe true" responses'), 827 | uncertain: float = Field(..., description='Number of "uncertain" responses'), 828 | maybe_false: float = Field(..., description='Number of "maybe false" responses'), 829 | definitely_false: float = Field(..., description='Number of "definitely false" responses'), 830 | mean: float | None = None, 831 | prior_params: Tuple[float, float] = (0.5, 0.5), 832 | **kwargs): 833 | self.n = n 834 | self.definitely_true = definitely_true 835 | self.maybe_true = maybe_true 836 | self.uncertain = uncertain 837 | self.maybe_false = maybe_false 838 | self.definitely_false = definitely_false 839 | self.mean = mean 840 | self._empirical_mean = 0.5 841 | self.prior_params = prior_params # Parameters for the prior Beta distribution 842 | 843 | def __repr__(self): 844 | return (f"BeliefTrueFalseCat.DistributionFormat(n={self.n}, definitely_true={self.definitely_true}, " 845 | f"maybe_true={self.maybe_true}, uncertain={self.uncertain}, " 846 | f"maybe_false={self.maybe_false}, definitely_false={self.definitely_false})") 847 | 848 | def to_dict(self): 849 | return { 850 | "_type": "boolean_cat", 851 | "prior_params": self.prior_params, 852 | "n": self.n, 853 | "definitely_true": self.definitely_true, 854 | "maybe_true": self.maybe_true, 855 | "uncertain": self.uncertain, 856 | "maybe_false": self.maybe_false, 857 | "definitely_false": self.definitely_false, 858 | "_empirical_mean": self._empirical_mean, 859 | "mean": self.mean, 860 | } 861 | 862 | def get_mean_belief(self, prior=None, recompute=False) -> float: 863 | """ 864 | Get the mean of the prior/posterior belief distribution. 865 | Args: 866 | prior (BeliefTrueFalseCat.DistributionFormat): Prior distribution format object. 867 | recompute (bool): Whether to recompute the mean even if it is already set. 868 | Returns: 869 | float: The mean belief probability. 870 | """ 871 | if self.mean is None or recompute: 872 | # Compute the mean belief using the Beta distribution 873 | alpha1, alpha2 = BeliefTrueFalseCat.get_beta_params_from_cat_samples( 874 | self.definitely_true, self.maybe_true, self.uncertain, self.maybe_false, self.definitely_false 875 | ) 876 | if self.n > 0: 877 | self._empirical_mean = alpha1 / self.n 878 | self.mean = (self.prior_params[0] + alpha1) / (self.n + sum(self.prior_params)) 879 | 880 | if prior is not None: 881 | # Bayesian update: Beta(n_true + a, n_false + b) where a and b are prior parameters 882 | prior_alpha1, prior_alpha2 = BeliefTrueFalseCat.get_beta_params_from_cat_samples( 883 | prior.definitely_true, prior.maybe_true, prior.uncertain, prior.maybe_false, 884 | prior.definitely_false 885 | ) 886 | post_alpha = prior_alpha1 + prior.prior_params[0] 887 | # post_beta = prior_alpha2 + prior.prior_params[1] 888 | self.mean = (alpha1 + post_alpha) / (self.n + prior.n + sum(prior.prior_params)) 889 | return self.mean 890 | 891 | def update(self, 892 | definitely_true: int = 0, 893 | maybe_true: int = 0, 894 | uncertain: int = 0, 895 | maybe_false: int = 0, 896 | definitely_false: int = 0, 897 | distr=None, 898 | normalize: bool = False): 899 | """ 900 | Update the distribution with new counts. 901 | """ 902 | if distr is not None: 903 | self.definitely_true += distr.definitely_true 904 | self.maybe_true += distr.maybe_true 905 | self.uncertain += distr.uncertain 906 | self.maybe_false += distr.maybe_false 907 | self.definitely_false += distr.definitely_false 908 | else: 909 | self.definitely_true += definitely_true 910 | self.maybe_true += maybe_true 911 | self.uncertain += uncertain 912 | self.maybe_false += maybe_false 913 | self.definitely_false += definitely_false 914 | n = distr.n if distr is not None else ( 915 | definitely_true + maybe_true + uncertain + maybe_false + definitely_false 916 | ) 917 | if normalize: 918 | total = self.n + n 919 | self.definitely_true /= (total / self.n) 920 | self.maybe_true /= (total / self.n) 921 | self.uncertain /= (total / self.n) 922 | self.maybe_false /= (total / self.n) 923 | self.definitely_false /= (total / self.n) 924 | else: 925 | self.n += n 926 | # Reset mean 927 | _ = self.get_mean_belief(recompute=True) 928 | 929 | def get_params(self) -> Tuple[float, float]: 930 | """ 931 | Get the parameters of the Beta distribution. 932 | Returns: 933 | Tuple[float, float]: Parameters (alpha, beta) of the Beta distribution. 934 | """ 935 | alpha1, alpha2 = BeliefTrueFalseCat.get_beta_params_from_cat_samples( 936 | self.definitely_true, self.maybe_true, self.uncertain, self.maybe_false, self.definitely_false 937 | ) 938 | return self.prior_params[0] + alpha1, self.prior_params[1] + alpha2 939 | 940 | class ResponseFormat(BaseModel): 941 | """ 942 | Belief about the support for the hypothesis. 943 | 944 | Attributes: 945 | belief (str): Belief about the support for the hypothesis. Choices are: 946 | "definitely true": Hypothesis is definitely true. 947 | "maybe true": Hypothesis may be true. 948 | "uncertain": Hypothesis is equally likely to be true or false (e.g., because of relevant but contradictory evidence). 949 | "maybe false": Hypothesis may be false. 950 | "definitely false": Hypothesis is definitely false. 951 | "cannot comment": Not enough information to comment on the hypothesis (e.g., due to lack of domain knowledge or lack of relevant evidence). 952 | """ 953 | belief: str = Field(..., description="Belief about the hypothesis", 954 | choices=["definitely true", "maybe true", "uncertain", 955 | "maybe false", "definitely false", "cannot comment"]) 956 | 957 | @staticmethod 958 | def parse_response(response: List[dict], 959 | prior_params: Tuple[float, float] = (0.5, 0.5), 960 | weight: float = 1.0) -> 'BeliefTrueFalseCat.DistributionFormat': 961 | """ 962 | Parse the response from the LLM into a DistributionFormat. 963 | 964 | Args: 965 | response (dict): The response from the LLM containing belief counts. 966 | prior_params (Tuple[float, float]): Parameters for the prior Beta distribution (alpha, beta). 967 | weight (float): Weight to apply to the counts (default is 1.0). 968 | 969 | Returns: 970 | BeliefTrueFalseCat.DistributionFormat: Parsed distribution format. 971 | """ 972 | cannot_comment = sum(1 for _res in response if _res["belief"] == "cannot comment") 973 | definitely_true = weight * sum(1 for _res in response if _res["belief"] == "definitely true") 974 | maybe_true = weight * sum(1 for _res in response if _res["belief"] == "maybe true") 975 | uncertain = weight * sum(1 for _res in response if _res["belief"] == "uncertain") 976 | maybe_false = weight * sum(1 for _res in response if _res["belief"] == "maybe false") 977 | definitely_false = weight * sum(1 for _res in response if _res["belief"] == "definitely false") 978 | n = weight * (len(response) - cannot_comment) # Exclude responses with "cannot comment" 979 | 980 | return BeliefTrueFalseCat.DistributionFormat( 981 | n=n, 982 | definitely_true=definitely_true, 983 | maybe_true=maybe_true, 984 | uncertain=uncertain, 985 | maybe_false=maybe_false, 986 | definitely_false=definitely_false, 987 | prior_params=prior_params 988 | ) 989 | 990 | @staticmethod 991 | def kl_divergence(distr1: 'BeliefTrueFalseCat.DistributionFormat', 992 | distr2: 'BeliefTrueFalseCat.DistributionFormat') -> float: 993 | """ 994 | Compute the KL divergence between two distributions. 995 | Args: 996 | distr1 (BeliefTrueFalseCat.DistributionFormat): First distribution. 997 | distr2 (BeliefTrueFalseCat.DistributionFormat): Second distribution. 998 | Returns: 999 | float: KL divergence between the two distributions. 1000 | """ 1001 | alpha1, beta1 = distr1.get_params() 1002 | alpha2, beta2 = distr2.get_params() 1003 | term1 = betaln(alpha2, beta2) - betaln(alpha1, beta1) 1004 | term2 = (alpha1 - alpha2) * psi(alpha1) 1005 | term3 = (beta1 - beta2) * psi(beta1) 1006 | term4 = (alpha2 - alpha1 + beta2 - beta1) * psi(alpha1 + beta1) 1007 | return term1 + term2 + term3 + term4 1008 | 1009 | @staticmethod 1010 | def get_beta_params_from_cat_samples(definitely_true: float, maybe_true: float, uncertain: float, 1011 | maybe_false: float, definitely_false: float) -> Tuple[float, float]: 1012 | """ 1013 | Convert categorical counts into parameters for a Beta distribution. 1014 | 1015 | Args: 1016 | definitely_true: Count of "definitely true" responses. 1017 | maybe_true: Count of "maybe true" responses. 1018 | uncertain: Count of "uncertain" responses. 1019 | maybe_false: Count of "maybe false" responses. 1020 | definitely_false: Count of "definitely false" responses. 1021 | 1022 | Returns: 1023 | Tuple[float, float]: Parameters (alpha, beta) for the Beta distribution. 1024 | """ 1025 | total = definitely_true + maybe_true + uncertain + maybe_false + definitely_false 1026 | alpha = definitely_true * BeliefTrueFalseCat.score_per_category["definitely_true"] + \ 1027 | maybe_true * BeliefTrueFalseCat.score_per_category["maybe_true"] + \ 1028 | uncertain * BeliefTrueFalseCat.score_per_category["uncertain"] + \ 1029 | maybe_false * BeliefTrueFalseCat.score_per_category["maybe_false"] + \ 1030 | definitely_false * BeliefTrueFalseCat.score_per_category["definitely_false"] 1031 | beta = total - alpha 1032 | return alpha, beta 1033 | 1034 | 1035 | BELIEF_MODE_TO_CLS = { 1036 | "boolean": BeliefTrueFalse, 1037 | "boolean_cat": BeliefTrueFalseCat, 1038 | "categorical": BeliefCategorical, 1039 | "categorical_numeric": BeliefCategoricalNumeric, 1040 | "gaussian": BeliefGauss 1041 | } 1042 | 1043 | 1044 | def get_belief( 1045 | hypothesis: str, 1046 | evidence: Optional[List[Dict[str, str]]] = None, 1047 | model: str = "gpt-4o", 1048 | belief_mode: str = "boolean", 1049 | n_samples: int = 5, 1050 | temperature: float | None = None, 1051 | reasoning_effort: str | None = None, 1052 | use_llm_prior: bool = False, 1053 | explicit_prior=None, 1054 | n_retries=3, 1055 | weight: float = 1.0 1056 | ): 1057 | """ 1058 | Get belief distribution for a hypothesis with optional evidence. 1059 | 1060 | Args: 1061 | hypothesis: The hypothesis to evaluate 1062 | evidence: Optional evidence messages to condition the belief 1063 | model: The LLM model to use 1064 | belief_mode: The belief mode to use for parsing responses (e.g., BeliefTrueFalse, BeliefCategorical) 1065 | n_samples: Number of samples to draw from the LLM 1066 | temperature: Temperature for sampling 1067 | reasoning_effort: Reasoning effort for o-series models 1068 | use_llm_prior: Whether to use implicit Bayesian posterior 1069 | explicit_prior: Optional prior distribution to use for Bayesian updates 1070 | n_retries: Number of retries for querying the LLM in case of errors 1071 | weight: Weight to apply to the empirical distribution 1072 | """ 1073 | belief_cls = BELIEF_MODE_TO_CLS.get(belief_mode) 1074 | if belief_cls is None: 1075 | raise ValueError(f"Unknown belief_mode '{belief_mode}'; expected one of {list(BELIEF_MODE_TO_CLS.keys())}") 1076 | 1077 | # Construct the system prompt based on whether we are eliciting prior, implicit posterior, or explicit posterior beliefs 1078 | _system_msgs = [ 1079 | "You are a research scientist skilled at analyzing scientific hypotheses. Your task is to provide your belief about the given hypothesis." 1080 | ] 1081 | if evidence is not None: 1082 | # posterior belief 1083 | _system_msgs.append( 1084 | "Use the provided evidence collected from running experiments to help make your decision. Carefully consider each piece of evidence and decide whether and how any of them affects your belief about the current hypothesis. Note that evidence from previous studies may have an indirect bearing on the hypothesis, so think about how they might relate to the hypothesis even if they do not directly test it." 1085 | ) 1086 | # else: # prior belief 1087 | if use_llm_prior: 1088 | # implicit posterior 1089 | _system_msgs.append( 1090 | "Use your prior knowledge of the research domain to help in your assessment of the hypothesis." 1091 | ) 1092 | else: 1093 | # explicit posterior 1094 | assert evidence is not None 1095 | _system_msgs.append( 1096 | "Disregard any prior beliefs you have about the hypothesis and focus only on the provided evidence." 1097 | ) 1098 | system_prompt = { 1099 | "role": "system", 1100 | "content": " ".join(_system_msgs) 1101 | } 1102 | 1103 | hypothesis_msg = { 1104 | "role": "user", 1105 | "content": f"Hypothesis: {hypothesis}\n\nCarefully reason before making your assessment." 1106 | } 1107 | 1108 | all_msgs = [system_prompt] 1109 | if evidence is not None: 1110 | all_msgs += evidence 1111 | all_msgs.append(hypothesis_msg) 1112 | 1113 | distribution, mean_belief = None, None 1114 | for attempt in range(n_retries): 1115 | try: 1116 | response = query_llm(all_msgs, model=model, n_samples=n_samples, 1117 | temperature=temperature, reasoning_effort=reasoning_effort, 1118 | response_format=belief_cls.ResponseFormat) 1119 | prior_params_or_none = {} 1120 | if explicit_prior is not None: 1121 | prior_params_or_none["prior_params"] = explicit_prior.get_params() 1122 | distribution = belief_cls.parse_response(response, weight=weight, **prior_params_or_none) 1123 | # Compute and store the mean belief 1124 | mean_belief = distribution.get_mean_belief() 1125 | except Exception as e: 1126 | if attempt == n_retries - 1: 1127 | print(f"Querying LLM: ERROR: {e}\nMax retries reached. Returning empty distribution.") 1128 | return None, None 1129 | else: 1130 | print(f"Querying LLM: ERROR: {e}\nRetrying ({attempt + 1}/{n_retries})...") 1131 | 1132 | return distribution, mean_belief 1133 | 1134 | 1135 | def calculate_prior_and_posterior_beliefs(node, n_samples=4, model="gpt-4o", temperature=None, 1136 | reasoning_effort=None, implicit_bayes_posterior=False, surprisal_width=0.2, 1137 | belief_mode="boolean", evidence_msg=None, prior=None, evidence_weight=1.0): 1138 | """ 1139 | Calculate prior and posterior belief distributions for a hypothesis. 1140 | 1141 | Args: 1142 | node: MCTSNode instance containing node information and messages or a dictionary with node data 1143 | n_samples: Number of samples to draw from the LLM 1144 | model: The LLM model to use for querying 1145 | temperature: Temperature for sampling 1146 | reasoning_effort: Reasoning effort for o-series models 1147 | implicit_bayes_posterior: Whether to use implicit Bayesian posterior 1148 | surprisal_width: Minimum difference in mean prior and posterior probabilities required to count as a surprisal 1149 | belief_mode: The belief mode to use for parsing responses (e.g., "boolean", "categorical") 1150 | evidence_msg: Optional evidence messages to condition the posterior belief 1151 | prior: Optional pre-computed prior distribution to use for posterior calculation 1152 | evidence_weight: Weight to apply to the evidence when calculating the posterior belief 1153 | """ 1154 | 1155 | # MODEL_CTXT_LIMITS = { 1156 | # "o4-mini": 200_000, 1157 | # "gpt-4o": 128_000, 1158 | # } 1159 | belief_cls = BELIEF_MODE_TO_CLS.get(belief_mode) 1160 | if belief_cls is None: 1161 | raise ValueError(f"Unknown belief_mode '{belief_mode}'; expected one of {list(BELIEF_MODE_TO_CLS.keys())}") 1162 | 1163 | if type(node) is MCTSNode: 1164 | hypothesis = node.hypothesis 1165 | query = node.query # Contains the hypothesis and experiment plan 1166 | code_output = node.code_output 1167 | analysis = node.analysis 1168 | review = node.review 1169 | else: 1170 | hypothesis = node["hypothesis"] 1171 | query = node.get("query", "N/A") 1172 | code_output = node.get("code_output", "N/A") 1173 | analysis = node.get("analysis", "N/A") 1174 | review = node.get("review", "N/A") 1175 | 1176 | if hypothesis is None: 1177 | return None, None, None, None 1178 | 1179 | if evidence_msg is None: 1180 | evidence_msg = [{ 1181 | "role": "user", 1182 | "content": get_context_string(query, code_output, analysis, review, include_code_output=True) 1183 | }] 1184 | 1185 | if prior is None: 1186 | prior, _ = get_belief( 1187 | hypothesis=hypothesis, 1188 | evidence=None, 1189 | model=model, 1190 | belief_mode=belief_mode, 1191 | n_samples=n_samples, 1192 | temperature=temperature, 1193 | reasoning_effort=reasoning_effort, 1194 | use_llm_prior=True, 1195 | ) 1196 | 1197 | posterior, _ = get_belief( 1198 | hypothesis=hypothesis, 1199 | evidence=evidence_msg, 1200 | model=model, 1201 | belief_mode=belief_mode, 1202 | n_samples=n_samples, 1203 | temperature=temperature, 1204 | reasoning_effort=reasoning_effort, 1205 | use_llm_prior=implicit_bayes_posterior, 1206 | explicit_prior=prior, 1207 | weight=evidence_weight 1208 | ) 1209 | 1210 | if prior is None or posterior is None: 1211 | raise ValueError("Belief distribution could not be computed.") 1212 | 1213 | belief_change = abs(posterior.mean - prior.mean) 1214 | kl_divergence = belief_cls.kl_divergence(posterior, prior) 1215 | 1216 | return prior, posterior, belief_change, kl_divergence 1217 | 1218 | 1219 | if __name__ == "__main__": 1220 | 1221 | # Unit test 1222 | from mcts_utils import load_mcts_from_json 1223 | 1224 | path = "/Users/dagarwal/code/github/allenai/autods-humans/debug/20250826-181702" 1225 | root, nodes_by_level = load_mcts_from_json(path) 1226 | belief_kl = [] 1227 | prior_posterior = [] 1228 | for level, nodes in nodes_by_level.items(): 1229 | for node in nodes: 1230 | if node.prior is not None: 1231 | belief_cls = BELIEF_MODE_TO_CLS[node.prior.to_dict()["_type"]] 1232 | prior = node.prior 1233 | posterior = node.posterior 1234 | belief_change = round(posterior.mean - prior.mean, 2) 1235 | kl_div = round(belief_cls.kl_divergence(posterior, prior), 2) 1236 | belief_kl.append((belief_change, kl_div)) 1237 | prior_posterior.append((prior.get_params(), posterior.get_params())) 1238 | print(f"Total nodes: {len(belief_kl)}\n\n") 1239 | # Print statistics (percentiles, mean, std) of belief change and KL divergence 1240 | belief_changes = [abs(bc[0]) for bc in belief_kl] 1241 | kl_divergences = [bc[1] for bc in belief_kl] 1242 | print(f"Belief Change - Mean: {np.mean(belief_changes):.2f}, Std: {np.std(belief_changes):.2f}") 1243 | print(f"Belief Change - Min: {np.min(belief_changes):.2f}, Max: {np.max(belief_changes):.2f}") 1244 | print(f"Belief Change - 25th Percentile: {np.percentile(belief_changes, 25):.2f}") 1245 | print(f"Belief Change - 50th Percentile: {np.percentile(belief_changes, 50):.2f}") 1246 | print(f"Belief Change - 75th Percentile: {np.percentile(belief_changes, 75):.2f}") 1247 | 1248 | print(f"KL Divergence - Mean: {np.mean(kl_divergences):.2f}, Std: {np.std(kl_divergences):.2f}") 1249 | print(f"KL Divergence - Min: {np.min(kl_divergences):.2f}, Max: {np.max(kl_divergences):.2f}") 1250 | print(f"KL Divergence - 25th Percentile: {np.percentile(kl_divergences, 25):.2f}") 1251 | print(f"KL Divergence - 50th Percentile: {np.percentile(kl_divergences, 50):.2f}") 1252 | print(f"KL Divergence - 75th Percentile: {np.percentile(kl_divergences, 75):.2f}\n\n") 1253 | 1254 | # Print a table of belief change, KL divergence, and prior/posterior parameters in sorted order of KL divergence 1255 | print(f"{'Belief Change':<20} {'KL Divergence':<20} {'Prior Params':<50} {'Posterior Params':<50}") 1256 | sorted_tuples = sorted(zip(belief_kl, prior_posterior), key=lambda x: x[0][1]) 1257 | for (belief_change, kl_div), (prior_params, posterior_params) in sorted_tuples: 1258 | prior_params_str = ", ".join(f"{p:.2f}" for p in prior_params) 1259 | posterior_params_str = ", ".join(f"{p:.2f}" for p in posterior_params) 1260 | print( 1261 | f"{belief_change:<20} {round(kl_div / np.mean(kl_divergences), 2):<20} {prior_params_str:<50} {posterior_params_str:<50}") 1262 | 1263 | print("\n\n") 1264 | 1265 | sorted_tuples = sorted(zip(belief_kl, prior_posterior), key=lambda x: abs(x[0][0])) 1266 | for (belief_change, kl_div), (prior_params, posterior_params) in sorted_tuples: 1267 | prior_params_str = ", ".join(f"{p:.2f}" for p in prior_params) 1268 | posterior_params_str = ", ".join(f"{p:.2f}" for p in posterior_params) 1269 | print( 1270 | f"{belief_change:<20} {round(kl_div / np.mean(kl_divergences), 2):<20} {prior_params_str:<50} {posterior_params_str:<50}") 1271 | --------------------------------------------------------------------------------