├── popper
├── llm
│ ├── __init__.py
│ ├── utils.py
│ ├── prompt_utils.py
│ └── custom_model.py
├── __init__.py
├── version.py
├── react_agent.py
├── benchmark.py
├── popper.py
├── react_utils.py
├── utils.py
└── prompt_utils.py
├── baseline_agents
├── __init__.py
├── utils
│ └── dv_log.py
├── react_agent.py
├── coder_agent.py
├── coder_utils.py
├── react_utils.py
└── self_refine_agent.py
├── .pre-commit-config.yaml
├── benchmark_scripts
├── __init__.py
├── run_discoverybench.sh
├── run_discoverybench_docker.sh
├── run_discovery_bench_baseline.py
├── run_targetval.sh
├── run_targetval_baseline.py
├── run_targetval_benchmark.py
└── run_discovery_bench.py
├── MANIFEST.in
├── setup.cfg
├── figs
└── popper_agent_illustration.png
├── requirements.txt
├── Dockerfile
├── setup.py
├── .gitignore
└── README.md
/popper/llm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/baseline_agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos: []
--------------------------------------------------------------------------------
/benchmark_scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/popper/__init__.py:
--------------------------------------------------------------------------------
1 | from .popper import Popper
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include requirements.txt
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
--------------------------------------------------------------------------------
/figs/popper_agent_illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snap-stanford/POPPER/HEAD/figs/popper_agent_illustration.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn
2 | tqdm
3 | pandas
4 | gradio
5 | pydantic==2.9.2
6 | openai==1.61.0
7 | langchain_core==0.3.22
8 | langchain_openai==0.2.11
9 | langchain_anthropic==0.2.3
10 | numpy==1.26.4
11 | scipy==1.13.1
12 | langchain==0.3.7
13 | langgraph==0.2.39
14 | langchain_experimental==0.3.3
15 | ipython
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.9-slim
2 |
3 | WORKDIR /app
4 |
5 | COPY requirements.txt .
6 |
7 | RUN pip install --no-cache-dir -r requirements.txt
8 |
9 | COPY benchmark_scripts /app/benchmark_scripts
10 |
11 | COPY popper /app/popper
12 |
13 | COPY baseline_agents /app/baseline_agents
14 |
15 | # Create directories for storing output and logs
16 | RUN mkdir -p /app/data /app/.logs && chmod -R 777 /app/data /app/.logs
17 |
18 | RUN useradd -m nonrootuser
19 | USER nonrootuser
20 |
21 | ENTRYPOINT ["python"]
22 |
--------------------------------------------------------------------------------
/benchmark_scripts/run_discoverybench.sh:
--------------------------------------------------------------------------------
1 | python benchmark_scripts/run_discovery_bench.py \
2 | --exp_name discovery_bench --num_tests 5 --samples 300 --permute --e_value --react --relevance_checker \
3 | --path /dfs/scratch0/lansong/discoverybench &> .logs/discovery_bench.log
4 |
5 | python benchmark_scripts/run_discovery_bench_baseline.py \
6 | --exp_name discovery_bench_baseline_react --agent_type react --samples 50 --log_file .logs/discovery_bench_baseline_react.log --permute \
7 | --path /dfs/scratch0/lansong/discoverybench &> .logs/discovery_bench_baseline_react_stdout.log
8 |
9 |
--------------------------------------------------------------------------------
/popper/version.py:
--------------------------------------------------------------------------------
1 |
2 | # Based on NiLearn package
3 | # License: simplified BSD
4 |
5 | # PEP0440 compatible formatted version, see:
6 | # https://www.python.org/dev/peps/pep-0440/
7 | #
8 | # Generic release markers:
9 | # X.Y
10 | # X.Y.Z # For bug fix releases
11 | #
12 | # Admissible pre-release markers:
13 | # X.YaN # Alpha release
14 | # X.YbN # Beta release
15 | # X.YrcN # Release Candidate
16 | # X.Y # Final release
17 | #
18 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
19 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev'
20 | #
21 | __version__ = '0.0.5' # pragma: no cover
22 |
--------------------------------------------------------------------------------
/benchmark_scripts/run_discoverybench_docker.sh:
--------------------------------------------------------------------------------
1 | # docker build --no-cache -t discovery_benchmark .
2 | docker build -t discovery_benchmark .
3 |
4 |
5 | docker run \
6 | -v /dfs/scratch0/lansong/discoverybench:/dfs/scratch0/lansong/discoverybench \
7 | -v /dfs/scratch0/lansong/data:/app/data \
8 | --name discovery_bench \
9 | --env-file .env \
10 | discovery_benchmark benchmark_scripts/run_discovery_bench.py \
11 | --exp_name discovery_bench_v3 --num_tests 5 --samples 100 --permute --e_value --react --relevance_checker \
12 | --path /dfs/scratch0/lansong/discoverybench &> .logs/discovery_bench_evalue_react_v3.log
13 |
14 |
15 | docker wait discovery_bench
16 |
17 | docker rm discovery_bench
18 |
--------------------------------------------------------------------------------
/baseline_agents/utils/dv_log.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 |
5 | class JSONFormatter(logging.Formatter):
6 | def format(self, record):
7 | log_data = {
8 | 'timestamp': self.formatTime(record),
9 | 'level': record.levelname,
10 | 'logger_name': record.name,
11 | 'module': record.module,
12 | 'function': record.funcName,
13 | 'message': record.getMessage()
14 | }
15 | return json.dumps(log_data)
16 |
17 |
18 | class DVLogger():
19 | def __init__(self, logger_name, log_filename):
20 | self.file_handler = logging.FileHandler(log_filename)
21 | self.file_handler.setLevel(logging.INFO)
22 | self.json_formatter = JSONFormatter()
23 | self.file_handler.setFormatter(self.json_formatter)
24 | self.logger = logging.getLogger(logger_name)
25 | self.logger.setLevel(logging.INFO)
26 | self.logger.addHandler(self.file_handler)
27 |
28 | # To log string
29 | def log(self, message):
30 | self.logger.info(message)
31 |
32 | # To log dictionary (agent reponse)
33 | def log_json(self, message):
34 | self.logger.info(json.dumps(message))
35 |
36 | def close(self):
37 | self.file_handler.close()
38 | self.logger.removeHandler(self.file_handler)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | # read the contents of README file
4 | from os import path
5 | from io import open # for Python 2 and 3 compatibility
6 |
7 | # get __version__ from _version.py
8 | ver_file = path.join('popper', 'version.py')
9 | with open(ver_file) as f:
10 | exec(f.read())
11 |
12 | this_directory = path.abspath(path.dirname(__file__))
13 |
14 |
15 | # read the contents of README.md
16 | def readme():
17 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
18 | return f.read()
19 |
20 |
21 | # read the contents of requirements.txt
22 | with open(path.join(this_directory, 'requirements.txt'),
23 | encoding='utf-8') as f:
24 | requirements = f.read().splitlines()
25 |
26 | setup(name='popper_agent',
27 | version=__version__,
28 | license='MIT',
29 | description='POPPER',
30 | long_description=readme(),
31 | long_description_content_type='text/markdown',
32 | url='https://github.com/snap-stanford/POPPER',
33 | author='POPPER Team',
34 | author_email='kexinh@cs.stanford.edu',
35 | packages=find_packages(exclude=['test']),
36 | zip_safe=False,
37 | include_package_data=True,
38 | install_requires=requirements,
39 | setup_requires=['setuptools>=38.6.0']
40 | )
41 |
--------------------------------------------------------------------------------
/popper/llm/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import traceback
3 | import logging
4 | import uuid
5 | import re
6 |
7 | def clean_json_string(json_string: str) -> str:
8 | # Remove trailing commas before closing braces/brackets in JSON structures
9 | cleaned_json_string = re.sub(r',\s*(\}|\])', r'\1', json_string)
10 | return cleaned_json_string
11 |
12 | def parse_llm_output(llm_output: str):
13 | # Regular expressions to match the JSON structures
14 | tool_calls_pattern = r'\{\s*"type":\s*"tool_calls",\s*"content":\s*\[.*?\]\s*\}'
15 | text_message_pattern = r'\{\s*"type":\s*"text_message",\s*"content":\s*".*?"\s*\}'
16 |
17 | # Initial text that appears before the JSON
18 | final_text_before_json = ""
19 |
20 | # Try to find tool calls JSON
21 | tool_calls_match = re.search(tool_calls_pattern, llm_output, re.DOTALL)
22 | if tool_calls_match:
23 | tool_calls_json = tool_calls_match.group(0)
24 | final_text_before_json = llm_output[:tool_calls_match.start()]
25 | # Clean up the JSON string (remove trailing commas)
26 | tool_calls_json = clean_json_string(tool_calls_json)
27 | print(tool_calls_json)
28 | return final_text_before_json, json.loads(tool_calls_json)
29 |
30 | # Try to find text message JSON
31 | text_message_match = re.search(text_message_pattern, llm_output)
32 | if text_message_match:
33 | text_message_json = text_message_match.group(0)
34 | final_text_before_json = llm_output[:text_message_match.start()]
35 | return final_text_before_json, json.loads(text_message_json)
36 |
37 | # If neither pattern is found, treat the entire output as text message
38 | return "", {
39 | "type": "text_message",
40 | "content": llm_output,
41 | }
--------------------------------------------------------------------------------
/popper/llm/prompt_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | def bind_tools_to_system_prompt(system_prompt, tools, tool_choice = None):
4 | if not tool_choice or tool_choice == 'none':
5 | return f'''You are an intelligent agent capable of calling tools to complete user-assigned tasks.
6 | Here are the instructions specified by the user:
7 | """{system_prompt}"""
8 |
9 | In addition, you have access to the following tools:
10 | {json.dumps(tools, indent=4)}
11 |
12 | You may output any intermediate thoughts or reasonings before delivering your final response.
13 | Your final response must either be at least one tool call or a response message to the user.
14 |
15 | To make one or more tool calls, wrap your final response in the following JSON format:
16 | {{
17 | "type": "tool_calls",
18 | "content": [
19 | {{
20 | "name": "name of the function to call",
21 | "id": "an unique id for this tool call",
22 | "arguments": {{
23 | "argument1": value1,
24 | "argument2": value2,
25 | ...
26 | }}
27 | }},
28 | ...
29 | ]
30 | }}
31 |
32 | To send a direct response message to the user, wrap your final response in the following JSON format:
33 | {{
34 | "type": "text_message",
35 | "content": "content of the message according to the user instructions"
36 | }}
37 |
38 | You must choose either to send tool calls or a direct response message. Be sure to format the final response properly according to the given JSON specs.
39 |
40 | DO NOT put anything after the final response JSON object.
41 | '''
42 |
43 | # tool choice!
44 | system_prompt = f'''You are an intelligent agent capable of calling tools to complete user-assigned tasks.
45 | Here are the instructions specified by the user:
46 | """{system_prompt}"""
47 |
48 | In addition, you have access to the following tools:
49 | {json.dumps(tools, indent=4)}
50 |
51 | You may output any intermediate thoughts or reasonings before delivering your final response.
52 | Your final response MUST BE one or more tool calls.
53 |
54 | To make one or more tool calls, wrap your final response in the following JSON format:
55 | {{
56 | "type": "tool_calls",
57 | "content": [
58 | {{
59 | "name": "name of the function to call",
60 | "id": "an unique id for this tool call",
61 | "arguments": {{
62 | "argument1": value1,
63 | "argument2": value2,
64 | ...
65 | }}
66 | }},
67 | ...
68 | ]
69 | }}
70 |
71 |
72 | You MUST wrap your response as a tool call formatted in the above JSON schema.
73 |
74 | DO NOT put anything after the tool call.
75 | '''
76 | return system_prompt
--------------------------------------------------------------------------------
/benchmark_scripts/run_discovery_bench_baseline.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append('../')
4 |
5 | from baseline_agents.coder_agent import BaseAgent
6 | from baseline_agents.react_agent import ReactAgent
7 | from baseline_agents.self_refine_agent import SelfRefineAgent
8 | from baseline_agents.coder_utils import load_data_to_coder_globals
9 | from baseline_agents.react_utils import load_data_to_react_globals
10 | from popper.benchmark import discovery_bench_hypothesis
11 |
12 | from sklearn.metrics import accuracy_score, average_precision_score
13 |
14 | from tqdm import tqdm
15 | import argparse
16 | import traceback
17 | import time
18 | import json
19 |
20 | argparser = argparse.ArgumentParser()
21 | argparser.add_argument("--exp_name", type=str, default="discovery_bench_baseline")
22 | argparser.add_argument("--model_name", type=str, default="claude-3-5-sonnet")
23 | argparser.add_argument("--agent_type", type=str, choices=['coder', 'react', 'self_refine'], default="coder")
24 | argparser.add_argument('--samples', type=int, default=50)
25 | argparser.add_argument('--starts_from', type=int, default=0)
26 | argparser.add_argument("--log_file", type=str, default=".logs/baseline_log.log")
27 | argparser.add_argument('--permute', action='store_true', default=False)
28 | argparser.add_argument('--path', type=str, required = True)
29 |
30 | args = argparser.parse_args()
31 |
32 | exp_name = args.exp_name
33 | exp_name += f"_{args.agent_type}"
34 | if args.permute:
35 | # data_loader.permute_selected_columns()
36 | exp_name+='_permuted'
37 |
38 | samples = args.samples
39 | bm = discovery_bench_hypothesis(num_samples = samples, path = path)
40 | predictions = []
41 | targets = []
42 |
43 | agent = None
44 | load_to_globals = None
45 | if args.agent_type == 'coder' or args.agent_type == 'react':
46 | load_to_globals = load_data_to_coder_globals if args.agent_type == 'coder' else load_data_to_react_globals
47 |
48 | start = time.time()
49 |
50 | for i, example in tqdm(enumerate(bm.get_iterator()), total=samples, desc="Processing"):
51 | if i < args.starts_from:
52 | print(f"Skipping example {i}")
53 | continue
54 |
55 | if args.agent_type == 'coder':
56 | agent = BaseAgent(
57 | model_name=args.model_name,
58 | api_config="baseline_agents/config/api_config.json",
59 | model_config="baseline_agents/config/model_config.json",
60 | log_file=args.log_file
61 | )
62 | elif args.agent_type == 'react':
63 | agent = ReactAgent(
64 | model_name=args.model_name,
65 | api_config="baseline_agents/config/api_config.json",
66 | model_config="baseline_agents/config/model_config.json",
67 | log_file=args.log_file
68 | )
69 | elif args.agent_type == 'self_refine':
70 | agent = SelfRefineAgent(
71 | llm="claude-3-5-sonnet-20241022"
72 | )
73 | else:
74 | raise ValueError(f"Agent {args.agent_type} not found")
75 |
76 | try:
77 | data_loader = example["data_loader"]
78 | if load_to_globals is not None:
79 | load_to_globals(data_loader)
80 |
81 | query = example['prompt']
82 | output = agent.generate(data_loader=data_loader, query=query)
83 | if output is not None:
84 | predictions.append((0.0, output["output"]))
85 | else:
86 | predictions.append((0.0, False))
87 | targets.append(example['answer'])
88 | except:
89 | print(traceback.format_exc())
90 |
91 |
92 | print("------------------------------------")
93 |
94 | end = time.time()
95 | print(f"Total elapsed time: {end - start}")
96 |
97 | print("Benchmark Results:")
98 | print(f"Predictions: {predictions}")
99 | print(f"Targets: {targets}")
100 | eval_results = bm.evaluate(predictions, targets)
101 |
102 | for metric in eval_results:
103 | print(f"{metric}: {eval_results[metric]}")
--------------------------------------------------------------------------------
/benchmark_scripts/run_targetval.sh:
--------------------------------------------------------------------------------
1 | ## TargetVal benchmarks
2 | data_path="/dfs/user/kexinh/popper_data_processed"
3 | for seed in 1 2 3 4 5
4 | do
5 | datasets=("IFNG" "IL2")
6 | for dataset in "${datasets[@]}"
7 | do
8 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path
9 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path
10 | done
11 | done
12 |
13 | ### For Popper-codegen:
14 | python run_targetval_benchmark.py --e_value --relevance_checker --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path
15 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path
16 |
17 |
18 | ### For Popper-NoRelevanceChecker:
19 | python run_targetval_benchmark.py --e_value --react --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path
20 | python run_targetval_benchmark.py --e_value --permute --react --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path
21 |
22 | ### For Popper-Fisher Combined Test:
23 | python run_targetval_benchmark.py --relevance_checker --react --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path
24 | python run_targetval_benchmark.py --relevance_checker --permute --react --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path
25 |
26 | ### For Popper-LLM estiamted likelihood ratio:
27 | python run_targetval_benchmark.py --llm_approx --relevance_checker --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path
28 | python run_targetval_benchmark.py --llm_approx --relevance_checker --permute --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path
29 |
30 | ### For other LLMs:
31 | # GPT-4o
32 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --model gpt-4o --samples 20 --dataset $dataset --path $data_path
33 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --model gpt-4o --samples 50 --dataset $dataset --path $data_path
34 |
35 | # o1
36 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --model o1-2024-12-17 --samples 20 --dataset $dataset --path $data_path
37 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --model o1-2024-12-17 --samples 50 --dataset $dataset --path $data_path
38 |
39 | # Haiku
40 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --model claude-3-5-haiku-20241022 --samples 20 --dataset $dataset --path $data_path
41 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --model claude-3-5-haiku-20241022 --samples 50 --dataset $dataset --path $data_path
42 |
43 | ### For user study type I error genes:
44 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --user_study_neg_genes --dataset IL2 --path $data_path
45 |
46 |
47 | ### Baselines
48 | # Coder-o1
49 | python run_targetval_baseline.py --agent_type coder --model o1-2024-12-17 --permute --samples 50 --seed $seed --dataset IL2 --path $data_path
50 | python run_targetval_baseline.py --agent_type coder --model o1-2024-12-17 --samples 20 --seed $seed --dataset IL2 --path $data_path
51 |
52 | # Coder
53 | python run_targetval_baseline.py --agent_type coder --permute --samples 50 --seed $seed --dataset IL2 --path $data_path
54 | python run_targetval_baseline.py --agent_type coder --samples 20 --seed $seed --dataset IL2 --path $data_path
55 |
56 | # Self-refine
57 | python run_targetval_baseline.py --agent_type self_refine --use_full_data --permute --samples 50 --seed $seed --dataset IL2 --path $data_path
58 | python run_targetval_baseline.py --agent_type self_refine --use_full_data --samples 20 --seed $seed --dataset IL2 --path $data_path
59 |
60 | # React
61 | python run_targetval_baseline.py --use_other_claude_api --use_simple_template --agent_type react --permute --samples 50 --seed $seed --dataset IL2 --path $data_path
62 | python run_targetval_baseline.py --use_other_claude_api --use_simple_template --agent_type react --samples 20 --seed $seed --dataset IL2 --path $data_path
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | demo_test.ipynb
3 | .gradio/
4 | test.ipynb
5 | *xlsx
6 | notebooks/figures/*
7 | notebooks/tmp_dir/*
8 | notebooks/tmp_directory/*
9 | notebooks/save_folder/*
10 | notebooks/cache
11 | notebooks/\[draft*
12 | notebooks/\[scratch*
13 | scripts/save_folder/*
14 | scripts/\[draft*
15 | scripts/tmp_dir/*
16 | scripts/tmp_directory/*
17 | bioagentos/agent/popper_agent.py
18 | bioagentos/tool/data_tool/old_data_tool/*
19 | examples/*.png
20 | examples/*.csv
21 | examples/*.fa
22 | scratch/
23 | packages/
24 | data
25 | data_lake
26 | __pycache__/
27 | *.py[cod]
28 | *$py.class
29 | res/
30 | examples/tmp*
31 |
32 | # C extensions
33 | *.so
34 |
35 | # Distribution / packaging
36 | .Python
37 | build/
38 | develop-eggs/
39 | dist/
40 | downloads/
41 | eggs/
42 | .eggs/
43 | lib/
44 | lib64/
45 | parts/
46 | sdist/
47 | var/
48 | wheels/
49 | share/python-wheels/
50 | *.egg-info/
51 | .installed.cfg
52 | *.egg
53 | MANIFEST
54 |
55 | # PyInstaller
56 | # Usually these files are written by a python script from a template
57 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
58 | *.manifest
59 | *.spec
60 |
61 | # Installer logs
62 | pip-log.txt
63 | pip-delete-this-directory.txt
64 |
65 | # Unit test / coverage reports
66 | htmlcov/
67 | .tox/
68 | .nox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *.cover
75 | *.py,cover
76 | .hypothesis/
77 | .pytest_cache/
78 | cover/
79 |
80 | # Translations
81 | *.mo
82 | *.pot
83 |
84 | # Django stuff:
85 | *.log
86 | local_settings.py
87 | db.sqlite3
88 | db.sqlite3-journal
89 |
90 | # Flask stuff:
91 | instance/
92 | .webassets-cache
93 |
94 | # Scrapy stuff:
95 | .scrapy
96 |
97 | # Sphinx documentation
98 | docs/_build/
99 |
100 | # PyBuilder
101 | .pybuilder/
102 | target/
103 |
104 | # Jupyter Notebook
105 | .ipynb_checkpoints
106 |
107 | # IPython
108 | profile_default/
109 | ipython_config.py
110 |
111 | # pyenv
112 | # For a library or package, you might want to ignore these files since the code is
113 | # intended to run in multiple environments; otherwise, check them in:
114 | # .python-version
115 |
116 | # pipenv
117 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
119 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
120 | # install all needed dependencies.
121 | #Pipfile.lock
122 |
123 | # poetry
124 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
125 | # This is especially recommended for binary packages to ensure reproducibility, and is more
126 | # commonly ignored for libraries.
127 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
128 | #poetry.lock
129 |
130 | # pdm
131 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
132 | #pdm.lock
133 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
134 | # in version control.
135 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
136 | .pdm.toml
137 | .pdm-python
138 | .pdm-build/
139 |
140 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
141 | __pypackages__/
142 |
143 | # Celery stuff
144 | celerybeat-schedule
145 | celerybeat.pid
146 |
147 | # SageMath parsed files
148 | *.sage.py
149 |
150 | # Environments
151 | .env
152 | .venv
153 | env/
154 | venv/
155 | ENV/
156 | env.bak/
157 | venv.bak/
158 |
159 | # Spyder project settings
160 | .spyderproject
161 | .spyproject
162 |
163 | # Rope project settings
164 | .ropeproject
165 |
166 | # mkdocs documentation
167 | /site
168 |
169 | # mypy
170 | .mypy_cache/
171 | .dmypy.json
172 | dmypy.json
173 |
174 | # Pyre type checker
175 | .pyre/
176 |
177 | # pytype static type analyzer
178 | .pytype/
179 |
180 | # Cython debug symbols
181 | cython_debug/
182 |
183 | # PyCharm
184 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
185 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
186 | # and can be added to the global gitignore or merged into this file. For a more nuclear
187 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
188 | #.idea/
189 |
--------------------------------------------------------------------------------
/benchmark_scripts/run_targetval_baseline.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append('../')
4 |
5 | from baseline_agents.coder_agent import BaseAgent
6 | from baseline_agents.react_agent import ReactAgent
7 | from baseline_agents.self_refine_agent import SelfRefineAgent
8 | from baseline_agents.coder_utils import load_data_to_coder_globals
9 | from baseline_agents.react_utils import load_data_to_react_globals
10 | from popper.benchmark import gene_perturb_hypothesis
11 | from popper.utils import ExperimentalDataLoader
12 | from langchain_core.prompts import ChatPromptTemplate
13 | from popper.utils import get_llm
14 | from sklearn.metrics import accuracy_score, average_precision_score
15 | from pydantic import BaseModel, Field
16 | from typing import (
17 | Optional, List, Tuple, Union, Literal, Dict, TypedDict, Annotated
18 | )
19 | from tqdm import tqdm
20 | import argparse
21 | import traceback
22 | import time
23 | import json
24 |
25 | argparser = argparse.ArgumentParser()
26 | argparser.add_argument("--exp_name", type=str, default="gene_baseline")
27 | argparser.add_argument("--model_name", type=str, default="claude-3-5-sonnet")
28 | argparser.add_argument("--agent_type", type=str, choices=['coder', 'react', 'self_refine'], default="coder")
29 | argparser.add_argument('--samples', type=int, default=25)
30 | argparser.add_argument('--starts_from', type=int, default=0)
31 | argparser.add_argument("--log_file", type=str, default=".logs/baseline_log.log")
32 | argparser.add_argument('--permute', action='store_true', default=False)
33 | argparser.add_argument('--use_full_data', action='store_true', default=False)
34 | argparser.add_argument("--dataset", type=str, default="IL2")
35 | argparser.add_argument('--seed', type=int, default=-1)
36 | argparser.add_argument('--use_simple_template', action='store_true', default=False)
37 | argparser.add_argument('--path', type=str, required = True)
38 |
39 | args = argparser.parse_args()
40 |
41 | data_path = args.path
42 | if args.use_full_data:
43 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'all_bio')
44 | else:
45 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'default')
46 |
47 | exp_name = args.dataset
48 | exp_name += f"_{args.agent_type}"
49 | if args.permute:
50 | data_loader.permute_selected_columns()
51 | exp_name+='_permuted'
52 |
53 | if args.use_full_data:
54 | exp_name+='_full_data'
55 |
56 | if args.seed != -1:
57 | exp_name+=f'_seed_{args.seed}'
58 |
59 | if args.use_simple_template:
60 | exp_name+='_simple_template'
61 |
62 | samples = args.samples
63 | bm = gene_perturb_hypothesis(num_of_samples = samples,
64 | permuted=args.permute, dataset = args.dataset, path = args.path)
65 | predictions = []
66 | targets = []
67 |
68 | agent = None
69 | load_to_globals = None
70 | if args.agent_type == 'coder' or args.agent_type == 'react':
71 | load_to_globals = load_data_to_coder_globals if args.agent_type == 'coder' else load_data_to_react_globals
72 |
73 | start = time.time()
74 | for i, example in tqdm(enumerate(bm.get_iterator()), total=samples, desc="Processing"):
75 | if i < args.starts_from:
76 | print(f"Skipping example {i}")
77 | continue
78 |
79 | if args.agent_type == 'coder':
80 | agent = BaseAgent(
81 | model_name=args.model_name,
82 | api_config="baseline_agents/config/api_config.json",
83 | model_config="baseline_agents/config/model_config.json",
84 | log_file=args.log_file,
85 | simple_prompt=args.use_simple_template
86 | )
87 | elif args.agent_type == 'react':
88 | agent = ReactAgent(
89 | model_name=args.model_name,
90 | api_config="baseline_agents/config/api_config.json",
91 | model_config="baseline_agents/config/model_config.json",
92 | log_file=args.log_file,
93 | simple_prompt=args.use_simple_template
94 | )
95 | elif args.agent_type == 'self_refine':
96 | agent = SelfRefineAgent(
97 | llm="claude-3-5-sonnet-20241022"
98 | )
99 | else:
100 | raise ValueError(f"Agent {args.agent_type} not found")
101 |
102 | try:
103 | if load_to_globals is not None:
104 | load_to_globals(data_loader)
105 |
106 | query = example['prompt']
107 | datasets = data_loader.data_desc
108 | if args.agent_type == 'coder':
109 | output = agent.generate(data_loader=data_loader, query=query)
110 | elif args.agent_type in ['react', 'self_refine']:
111 | output = agent.generate(data_loader=data_loader, query=query)
112 |
113 | print(output)
114 | if output is not None:
115 | predictions.append((0.0, output["output"], example['gene']))
116 | else:
117 | predictions.append((0.0, False, example['gene']))
118 | targets.append(example['binary_answer'])
119 | except:
120 | print(traceback.format_exc())
121 |
122 |
123 | print("------------------------------------")
124 |
125 | end = time.time()
126 | print(f"Total elapsed time: {end - start}")
127 |
128 | print("Benchmark Results:")
129 | print(f"Predictions: {predictions}")
130 | print(f"Targets: {targets}")
131 | eval_results = bm.evaluate(predictions, targets)
132 |
133 | for metric in eval_results:
134 | print(f"{metric}: {eval_results[metric]}")
135 |
136 |
137 | import pickle
138 | import os
139 | os.makedirs(args.path + '/res', exist_ok=True)
140 | with open(args.path + '/res/' + exp_name + '_res_final.pkl', 'wb') as f:
141 | pickle.dump(res, f)
--------------------------------------------------------------------------------
/benchmark_scripts/run_targetval_benchmark.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append('../')
4 |
5 | from popper.benchmark import gene_perturb_hypothesis
6 | from popper.agent import SequentialFalsificationTest
7 | from popper.utils import ExperimentalDataLoader
8 |
9 | from tqdm import tqdm
10 | import argparse
11 |
12 | argparser = argparse.ArgumentParser()
13 | argparser.add_argument('--permute', action='store_true', default=False)
14 | argparser.add_argument('--samples', type=int, default=25)
15 | argparser.add_argument('--llm_approx', action='store_true', default=False)
16 | argparser.add_argument('--e_value', action='store_true', default=False)
17 | argparser.add_argument('--relevance_checker', action='store_true', default=False)
18 | argparser.add_argument('--react', action='store_true', default=False)
19 | argparser.add_argument('--use_full_data', action='store_true', default=False)
20 | argparser.add_argument('--max_num_of_tests', type=int, default=5)
21 | argparser.add_argument('--alpha', type=float, default=0.1)
22 | argparser.add_argument('--seed', type=int, default=-1)
23 | argparser.add_argument("--model", type=str, default="claude-3-5-sonnet-20241022")
24 | argparser.add_argument("--dataset", type=str, default="IL2")
25 | argparser.add_argument('--data_sampling', type=int, default=-1)
26 | argparser.add_argument('--user_study_neg_genes', action='store_true', default=False)
27 | argparser.add_argument('--is_locally_served', action='store_true', default=False)
28 | argparser.add_argument('--server_port', type=int, required=False)
29 | argparser.add_argument("--api_key", type=str, default="EMPTY")
30 | argparser.add_argument('--path', type=str, required = True)
31 |
32 | args = argparser.parse_args()
33 |
34 | data_path = args.path
35 |
36 | if args.data_sampling != -1:
37 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'all_bio', data_sampling = args.data_sampling)
38 | elif args.use_full_data:
39 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'all_bio')
40 | else:
41 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'default')
42 |
43 | exp_name = args.dataset
44 | if args.permute:
45 | data_loader.permute_selected_columns()
46 | exp_name+='_permuted'
47 |
48 | if args.max_num_of_tests != 5:
49 | exp_name+=f'_max_{args.max_num_of_tests}'
50 |
51 | if args.alpha != 0.1:
52 | exp_name+=f'_alpha_{args.alpha}'
53 |
54 | if args.llm_approx:
55 | exp_name+='_llm_approx'
56 |
57 | if args.use_full_data:
58 | exp_name+='_full_data'
59 |
60 | if args.react:
61 | exp_name+="_react_v2"
62 | #args.e_value = True
63 |
64 | if args.e_value:
65 | exp_name+='_e_value'
66 |
67 | if args.relevance_checker:
68 | exp_name+='_relevance_checker'
69 |
70 | if args.seed != -1:
71 | exp_name+=f'_seed_{args.seed}'
72 |
73 | if args.model[:6] != 'claude':
74 | exp_name+=f'_{args.model}'
75 | else:
76 | if 'haiku' in args.model:
77 | exp_name+='_haiku'
78 |
79 | if args.data_sampling != -1:
80 | exp_name+=f'_sampling_{args.data_sampling}'
81 |
82 | if args.user_study_neg_genes:
83 | exp_name+='_user_study_neg_genes'
84 |
85 | print('Experiment name: ' + exp_name)
86 | res = {}
87 | samples = args.samples
88 | bm = gene_perturb_hypothesis(num_of_samples = samples, permuted=args.permute,
89 | dataset = args.dataset, user_study_neg_genes= args.user_study_neg_genes, path = args.path)
90 | #response = []
91 | for example in tqdm(bm.get_iterator(), total=samples, desc="Processing"):
92 | import traceback
93 | try:
94 | agent = SequentialFalsificationTest(llm = args.model, is_local=args.is_locally_served, port=args.server_port, api_key=args.api_key)
95 | if args.llm_approx:
96 | agent.configure(data = data_loader,
97 | alpha = args.alpha, beta = 0.1,
98 | aggregate_test = 'LLM_approx',
99 | max_num_of_tests = 5,
100 | max_retry = args.max_num_of_tests, time_limit = 2,
101 | llm_approx = True,
102 | relevance_checker = args.relevance_checker)
103 | else:
104 | if args.e_value:
105 | agent.configure(data = data_loader, alpha = args.alpha,
106 | beta = 0.1, aggregate_test = 'E-value',
107 | max_num_of_tests = args.max_num_of_tests, max_retry = 5, time_limit = 2,
108 | relevance_checker = args.relevance_checker, use_react_agent=args.react)
109 | else:
110 | agent.configure(data = data_loader, alpha = args.alpha, beta = 0.1,
111 | aggregate_test = 'Fisher', max_num_of_tests = args.max_num_of_tests,
112 | max_retry = args.max_num_of_tests, time_limit = 2,
113 | relevance_checker = args.relevance_checker, use_react_agent=args.react)
114 |
115 | log, last_message, parsed_result = agent.go(example['prompt'])
116 | res[example['gene']] = (log, last_message, parsed_result, agent.res_stat)
117 | except Exception as e:
118 | print(f"Error for prompt '{example['prompt']}': {e}")
119 | print(traceback.format_exc()) # Print the full traceback for debugging
120 | res[example['gene']] = ('Error', traceback.format_exc())
121 | continue
122 |
123 | import pickle
124 | import os
125 | os.makedirs(args.path + '/res', exist_ok=True)
126 | with open(args.path + '/res/' + exp_name + '_res_final.pkl', 'wb') as f:
127 | pickle.dump(res, f)
--------------------------------------------------------------------------------
/baseline_agents/react_agent.py:
--------------------------------------------------------------------------------
1 | from baseline_agents.react_utils import create_agent
2 | from langchain_anthropic import ChatAnthropic
3 | from langchain_together import Together
4 | from langchain_openai import ChatOpenAI
5 | from langchain_google_genai import ChatGoogleGenerativeAI
6 | import os
7 | import json
8 | import langchain
9 | from langchain_core.callbacks import FileCallbackHandler, StdOutCallbackHandler
10 | from baseline_agents.utils.dv_log import DVLogger
11 | import uuid
12 | import traceback
13 |
14 |
15 | # uncomment the following line to enable debug mode
16 | # langchain.debug = True
17 |
18 | def get_prompt_data(
19 | prompt_config: str = None
20 | ):
21 | if prompt_config is None and os.environ.get("PROMPT_CONFIG") is None:
22 | raise ValueError("PROMPT_CONFIG not set and prompt_config not provided")
23 | else:
24 | prompt_config = prompt_config or os.environ.get("PROMPT_CONFIG")
25 |
26 | with open(prompt_config, "r") as file:
27 | return json.load(file)
28 |
29 |
30 | class ReactAgent():
31 | def __init__(
32 | self,
33 | model_config: str = None,
34 | api_config: str = None,
35 | model_name: str = "claude-3-5-sonnet",
36 | log_file: str = "base_react_agent.log",
37 | max_iterations: int = 25,
38 | simple_prompt: bool = False
39 | ):
40 | self.logfile = log_file
41 | self.logger = DVLogger(f"{model_name}_{uuid.uuid4()}", log_file)
42 | # logger.add(log_file, format="{time} {level} {message}", level="INFO")
43 | self.file_handler = FileCallbackHandler(self.logfile)
44 | self.stdout_handler = StdOutCallbackHandler()
45 |
46 | # set max iterations
47 | self.max_iterations = max_iterations
48 |
49 | # check if model config is provided
50 | if model_config is None and os.environ.get("MODEL_CONFIG") is None:
51 | raise ValueError("MODEL_CONFIG not set and model_config not provided")
52 | else:
53 | # override environment variable config path if model_config is provided
54 | model_config = model_config or os.environ.get("MODEL_CONFIG")
55 |
56 | # do a similar check for api_config
57 | if api_config is None and os.environ.get("API_CONFIG") is None:
58 | raise ValueError("API_CONFIG not set and api_config not provided")
59 | else:
60 | api_config = api_config or os.environ.get("API_CONFIG")
61 |
62 |
63 | # load model config
64 | self.model_name = model_name
65 | with open(model_config, "r") as file:
66 | self.model_config = json.load(file)
67 |
68 | # load api config
69 | with open(api_config, "r") as file:
70 | self.api_config = json.load(file)
71 |
72 | try:
73 | # get full model name and type
74 | self.full_model_name = self.model_config['models'][self.model_name]['model_name']
75 | self.model_type = self.model_config['models'][self.model_name]['model_type']
76 | except KeyError:
77 | raise ValueError(f"Model {model_name} not found in model config")
78 |
79 | try:
80 | # get api key using model type
81 | self.api_key = self.api_config[self.model_type]
82 | except KeyError:
83 | raise ValueError(f"API key not found for {self.model_type}")
84 |
85 | # get the model
86 | self.llm = self.get_model(
87 | api=self.model_type,
88 | model=self.full_model_name,
89 | api_key=self.api_key
90 | )
91 |
92 | # create agent
93 | self.agent = create_agent(
94 | llm=self.llm,
95 | handlers=[self.file_handler, self.stdout_handler],
96 | max_iterations=self.max_iterations,
97 | simple_template=simple_prompt
98 | )
99 |
100 | def get_model(
101 | self,
102 | api,
103 | api_key,
104 | model,
105 | **kwargs
106 | ):
107 | llm = None
108 | if (api == "together"):
109 | llm = Together(
110 | model=model,
111 | together_api_key=api_key,
112 | **kwargs
113 | )
114 | elif (api == "anthropic"):
115 | llm = ChatAnthropic(
116 | model=model,
117 | api_key=api_key,
118 | **kwargs
119 | )
120 | elif (api == "openai"):
121 | llm = ChatOpenAI(
122 | model=model,
123 | api_key=api_key,
124 | **kwargs
125 | )
126 | elif (api == "google"):
127 | llm = ChatGoogleGenerativeAI(
128 | model=model,
129 | google_api_key=api_key,
130 | **kwargs
131 | )
132 | else:
133 | raise ValueError(f"Invalid API: {api}")
134 | return llm
135 |
136 | def generate(self, data_loader, query):
137 | try:
138 | table_dict = data_loader.table_dict
139 | dataset_desc = data_loader.data_desc
140 | self.agent.tools[0]._set_globals(table_dict)
141 | output = self.agent.invoke(input={
142 | "system_prompt": "You are a scientific agent who can plan and execute python code multiple times to answer a query based on one or more datasets. All datasets have already been loaded into the global namespace as pandas dataframes.",
143 | "input": f"""Question: Is the following hypothesis true or false? {query}
144 | Datasets: {dataset_desc}"""
145 | })
146 | self.logger.log_json(output)
147 | return output
148 | except Exception as e:
149 | print("Execution Stopped due to : ", e)
150 | print(traceback.format_exc())
151 | self.logger.logger.error(f"Execution Stopped due to : {e}")
152 | self.logger.close()
--------------------------------------------------------------------------------
/benchmark_scripts/run_discovery_bench.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.getcwd())
4 |
5 | from popper.benchmark import discovery_bench_hypothesis
6 | from popper.agent import SequentialFalsificationTest
7 | from sklearn.metrics import accuracy_score, average_precision_score
8 |
9 | from tqdm import tqdm
10 | import argparse
11 | import traceback
12 | import time
13 | import json
14 | import pickle
15 | import numpy as np
16 |
17 |
18 | argparser = argparse.ArgumentParser()
19 | argparser.add_argument("--exp_name", type=str, default="discovery_bench")
20 | argparser.add_argument("--model", type=str, default="claude-3-5-sonnet-20241022")
21 | argparser.add_argument('--samples', type=int, default=50)
22 | argparser.add_argument('--num_tests', type=int, default=3)
23 | argparser.add_argument('--starts_from', type=int, default=0)
24 | argparser.add_argument('--permute', action='store_true', default=False)
25 | argparser.add_argument('--llm_approx', action='store_true', default=False)
26 | argparser.add_argument('--e_value', action='store_true', default=False)
27 | argparser.add_argument('--react', action='store_true', default=False)
28 | argparser.add_argument('--relevance_checker', action='store_true', default=False)
29 | argparser.add_argument('--is_locally_served', action='store_true', default=False)
30 | argparser.add_argument('--server_port', type=int, required=False)
31 | argparser.add_argument("--api_key", type=str, default="EMPTY")
32 | argparser.add_argument('--path', type=str, required = True)
33 |
34 | args = argparser.parse_args()
35 |
36 |
37 | exp_name = args.exp_name
38 |
39 | exp_name += args.model
40 |
41 | exp_name += f'_{args.num_tests}tests'
42 |
43 | if args.permute:
44 | # data_loader.permute_selected_columns()
45 | exp_name+='_permuted'
46 |
47 | if args.llm_approx:
48 | exp_name+='_llm_approx'
49 |
50 | if args.e_value:
51 | exp_name+='_e_value'
52 |
53 | if args.react:
54 | exp_name+="_react"
55 |
56 | if args.relevance_checker:
57 | exp_name+="_relevance_checker"
58 |
59 | print(f"Running {exp_name}")
60 |
61 | res = []
62 | samples = args.samples
63 | bm = discovery_bench_hypothesis(num_samples = samples, path=args.path)
64 | predictions = []
65 | targets = []
66 |
67 | start = time.time()
68 |
69 | #response = []
70 | for i, example in tqdm(enumerate(bm.get_iterator()), total=samples, desc="Processing"):
71 | if i < args.starts_from:
72 | print(f"Skipping example {i}")
73 | continue
74 | try:
75 | data_loader = example["data_loader"]
76 |
77 | permuted = "not permuted" if example["answer"] else "permuted"
78 | print("======================================================")
79 | print(f'Processing {(example["task"], example["metadataid"], example["query_id"], permuted)}')
80 | for name, df in data_loader.table_dict.items():
81 | print(name)
82 | print(df.head())
83 | print("------------------------------------")
84 |
85 | agent = SequentialFalsificationTest(llm = args.model, is_local=args.is_locally_served, port=args.server_port, api_key=args.api_key)
86 | if args.llm_approx:
87 | agent.configure(data = data_loader, alpha = 0.1, beta = 0.1, aggregate_test = 'LLM_approx', max_num_of_tests = args.num_tests, max_retry = 5, time_limit = 2, llm_approx = True, domain=example['domain'], relevance_checker=args.relevance_checker)
88 | else:
89 | if args.e_value:
90 | agent.configure(data = data_loader, alpha = 0.1, beta = 0.1, aggregate_test = 'E-value', max_num_of_tests = args.num_tests, max_retry = 5, time_limit = 2, domain=example['domain'], relevance_checker=args.relevance_checker, use_react_agent=args.react, max_failed_tests=args.num_tests)
91 | else:
92 | agent.configure(data = data_loader, alpha = 0.1, beta = 0.1, aggregate_test = 'Fisher', max_num_of_tests = args.num_tests, max_retry = 5, time_limit = 2, domain=example['domain'], relevance_checker=args.relevance_checker, max_failed_tests=args.num_tests)
93 |
94 | log, last_message, parsed_result = agent.go(example['prompt'])
95 | predictions.append((agent.res_stat, agent.res))
96 | targets.append(example['answer'])
97 | res.append({
98 | "task": example["task"],
99 | "metadataid": example["metadataid"],
100 | "query_id": example["query_id"],
101 | "log": log,
102 | "last_message": last_message,
103 | "parsed_result": parsed_result,
104 | "res": agent.res,
105 | "res_stat": agent.res_stat,
106 | "answer": example["answer"]
107 | })
108 | # res[(example["task"], example["metadataid"], example["query_id"])] = (log, last_message, parsed_result, agent.res, agent.res_stat, example['answer'])
109 | except Exception as e:
110 | print(f"Error for prompt '{example['prompt']}': {e}")
111 | print(traceback.format_exc()) # Print the full traceback for debugging
112 | res.append({
113 | "task": example["task"],
114 | "metadataid": example["metadataid"],
115 | "query_id": example["query_id"],
116 | "error": traceback.format_exc()
117 | })
118 | # res[(example["task"], example["metadataid"], example["query_id"])] = ('Error', traceback.format_exc())
119 | # predictions.append((0.0, False))
120 | continue
121 |
122 | output_path = os.path.join(os.getcwd(), 'data/' + exp_name + '.json')
123 |
124 | # with open(os.path.join(os.getcwd(), 'data/' + exp_name + '.pkl'), 'wb') as f:
125 | with open(output_path, 'w') as f:
126 | # pickle.dump(res, f)
127 | json.dump(res, f, indent=4)
128 |
129 | print(f"Results saved to {output_path}")
130 | print("------------------------------------")
131 | print("------------------------------------")
132 |
133 | end = time.time()
134 | print(f"Total elapsed time: {end - start}")
135 |
136 | print("Benchmark Results:")
137 | print(f"Predictions: {predictions}")
138 | print(f"Targets: {targets}")
139 | eval_results = bm.evaluate(predictions, targets)
140 |
141 | for metric in eval_results:
142 | print(f"{metric}: {eval_results[metric]}")
--------------------------------------------------------------------------------
/baseline_agents/coder_agent.py:
--------------------------------------------------------------------------------
1 | from baseline_agents.coder_utils import create_agent
2 | from baseline_agents.utils.dv_log import DVLogger
3 | from langchain_anthropic import ChatAnthropic
4 | from langchain_together import Together
5 | from langchain_openai import ChatOpenAI
6 | from langchain_google_genai import ChatGoogleGenerativeAI
7 | from langchain_core.callbacks import FileCallbackHandler, StdOutCallbackHandler
8 |
9 | import os
10 | import json
11 | import uuid
12 | import traceback
13 |
14 | # uncomment the following line to enable debug mode
15 | # langchain.debug = True
16 |
17 |
18 | def get_prompt_data(
19 | prompt_config: str = None
20 | ):
21 | if prompt_config is None and os.environ.get("PROMPT_CONFIG") is None:
22 | raise ValueError("PROMPT_CONFIG not set and prompt_config not provided")
23 | else:
24 | prompt_config = prompt_config or os.environ.get("PROMPT_CONFIG")
25 |
26 | with open(prompt_config, "r") as file:
27 | return json.load(file)
28 |
29 |
30 | class BaseAgent():
31 | def __init__(
32 | self,
33 | model_config: str = None,
34 | api_config: str = None,
35 | model_name: str = "claude-3-5-sonnet",
36 | log_file: str = "base_coder_agent.log",
37 | max_iterations: int = 5,
38 | simple_prompt: bool = False
39 | ):
40 | self.simple_prompt = simple_prompt
41 | self.logfile = log_file
42 | # logger.add(log_file, format="{time} {level} {message}", level="INFO")
43 | self.logger = DVLogger(f"{model_name}_{uuid.uuid4()}", log_file)
44 | self.file_handler = FileCallbackHandler(self.logfile)
45 | self.stdout_handler = StdOutCallbackHandler()
46 | self.max_iterations = max_iterations
47 |
48 | # check if model config is provided
49 | if model_config is None and os.environ.get("MODEL_CONFIG") is None:
50 | raise ValueError("MODEL_CONFIG not set and model_config not provided")
51 | else:
52 | # override environment variable config path if model_config is provided
53 | model_config = model_config or os.environ.get("MODEL_CONFIG")
54 |
55 | # do a similar check for api_config
56 | if api_config is None and os.environ.get("API_CONFIG") is None:
57 | raise ValueError("API_CONFIG not set and api_config not provided")
58 | else:
59 | api_config = api_config or os.environ.get("API_CONFIG")
60 |
61 | # load model config
62 | self.model_name = model_name
63 | with open(model_config, "r") as file:
64 | self.model_config = json.load(file)
65 |
66 | # load api config
67 | with open(api_config, "r") as file:
68 | self.api_config = json.load(file)
69 |
70 | # get full model name and type
71 | try:
72 | self.full_model_name = self.model_config['models'][self.model_name]['model_name']
73 | self.model_type = self.model_config['models'][self.model_name]['model_type']
74 | except KeyError:
75 | raise ValueError(f"Model {model_name} not found in model config")
76 |
77 | try:
78 | # get api key using model type
79 | self.api_key = self.api_config[self.model_type]
80 | if self.api_key == "":
81 | raise ValueError(f"API key for {self.model_type} in api config is empty")
82 | except KeyError:
83 | raise ValueError(f"API key for {self.model_type} not found in api config")
84 |
85 | # get the model
86 | self.llm = self.get_model(
87 | api=self.model_type,
88 | model=self.full_model_name,
89 | api_key=self.api_key
90 | )
91 |
92 | # create agent
93 | self.agent = create_agent(
94 | llm=self.llm,
95 | handlers=[self.file_handler, self.stdout_handler],
96 | max_iterations=self.max_iterations,
97 | simple_template=self.simple_prompt
98 | )
99 |
100 | def get_model(
101 | self,
102 | api,
103 | api_key,
104 | model,
105 | **kwargs
106 | ):
107 | llm = None
108 | if (api == "together"):
109 | llm = Together(
110 | model=model,
111 | together_api_key=api_key,
112 | **kwargs
113 | )
114 | elif (api == "anthropic"):
115 | llm = ChatAnthropic(
116 | model=model,
117 | api_key=api_key,
118 | **kwargs
119 | )
120 | elif (api == "openai"):
121 | llm = ChatOpenAI(
122 | model=model,
123 | api_key=api_key,
124 | **kwargs
125 | )
126 | elif (api == "google"):
127 | llm = ChatGoogleGenerativeAI(
128 | model=model,
129 | google_api_key=api_key,
130 | **kwargs
131 | )
132 | else:
133 | raise ValueError(f"Invalid API: {api}")
134 | return llm
135 |
136 | def generate(self, data_loader, query):
137 | dataset_desc = data_loader.data_desc
138 |
139 | if self.simple_prompt:
140 | system_prompt = "You are a scientific agent who can execute python codes to answer a query based on one or more datasets. All datasets have already been loaded into the global namespace as pandas dataframes."
141 | else:
142 | system_prompt = "You are a scientific agent who can execute a python code only once to answer a query based on one or more datasets. All datasets have already been loaded into the global namespace as pandas dataframes.",
143 |
144 | try:
145 | output = self.agent.invoke(input={
146 | "system_prompt": system_prompt,
147 | "input": f"""Question: Is the following hypothesis true or false? {query}
148 | Datasets: {dataset_desc}"""
149 | })
150 | self.logger.log_json(output)
151 | return output
152 | except Exception as e:
153 | print("Execution Stopped due to : ", e)
154 | print(traceback.format_exc())
155 | self.logger.logger.error(f"Execution Stopped due to : {e}")
156 | self.logger.close()
157 |
--------------------------------------------------------------------------------
/popper/react_agent.py:
--------------------------------------------------------------------------------
1 | from popper.react_utils import create_agent
2 | from popper.prompt_utils import get_react_coding_agent_system_prompt
3 | from popper.llm.custom_model import CustomChatModel
4 | from langchain_anthropic import ChatAnthropic
5 | from langchain_openai import ChatOpenAI
6 | import os
7 | import json
8 | import langchain
9 | import openai
10 | from langchain_core.callbacks import FileCallbackHandler, StdOutCallbackHandler
11 | import logging
12 | import uuid
13 | import traceback
14 | import io
15 | import contextlib
16 | import sys
17 | import re
18 |
19 | # uncomment the following line to enable debug mode
20 | # langchain.debug = True
21 |
22 | class LiveLogger:
23 | """Custom stdout handler that logs in real-time while also printing output."""
24 | def __init__(self, log):
25 | self.original_stdout = sys.stdout # Store original stdout
26 | self.log = log # Log dictionary
27 | self.current_buffer = [] # Store intermediate logs
28 |
29 | def clean_message(self, message):
30 | """Remove ANSI escape codes and filter out unnecessary logs."""
31 | # Remove ANSI escape codes (color formatting)
32 | message = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', message)
33 |
34 | # Filter out specific unwanted messages
35 | unwanted_logs = [
36 | "> Entering new AgentExecutor chain...",
37 | "> Finished chain."
38 | ]
39 | if any(unwanted in message for unwanted in unwanted_logs):
40 | return None # Skip logging this message
41 |
42 | return message.strip() if message.strip() else None
43 |
44 | def write(self, message):
45 | cleaned_message = self.clean_message(message)
46 | if cleaned_message:
47 | self.original_stdout.write(cleaned_message + "\n") # Print to console
48 | self.original_stdout.flush() # Ensure immediate output
49 |
50 | # Append each new log update separately in Markdown format
51 | self.current_buffer.append(cleaned_message)
52 | self.log['executor'].append(f"```\n{cleaned_message}\n```") # Markdown formatting
53 |
54 | def flush(self):
55 | self.original_stdout.flush()
56 |
57 |
58 | def get_prompt_data(
59 | prompt_config: str = None
60 | ):
61 | if prompt_config is None and os.environ.get("PROMPT_CONFIG") is None:
62 | raise ValueError("PROMPT_CONFIG not set and prompt_config not provided")
63 | else:
64 | prompt_config = prompt_config or os.environ.get("PROMPT_CONFIG")
65 |
66 | with open(prompt_config, "r") as file:
67 | return json.load(file)
68 |
69 |
70 | class ReactAgent():
71 | def __init__(
72 | self,
73 | model_name: str = "claude-3-5-sonnet-20241022",
74 | model_config: str = None,
75 | api_config: str = None,
76 | max_iterations: int = 25,
77 | prompt_revision: bool = False,
78 | port=None,
79 | api_key="EMPTY",
80 | ):
81 | self.prompt_revision = prompt_revision
82 | self.api = "custom"
83 | if model_name[:7] == 'claude-':
84 | self.api = 'anthropic'
85 | elif model_name[:4] == 'gpt-':
86 | self.api = 'openai'
87 | else:
88 | self.api = 'local'
89 |
90 | # logger.add(log_file, format="{time} {level} {message}", level="INFO")
91 | self.stdout_handler = StdOutCallbackHandler()
92 |
93 | # set max iterations
94 | self.max_iterations = max_iterations
95 |
96 | # load model config
97 | self.model_name = model_name
98 |
99 | # get the model
100 | self.llm = self.get_model(
101 | api=self.api,
102 | model=self.model_name,
103 | port=port,
104 | api_key=api_key
105 | )
106 |
107 | # create agent
108 | self.agent = create_agent(
109 | llm=self.llm,
110 | handlers=[self.stdout_handler],
111 | max_iterations=self.max_iterations
112 | )
113 |
114 | def get_model(
115 | self,
116 | api,
117 | model,
118 | port=None,
119 | api_key=None,
120 | **kwargs
121 | ):
122 | llm = None
123 | if (api == "anthropic"):
124 | llm = ChatAnthropic(
125 | model=model,
126 | api_key=os.environ["ANTHROPIC_API_KEY"],
127 | **kwargs
128 | )
129 | elif (api == "openai"):
130 | llm = ChatOpenAI(
131 | model=model,
132 | api_key=os.environ["OPENAI_API_KEY"],
133 | **kwargs
134 | )
135 | # elif (api == 'llama'):
136 | # llm = CustomChatModel(
137 | # model=model,
138 | # model_type='custom',
139 | # **kwargs
140 | # )
141 | # llm.client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY").chat.completions
142 | else:
143 | # Llama or other locally-served models
144 | assert port is not None, "Port must be specified for local models"
145 | llm = CustomChatModel(
146 | model=model,
147 | model_type='custom',
148 | **kwargs
149 | )
150 | api_key = "EMPTY" if api_key is None else api_key
151 | llm.client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key=api_key).chat.completions
152 | return llm
153 |
154 | def generate(self, data_loader, test_spec, domain, log=None):
155 | try:
156 | self.agent.tools[0]._set_globals(data_loader.table_dict)
157 | dataset_desc = data_loader.data_desc
158 |
159 | # Use LiveLogger only if a log is provided
160 | logger = LiveLogger(log) if log is not None else sys.stdout
161 |
162 | # Redirect stdout to capture real-time logs
163 | sys.stdout = logger
164 | try:
165 | output = self.agent.invoke(input={
166 | "system_prompt": get_react_coding_agent_system_prompt(domain=domain, prompt_revision=self.prompt_revision),
167 | "input": f"""Falsification Test: {test_spec}
168 | Datasets: {dataset_desc}
169 | Thought:"""
170 | })
171 | finally:
172 | sys.stdout = logger.original_stdout # Restore stdout
173 |
174 | return output['output']
175 |
176 | except Exception as e:
177 | error_message = f"Execution Stopped due to: {e}\n{traceback.format_exc()}"
178 | print(error_message)
179 | if log is not None:
180 | log['executor'].append(f"```\n{error_message}\n```") # Markdown format
181 | return None
182 | '''
183 | def generate(self, data_loader, test_spec, domain, log = None):
184 | try:
185 | self.agent.tools[0]._set_globals(data_loader.table_dict)
186 | dataset_desc = data_loader.data_desc
187 | output = self.agent.invoke(input={
188 | "system_prompt": get_react_coding_agent_system_prompt(domain=domain, prompt_revision=self.prompt_revision),
189 | "input": f"""Falsification Test: {test_spec}
190 | Datasets: {dataset_desc}
191 | Thought:"""
192 | })
193 | return output['output']
194 | except Exception as e:
195 | print("Execution Stopped due to : ", e)
196 | print(traceback.format_exc())
197 | return None
198 | '''
199 |
--------------------------------------------------------------------------------
/baseline_agents/coder_utils.py:
--------------------------------------------------------------------------------
1 | # Set up the base template
2 | from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser
3 | from langchain.prompts import StringPromptTemplate
4 | from langchain.tools import BaseTool
5 | from langchain.chains.llm import LLMChain
6 | from langchain_experimental.tools.python.tool import PythonAstREPLTool
7 | from langchain.schema import AgentAction, AgentFinish
8 | from pydantic import Field
9 | from typing import List, Union
10 | import contextlib
11 | import io
12 | import logging
13 | import re
14 |
15 |
16 | logging.basicConfig(level=logging.INFO)
17 |
18 |
19 | # set the maximum number of python code blocks that can be run
20 | MAX_TURNS = 1
21 |
22 | template = """{system_prompt}
23 |
24 | You have access to the following tools:
25 | {tools}
26 |
27 | Use the following format:
28 |
29 | Question: an input hypothesis that you must decide if it is True or False
30 | Datasets: the names and descriptions of datasets relevant to the input hypothesis
31 | Action: the action to take, should be one of [{tool_names}]
32 | Action Input: the input to the action
33 | Observation: the result of the action
34 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer
35 | Final Answer: True/False. Please output True if the input hypothesis is valid (e.g., you are able to reject the null hypothesis with statistical significance) and False if the input hypothesis is invalid (e.g., if you fail to reject the null hypothesis).
36 |
37 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY:
38 |
39 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section.
40 |
41 | NOTE: You will be able to execute the python code ONLY ONCE. So you will need to generate the complete code to solve the query in one go. Please provide the final answer after that.
42 |
43 | Begin!
44 |
45 | {input}
46 | {agent_scratchpad}"""
47 |
48 |
49 | template_v2 = """{system_prompt}
50 |
51 | You have access to the following tools:
52 | {tools}
53 |
54 | Use the following format:
55 |
56 | Question: an input hypothesis that you must decide if it is True or False
57 | Datasets: the names and descriptions of datasets relevant to the input hypothesis
58 | Action: the action to take, should be one of [{tool_names}]
59 | Action Input: the input to the action
60 | Observation: the result of the action
61 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer
62 | Final Answer: True/False. Please output True if you believe the input hypothesis is correct and False if the input hypothesis is not based on your analysis.
63 |
64 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY:
65 |
66 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section.
67 |
68 | NOTE: You will be able to execute the python code ONLY ONCE. So you will need to generate the complete code to solve the query in one go. Please provide the final answer after that.
69 |
70 | Begin!
71 |
72 | {input}
73 | {agent_scratchpad}"""
74 |
75 |
76 | def load_data_to_coder_globals(data_loader):
77 | for name, df in data_loader.table_dict.items():
78 | globals()[name] = df
79 |
80 |
81 | # Set up a prompt template
82 | class CustomPromptTemplate(StringPromptTemplate):
83 | # The template to use
84 | template: str
85 | # The list of tools available
86 | tools: List[BaseTool]
87 |
88 | def format(self, **kwargs) -> str:
89 | # Get the intermediate steps (AgentAction, Observation tuples)
90 | # Format them in a particular way
91 | intermediate_steps = kwargs.pop("intermediate_steps")
92 | thoughts = ""
93 | for action, observation in intermediate_steps:
94 | thoughts += action.log
95 | thoughts += f"\nObservation: {observation}\nThought: "
96 | # Set the agent_scratchpad variable to that value
97 | kwargs["agent_scratchpad"] = thoughts
98 | # Create a tools variable from the list of tools provided
99 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
100 | # Create a list of tool names for the tools provided
101 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
102 | return self.template.format(**kwargs)
103 |
104 |
105 | # CustomOutputParser to parse the output of the LLM and execute actions
106 | class CustomOutputParser(AgentOutputParser):
107 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
108 | # Check if agent should finish
109 | if "Final Answer:" in llm_output:
110 | output = llm_output.split("Final Answer:")[-1].split()[0].strip().lower()
111 | if output not in ["true", "false", "yes", "no", "y", "n"]:
112 | raise ValueError(f"Could not parse LLM output: `{llm_output}`")
113 | return AgentFinish(
114 | return_values={"output": output in ["true", "yes", 'y']},
115 | log=llm_output,
116 | )
117 | # Parse out the action and action input
118 | regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
119 | match = re.search(regex, llm_output, re.DOTALL)
120 | if not match:
121 | raise ValueError(f"Could not parse LLM output: `{llm_output}`")
122 | action = match.group(1).strip()
123 | action_input = match.group(2)
124 | # Return the action and action input
125 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
126 |
127 |
128 | class CustomPythonAstREPLTool(PythonAstREPLTool):
129 | max_turns: int = Field(0, exclude=True)
130 |
131 | def _run(self, query: str, run_manager=None):
132 | if self.max_turns >= MAX_TURNS:
133 | return 'You cannot run the code more than once - you have already run it earlier. Please provide the "Final Answer:" immediately after "Thought:", based on whatever information you got till now. Do not attempt to output an "Action:" or run the code again.'
134 | self.max_turns += 1
135 |
136 | code_match = re.search(r"```(.*?)```", query, re.DOTALL)
137 | if code_match:
138 | # Extract code within backticks
139 | code = code_match.group(1)
140 | else:
141 | code = query
142 | code = code.strip()
143 | if code.startswith('"""'):
144 | code = code[3:].lstrip()
145 | if code.endswith('"""'):
146 | code = code[:-3].rstrip()
147 | if code.startswith("python"):
148 | code = code[len("python"):].lstrip()
149 |
150 | code = "import pandas as pd\n" + code
151 |
152 | exec_globals = globals().copy()
153 | exec_globals.update(__builtins__)
154 |
155 | output_capture = io.StringIO()
156 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture):
157 | logging.getLogger().handlers[0].stream = output_capture
158 | try:
159 | exec(code, exec_globals)
160 | except Exception as e:
161 | return str(e)
162 |
163 | # Retrieve the output and return it
164 | output = output_capture.getvalue()
165 | return output if output else "Execution completed without output."
166 |
167 | def create_agent(
168 | llm,
169 | handlers,
170 | max_iterations = None,
171 | early_stopping_method: str = "force",
172 | simple_template = False
173 | ):
174 | output_parser = CustomOutputParser()
175 | python_tool = CustomPythonAstREPLTool(callbacks=handlers)
176 | tools = [python_tool]
177 | tool_names = [tool.name for tool in tools]
178 |
179 | if simple_template:
180 | use_template = template_v2
181 | else:
182 | use_template = template
183 |
184 | print('Creating agent with template:', use_template)
185 |
186 | prompt = CustomPromptTemplate(
187 | template=use_template,
188 | tools=tools,
189 | input_variables=["system_prompt", "input", "intermediate_steps"]
190 | )
191 | llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=handlers)
192 |
193 | agent = LLMSingleActionAgent(
194 | llm_chain=llm_chain,
195 | output_parser=output_parser,
196 | stop=["\nObservation:"],
197 | allowed_tools=tool_names
198 | )
199 |
200 | return AgentExecutor.from_agent_and_tools(
201 | agent=agent,
202 | tools=tools,
203 | verbose=True,
204 | max_iterations=max_iterations,
205 | callbacks=handlers,
206 | early_stopping_method=early_stopping_method
207 | )
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # POPPER: Automated Hypothesis Validation with Agentic Sequential Falsifications
2 |
3 | This repository hosts the code base for the paper
4 |
5 | **Automated Hypothesis Validation with Agentic Sequential Falsifications**
6 |
7 | Kexin Huang*, Ying Jin*, Ryan Li*, Michael Y. Li, Emmanuel Candès, Jure Leskovec\
8 | [Link to Paper](https://arxiv.org/abs/2502.09858)
9 |
10 |
11 | If you find this work useful, please consider cite:
12 |
13 | ```
14 | @misc{popper,
15 | title={Automated Hypothesis Validation with Agentic Sequential Falsifications},
16 | author={Kexin Huang and Ying Jin and Ryan Li and Michael Y. Li and Emmanuel Candès and Jure Leskovec},
17 | year={2025},
18 | eprint={2502.09858},
19 | archivePrefix={arXiv}
20 | }
21 | ```
22 |
23 |
24 | ### Overview
25 | Hypotheses are central to information acquisition, decision-making, and discovery. However, many real-world hypotheses are abstract, high-level statements that are difficult to validate directly.
26 | This challenge is further intensified by the rise of hypothesis generation from Large Language Models (LLMs), which are prone to hallucination and produce hypotheses in volumes that make manual validation impractical. Here we propose Popper, an agentic framework for rigorous automated validation of free-form hypotheses.
27 | Guided by Karl Popper's principle of falsification, Popper validates a hypothesis using LLM agents that design and execute falsification experiments targeting its measurable implications. A novel sequential testing framework ensures strict Type-I error control while actively gathering evidence from diverse observations, whether drawn from existing data or newly conducted procedures.
28 | We demonstrate Popper on six domains including biology, economics, and sociology. Popper delivers robust error control, high power, and scalability. Furthermore, compared to human scientists, Popper achieved comparable performance in validating complex biological hypotheses while reducing time by 10 folds, providing a scalable, rigorous solution for hypothesis validation.
29 |
30 |
31 |

32 |
33 |
34 | ## Installation
35 |
36 | We highly recommend using a virtual environment to manage the dependencies.
37 |
38 | ```bash
39 | conda create -n popper_env python=3.10
40 | conda activate popper_env
41 | ```
42 |
43 | For direct usage of Popper, you can install the package via pip:
44 | ```bash
45 | pip install popper_agent
46 | ```
47 |
48 | For source code development, you can clone the repository and install the package:
49 | ```bash
50 | git clone https://github.com/snap-stanford/POPPER.git
51 | cd POPPER
52 | pip install -r requirements.txt
53 | ```
54 |
55 | Add the OpenAI/Anthropic API key to the environment variables:
56 | ```bash
57 | export OPENAI_API_KEY="YOUR_API_KEY"
58 | export ANTHROPIC_API_KEY="YOUR_API_KEY"
59 | ```
60 |
61 | Datasets will be automatically downloaded to specified data folder when you run the code.
62 |
63 | ## Demo
64 |
65 | A demo is provided in [here](demo.ipynb) to show how to use the Popper agent to validate a hypothesis and basic functionalities of the Popper agent.
66 |
67 | ## Core API Usage
68 |
69 | ```python
70 | from popper import Popper
71 |
72 | # Initialize the Popper agent
73 | agent = Popper(llm="claude-3-5-sonnet-20240620")
74 |
75 | # Register data for hypothesis testing;
76 | # for bio/discoverybench data in the paper,
77 | # it will be automatically downloaded to your specified data_path
78 | agent.register_data(data_path='path/to/data', loader_type='bio')
79 |
80 | # Configure the agent with custom parameters
81 | agent.configure(
82 | alpha=0.1,
83 | max_num_of_tests=5,
84 | max_retry=3,
85 | time_limit=2,
86 | aggregate_test='E-value',
87 | relevance_checker=True,
88 | use_react_agent=True
89 | )
90 |
91 | # Validate a hypothesis
92 | results = agent.validate(hypothesis="Your hypothesis here")
93 |
94 | # Print the results
95 | print(results)
96 | ```
97 |
98 | ## Running locally-served LLM with OpenAI-Compatible API
99 | **Popper** supports inferencing with local LLM servers such as vLLM, SGLang, and llama.cpp, as long as they support OpenAI-compatible API. Here are some example usage with locally hosted LLMs:
100 |
101 | Using [SGLang](https://github.com/sgl-project/sglang/tree/main):
102 | ```bash
103 | # mistral large 2 with SGLang, using 4 GPUs with 8-bit quantization
104 | python -m sglang.launch_server --model-path mistralai/Mistral-Large-Instruct-2411 --port 40000 --host 0.0.0.0 --tp 4 --quantization fp8 --mem-fraction-static 0.8 --trust-remote-code
105 | ```
106 | ```python
107 | from popper import Popper
108 | agent = Popper(llm="mistralai/Mistral-Large-Instruct-2411", is_locally_served=True, server_port=40000)
109 | agent.configure(alpha=0.1)
110 | agent.register_data(data_path='path/to/data', loader_type='bio')
111 | agent.validate(hypothesis = 'YOUR HYPOTHESIS')
112 | ```
113 |
114 | Using [vLLM](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html):
115 | ```bash
116 | vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123
117 | ```
118 | ```python
119 | from popper import Popper
120 | agent = Popper(llm="NousResearch/Meta-Llama-3-8B-Instruct", is_locally_served=True, server_port=8000, api_key="token-abc123")
121 | ```
122 |
123 | Using [llama.cpp](https://github.com/ggml-org/llama.cpp):
124 | ```bash
125 | llama-server -m model.gguf --port 8080
126 | ```
127 | ```python
128 | from popper import Popper
129 | agent = Popper(llm="qwen2 1.5B", is_locally_served=True, server_port=8080)
130 | ```
131 |
132 | ## Run on your own hypothesis and database
133 |
134 | You can simply dump in a set of datasets in your domain (e.g. business, economics, political science, etc.) and run Popper on your own hypothesis.
135 | We only expect every file is in a csv or pkl format.
136 |
137 | ```python
138 | from popper import Popper
139 |
140 | agent = Popper(llm="claude-3-5-sonnet-20240620")
141 | agent.configure(alpha = 0.1)
142 | agent.register_data(data_path='path/to/data', loader_type='custom')
143 | agent.validate(hypothesis = 'YOUR HYPOTHESIS')
144 | ```
145 |
146 | ## Hypothesis in Popper
147 |
148 | You can arbitrarily define any free-form hypothesis. In the paper, we provide two types of hypothesis: biological hypothesis and discovery-bench hypothesis.
149 |
150 | You can load the biological hypothesis with:
151 |
152 | ```python
153 | from popper.benchmark import gene_perturb_hypothesis
154 | bm = gene_perturb_hypothesis(num_of_samples = samples, permuted=False, dataset = 'IL2', path = path)
155 | example = bm.get_example(0)
156 | ```
157 | It will return something like:
158 |
159 | ```
160 | {'prompt': 'Gene VAV1 regulates the production of Interleukin-2 (IL-2).',
161 | 'gene': 'VAV1',
162 | 'answer': 2.916,
163 | 'binary_answer': True}
164 | ```
165 |
166 | `num_of_samples` is the number of samples you want to generate, `permuted` is whether you want to permute the dataset for type I error estimation, and `dataset` is the dataset you want to use and you can choose from `IL2` and `IFNG`.
167 |
168 | For discovery-bench, you can load the hypothesis with:
169 |
170 | ```python
171 | from popper.benchmark import discovery_bench_hypothesis
172 | bm = discovery_bench_hypothesis(num_samples = samples, path = path)
173 | example = bm.get_example(0)
174 | ```
175 |
176 | It will return something like:
177 |
178 | ```
179 | {'task': 'archaeology',
180 | 'domain': 'humanities',
181 | 'metadataid': 5,
182 | 'query_id': 0,
183 | 'prompt': 'From 1700 BCE onwards, the use of hatchets and swords increased while the use of daggers decreased.',
184 | 'data_loader': ,
185 | 'answer': True}
186 | ```
187 |
188 | As each hypothesis in discoverybench has its own associated dataset, the example will return `data_loader` its own dataset.
189 |
190 |
191 | ## Run benchmarks in the paper
192 |
193 | Bash scripts for reproducing the paper is provided in the `benchmark_scripts/run_targetval.sh` for `TargetVal` benchmark and `benchmark_scripts/run_discoverybench.sh` for `DiscoveryBench` benchmark.
194 |
195 | **Note:** the Popper agent can read or write files to your filesystem. We recommend running the benchmark scripts inside a containerized environments. We have provided a working `Dockerfile` and an example script to launch a Docker container and execute the scripts in `benchmark_scripts/run_discoverybench_docker.sh`.
196 |
197 | **To run paper benchmarks with locally-served models,** you can simply passed in the extra parameters to the benchmark script, e.g.,
198 | ```bash
199 | python benchmark_scripts/run_discovery_bench.py --exp_name discovery_bench --model llama-3.3-70b --num_tests 5 --samples 100 --permute --e_value --react --relevance_checker --is_locally_served --server_port 30000 --path PATH_TO_YOUR_DATASET
200 | ```
201 |
202 | ## UI interface
203 | You can deploy a simple UI interface with one line of code using your datasets or our bio dataset - a gradio UI will be generated and you can interact with it to validate your hypothesis.
204 |
205 | ```python
206 | agent.launch_UI()
207 | ```
208 |
209 | An interface like this will be popped up:
210 |
211 | [](https://www.youtube.com/watch?v=jYFEeP2mEY8)
212 |
213 | ## Acknowledgement
214 | The DiscoveryBench benchmark and some of the baseline agents are built on top of [allenai/discoverybench](https://github.com/allenai/discoverybench). Thanks for their awsome work!
215 |
216 | ## Contact
217 |
218 | For any questions, please raise an issue in the GitHub or contact Kexin Huang (kexinh@cs.stanford.edu).
219 |
--------------------------------------------------------------------------------
/popper/benchmark.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import traceback
5 | import pandas as pd
6 | import numpy as np
7 |
8 | from popper.utils import DiscoveryBenchDataLoader
9 |
10 | class gene_perturb_hypothesis:
11 | def __init__(self, dataset='IL2',
12 | num_of_samples = 50,
13 | permuted = False,
14 | user_study_neg_genes = False,
15 | path = None):
16 |
17 | path = os.path.join(path, 'benchmark/targetval/')
18 | if dataset == 'IL2':
19 | self.prompt = "Gene {gene} regulates the production of Interleukin-2 (IL-2)."
20 | ground_truth_path = path + 'ground_truth_IL2.csv'
21 | elif dataset == 'IFNG':
22 | self.prompt = "Gene {gene} regulates the production of Interferon-gamma (IFN-g)."
23 | ground_truth_path = path + 'ground_truth_IFNG.csv'
24 |
25 | self.ground_truth = pd.read_csv(ground_truth_path, index_col=0)
26 |
27 | self.query = []
28 | self.ground_truth['abs_score'] = self.ground_truth.Score.abs()
29 | self.ground_truth = self.ground_truth.sort_values('abs_score')
30 | self.hypothesis2score = self.ground_truth.abs_score.to_dict()
31 |
32 | if not permuted:
33 | self.query += self.ground_truth.iloc[-num_of_samples:].index.values.tolist()
34 | self.answer = np.array([True] * num_of_samples)
35 | else:
36 | self.query += self.ground_truth.sample(frac = 1, random_state = 42).iloc[:num_of_samples].index.values.tolist()
37 | self.answer = np.array([False] * num_of_samples)
38 |
39 | if user_study_neg_genes:
40 | self.query = ['CD28', 'CD2', 'CD3D', 'MAK16', 'RAC2', 'CD3E', 'VAV1', 'CD247', 'ZAP70', 'CD3G', 'LCP2']
41 | self.answer = np.concatenate([self.answer, np.array([False] * 10)])
42 |
43 |
44 | def get_example(self, index = None):
45 | if index is None:
46 | index = np.random.randint(len(self.query))
47 |
48 | q = self.query[index]
49 | a = self.hypothesis2score[q]
50 | return {"prompt": self.prompt.format(gene = q),
51 | "gene": q,
52 | "answer": a,
53 | "binary_answer": self.answer[index]
54 | }
55 |
56 | def get_iterator(self):
57 | for i in range(len(self.query)):
58 | yield self.get_example(i)
59 |
60 | def output_class(self):
61 | from langchain_core.pydantic_v1 import BaseModel, Field
62 | from typing import Optional
63 | class Output(BaseModel):
64 | """Whether or not this hypothesis is considered true/false."""
65 |
66 | hypothesis_test_result: Optional[bool] = Field(
67 | description="Whether or not this hypothesis is considered true/false."
68 | )
69 | return Output
70 |
71 | def evaluate(self, response, answers=None):
72 | from sklearn.metrics import accuracy_score, average_precision_score, f1_score
73 | predicted = np.array([i[1] for i in response])
74 | if not answers:
75 | answers = np.array([exp["binary_answer"] for exp in self.examples])
76 |
77 | res_stats = np.array([i[0] for i in response])
78 | return {
79 | 'accuracy': accuracy_score(answers, predicted),
80 | 'power': np.sum((predicted == True) & (answers == True)) / np.sum((answers == True)),
81 | 'false discovery rate': np.sum((predicted == True) & (answers == False)) / np.sum((answers == False)),
82 | 'f1': f1_score(answers, predicted)
83 | }
84 |
85 | class discovery_bench_hypothesis:
86 | def __init__(self, split="test", synthetic=False, num_samples=50, seed=1234, path = None):
87 |
88 | print("----------Loading Discovery Bench------------")
89 |
90 | if split != "test":
91 | raise NotImplementedError
92 |
93 | random.seed(seed)
94 | if path is None:
95 | raise ValueError
96 |
97 | root_path = path
98 | self.split = split
99 | self.synthetic = synthetic
100 | self.data_path = os.path.join(root_path, "discoverybench", "synthetic" if self.synthetic else "real", self.split)
101 |
102 | ground_truth_path = os.path.join(root_path, "answer_key/answer_key_synth.csv") if self.synthetic else os.path.join(root_path, "answer_key/answer_key_real_cleaned_1.csv")
103 | self.ground_truth = pd.read_csv(ground_truth_path)
104 |
105 | self.examples = []
106 | for task_dir in os.listdir(self.data_path):
107 | task_path = os.path.join(self.data_path, task_dir)
108 |
109 | for file in os.listdir(task_path):
110 | if file.endswith(".json"):
111 | file_path = os.path.join(task_path, file)
112 | with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
113 | metadata = json.load(f)
114 | metadata_id = int(file.split(".")[0].split("_")[1])
115 | data_loader = DiscoveryBenchDataLoader(task_path, metadata)
116 |
117 | # permute the data loader for negative example
118 | permuted_dataloader = DiscoveryBenchDataLoader(task_path, metadata)
119 | permuted_dataloader.permute_selected_columns()
120 |
121 |
122 | for query_list in metadata["queries"]:
123 | for query in query_list:
124 | try:
125 | # print(file_path)
126 | hypothesis = self.ground_truth.loc[
127 | (self.ground_truth['dataset'] == task_dir) &
128 | (self.ground_truth['metadataid'] == metadata_id) &
129 | (self.ground_truth['query_id'] == query["qid"]),
130 | 'gold_hypo'
131 | ].iloc[0]
132 |
133 | if 'non-trivially falsifiable' in self.ground_truth.columns and self.ground_truth.loc[
134 | (self.ground_truth['dataset'] == task_dir) &
135 | (self.ground_truth['metadataid'] == metadata_id) &
136 | (self.ground_truth['query_id'] == query["qid"]),
137 | 'non-trivially falsifiable'
138 | ].iloc[0] == 0:
139 | continue
140 |
141 | self.examples.append({
142 | "task": task_dir,
143 | "domain": metadata["domain"],
144 | "metadataid": metadata_id,
145 | "query_id": query["qid"],
146 | "prompt": hypothesis,
147 | # "metadata": metadata,
148 | "data_loader": data_loader,
149 | "answer": True,
150 | })
151 |
152 | self.examples.append({
153 | "task": task_dir,
154 | "domain": metadata["domain"],
155 | "metadataid": metadata_id,
156 | "query_id": query["qid"],
157 | "prompt": hypothesis,
158 | # "metadata": metadata,
159 | "data_loader": permuted_dataloader,
160 | "answer": False,
161 | })
162 | except Exception as e:
163 | # print(e)
164 | # print(traceback.format_exc())
165 | pass
166 |
167 |
168 | if num_samples < len(self.examples):
169 | random.shuffle(self.examples)
170 | self.examples = self.examples[:num_samples]
171 |
172 | self.num_samples = len(self.examples)
173 | print(f"Loaded {self.num_samples} hypotheses")
174 | print("--------------------------------------")
175 |
176 | def get_example(self, index = None):
177 | return self.examples[index]
178 |
179 | def get_iterator(self):
180 | for i in range(self.num_samples):
181 | yield self.get_example(i)
182 |
183 | def output_class(self):
184 | from langchain_core.pydantic_v1 import BaseModel, Field
185 | from typing import Optional
186 | class Output(BaseModel):
187 | """Whether or not this hypothesis is considered true/false."""
188 |
189 | hypothesis_test_result: Optional[bool] = Field(
190 | description="Whether or not this hypothesis is considered true/false."
191 | )
192 | return Output
193 |
194 | def evaluate(self, response, answers=None):
195 | ## expected [(res_stat, conclusion)] following the order of the query
196 | from sklearn.metrics import accuracy_score, average_precision_score, f1_score
197 | predicted = np.array([i[1] for i in response])
198 | if not answers:
199 | answers = np.array([exp["answer"] for exp in self.examples])
200 | else:
201 | answers = np.array(answers)
202 |
203 | res_stats = np.array([i[0] for i in response])
204 | return {
205 | 'accuracy': accuracy_score(answers, predicted),
206 | 'power': np.sum((predicted == True) & (answers == True)) / np.sum((answers == True)),
207 | 'false discovery rate': np.sum((predicted == True) & (answers == False)) / np.sum((answers == False)),
208 | 'f1': f1_score(answers, predicted)
209 |
210 | # 'auprc': average_precision_score(answers, res_stats),
211 | # 'stat_pos': res_stats[np.where(answers & res_stats)[0]].mean(),
212 | # 'stat_neg': res_stats[np.where(answers == False & res_stats)[0]].mean()
213 | }
--------------------------------------------------------------------------------
/popper/popper.py:
--------------------------------------------------------------------------------
1 | from popper.utils import ExperimentalDataLoader, CustomDataLoader, DiscoveryBenchDataLoader
2 | from popper.agent import SequentialFalsificationTest
3 | from typing import Optional, Dict, Any
4 | import os
5 | import requests
6 | import zipfile
7 | import urllib
8 | from tqdm import tqdm
9 | import tarfile
10 | import subprocess
11 | import shutil
12 |
13 | class Popper:
14 | """Wrapper class for hypothesis validation using sequential falsification testing."""
15 |
16 | def __init__(self, llm: str = "claude-3-5-sonnet-20240620", is_locally_served = False, server_port = None, api_key = "EMPTY", **kwargs):
17 | """Initialize Popper.
18 |
19 | Args:
20 | llm (str): Name of the LLM model to use
21 | **kwargs: Additional arguments to pass to SequentialFalsificationTest
22 | """
23 | self.llm = llm
24 | self.agent = None
25 | self.data_loader = None
26 | self.is_local = is_locally_served
27 | self.port = server_port
28 | self.api_key = api_key
29 | self.kwargs = kwargs
30 |
31 | def register_data(self, data_path: str, data_sampling: int = -1, loader_type: str = 'bio', metadata: Optional[Dict] = None):
32 | """Register data for hypothesis testing.
33 |
34 | Args:
35 | data_path (str): Path to data directory
36 | data_sampling (int): Number of datasets to sample (-1 for all)
37 | loader_type (str): Type of data loader to use ('bio', 'custom', or 'discovery_bench')
38 | metadata (Optional[Dict]): Metadata required for DiscoveryBenchDataLoader
39 | """
40 | if not os.path.exists(data_path):
41 | os.makedirs(data_path)
42 |
43 | self.data_path = data_path
44 | if not os.path.exists(os.path.join(data_path, 'bio_database')):
45 | print('It will take a few minutes to download the data for the first time...')
46 | self.download_all_data()
47 | else:
48 | print('Data already exists, loading...')
49 |
50 | if loader_type == 'bio':
51 | self.data_loader = ExperimentalDataLoader(
52 | data_path=data_path,
53 | table_dict_selection='all_bio',
54 | data_sampling=data_sampling
55 | )
56 | elif loader_type == 'bio_selected':
57 | self.data_loader = ExperimentalDataLoader(
58 | data_path=data_path,
59 | table_dict_selection='default',
60 | data_sampling=data_sampling
61 | )
62 | elif loader_type == 'custom':
63 | self.data_loader = CustomDataLoader(data_folder=data_path)
64 | elif loader_type == 'discovery_bench':
65 | if metadata is None:
66 | raise ValueError("Metadata must be provided for DiscoveryBenchDataLoader")
67 | self.data_loader = DiscoveryBenchDataLoader(data_path=data_path, metadata=metadata)
68 | else:
69 | raise ValueError(f"Unknown loader_type: {loader_type}")
70 |
71 |
72 | def configure(self,
73 | alpha: float = 0.1,
74 | aggregate_test: str = 'E-value',
75 | max_num_of_tests: int = 5,
76 | max_retry: int = 5,
77 | time_limit: int = 2,
78 | relevance_checker: bool = True,
79 | use_react_agent: bool = True):
80 | """Configure the sequential falsification test parameters.
81 |
82 | Args:
83 | alpha (float): Significance level
84 | aggregate_test (str): Test aggregation method
85 | max_num_of_tests (int): Maximum number of tests to run
86 | max_retry (int): Maximum number of retries for failed tests
87 | time_limit (int): Time limit in hours
88 | relevance_checker (bool): Whether to use relevance checker
89 | use_react_agent (bool): Whether to use ReAct agent
90 | """
91 | if self.data_loader is None:
92 | raise ValueError("Please register data first using register_data()")
93 |
94 | self.agent = SequentialFalsificationTest(llm=self.llm, is_local=self.is_local, port=self.port, api_key=self.api_key)
95 | self.agent.configure(
96 | data=self.data_loader,
97 | alpha=alpha,
98 | aggregate_test=aggregate_test,
99 | max_num_of_tests=max_num_of_tests,
100 | max_retry=max_retry,
101 | time_limit=time_limit,
102 | relevance_checker=relevance_checker,
103 | use_react_agent=use_react_agent,
104 | **self.kwargs
105 | )
106 |
107 | def validate(self, hypothesis: str) -> Dict[str, Any]:
108 | """Validate a scientific hypothesis using sequential falsification testing.
109 |
110 | Args:
111 | hypothesis (str): The scientific hypothesis to test
112 |
113 | Returns:
114 | Dict containing the test results including logs, final message, and parsed results
115 | """
116 | if self.agent is None:
117 | raise ValueError("Please configure the agent first using configure()")
118 |
119 | log, last_message, parsed_result = self.agent.go(hypothesis)
120 |
121 | return {
122 | "log": log,
123 | "last_message": last_message,
124 | "parsed_result": parsed_result
125 | }
126 |
127 | def _setup_default_agent(self):
128 | """Set up agent with default configuration if not already configured."""
129 | self.configure(
130 | alpha=0.1,
131 | aggregate_test='E-value',
132 | max_num_of_tests=5,
133 | max_retry=5,
134 | time_limit=2,
135 | relevance_checker=True,
136 | use_react_agent=True
137 | )
138 |
139 | def launch_UI(self):
140 | import gradio as gr
141 | from gradio import ChatMessage
142 | from time import time
143 | import asyncio
144 | import copy
145 |
146 | async def generate_response(prompt,
147 | designer_history=[],
148 | executor_history=[],
149 | relevance_checker_history=[],
150 | error_control_history=[],
151 | summarizer_history=[]):
152 |
153 | #designer_history.append(ChatMessage(role="user", content=prompt))
154 | #yield designer_history, executor_history, relevance_checker_history, error_control_history, summarizer_history
155 |
156 | # Initialize log tracking
157 | prev_log = copy.deepcopy(self.agent.log) # Store initial log state
158 |
159 | # Run the agent asynchronously
160 | task = asyncio.create_task(asyncio.to_thread(self.agent.go, prompt))
161 |
162 | while not task.done(): # Check while the agent is still running
163 | #print("Checking for new log messages...")
164 | await asyncio.sleep(1) # Wait for 1 second
165 |
166 | # Check if log has changed
167 | if self.agent.log != prev_log:
168 | prev_log = copy.deepcopy(self.agent.log) # Update previous log state
169 |
170 | # Convert new log messages to ChatMessage format
171 | designer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['designer']]
172 | executor_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['executor']]
173 | relevance_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['relevance_checker']]
174 | sequential_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['sequential_testing']]
175 | summarizer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['summarizer']]
176 |
177 | yield designer_msgs, executor_msgs, relevance_msgs, sequential_msgs, summarizer_msgs
178 |
179 | # Ensure final result is captured
180 | result = await task
181 |
182 | # Convert final logs to ChatMessage format before yielding
183 | designer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['designer']]
184 | executor_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['executor']]
185 | relevance_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['relevance_checker']]
186 | sequential_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['sequential_testing']]
187 | summarizer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['summarizer']]
188 |
189 | yield designer_msgs, executor_msgs, relevance_msgs, sequential_msgs, summarizer_msgs
190 |
191 |
192 | def like(evt: gr.LikeData):
193 | print("User liked the response")
194 | print(evt.index, evt.liked, evt.value)
195 |
196 | with gr.Blocks() as demo:
197 | with gr.Row():
198 | with gr.Column(scale=1):
199 | designer_chatbot = gr.Chatbot(label="Popper Experiment Designer",
200 | type="messages", height=600,
201 | show_copy_button=True,
202 | show_share_button = True,
203 | group_consecutive_messages = False,
204 | show_copy_all_button = True,
205 | )
206 | relevance_checker_chatbot = gr.Chatbot(label="Relevance Checker",
207 | type="messages", height=300,
208 | show_copy_button=True,
209 | show_share_button = True,
210 | group_consecutive_messages = False,
211 | show_copy_all_button = True
212 | )
213 |
214 | with gr.Column(scale=1):
215 | executor_chatbot = gr.Chatbot(label="Popper Experiment Executor",
216 | type="messages", height=600,
217 | show_copy_button=True,
218 | show_share_button = True,
219 | group_consecutive_messages = False,
220 | show_copy_all_button = True,
221 | )
222 | error_control_chatbot = gr.Chatbot(label="Sequential Error Control",
223 | type="messages", height=300,
224 | show_copy_button=True,
225 | show_share_button = True,
226 | group_consecutive_messages = False,
227 | show_copy_all_button = True
228 | )
229 |
230 | with gr.Row():
231 | summarizer = gr.Chatbot(label="Popper Summarizer",
232 | type="messages", height=300,
233 | show_copy_button=True,
234 | show_share_button = True,
235 | group_consecutive_messages = False,
236 | show_copy_all_button = True,
237 | )
238 | with gr.Row():
239 | # Textbox on the left, and Button with an icon on the right
240 | prompt_input = gr.Textbox(show_label = False, placeholder="What is your hypothesis?", scale=8)
241 | button = gr.Button("Validate")
242 |
243 | button.click(lambda: gr.update(value=""), inputs=None, outputs=prompt_input)
244 | # Bind button click to generate_response function, feeding results to both chatbots
245 | button.click(generate_response, inputs=[prompt_input, designer_chatbot, executor_chatbot, relevance_checker_chatbot, error_control_chatbot, summarizer], outputs=[designer_chatbot, executor_chatbot, relevance_checker_chatbot, error_control_chatbot, summarizer])
246 |
247 | demo.launch(share = True)
248 |
249 |
250 | def download_all_data(self):
251 | url = "https://dataverse.harvard.edu/api/access/datafile/10888484"
252 | file_name = 'popper_data_processed'
253 | self._download_and_extract_data(url, file_name)
254 |
255 | def _download_and_extract_data(self, url, file_name):
256 | """Download, extract, and merge directories using rsync."""
257 | tar_file_path = os.path.join(self.data_path, f"{file_name}.tar.gz")
258 |
259 | if not os.path.exists(tar_file_path):
260 | # Download the file
261 | print(f"Downloading {file_name}.tar.gz...")
262 | self._download_with_progress(url, tar_file_path)
263 | print("Download complete.")
264 |
265 | # Extract the tar.gz file
266 | print("Extracting files...")
267 | with tarfile.open(tar_file_path, 'r:gz') as tar:
268 | for member in tqdm(tar.getmembers(), desc="Extracting: "):
269 | member.name = member.name.split('popper_data_processed/')[-1] # Strip directory structure
270 | tar.extract(member, self.data_path)
271 | print("Extraction complete.")
272 |
273 | def _download_with_progress(self, url, file_path):
274 | """Download a file with a progress bar."""
275 | request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
276 | response = urllib.request.urlopen(request)
277 | total_size = int(response.getheader('Content-Length').strip())
278 | block_size = 1024 # 1 KB
279 |
280 | with open(file_path, 'wb') as file, tqdm(
281 | total=total_size, unit='B', unit_scale=True, desc="Downloading"
282 | ) as pbar:
283 | while True:
284 | buffer = response.read(block_size)
285 | if not buffer:
286 | break
287 | file.write(buffer)
288 | pbar.update(len(buffer))
--------------------------------------------------------------------------------
/popper/react_utils.py:
--------------------------------------------------------------------------------
1 | # Set up the base template
2 | from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser, create_react_agent
3 | from langchain.prompts import StringPromptTemplate
4 | from langchain.tools import BaseTool
5 | from langchain.chains.llm import LLMChain
6 | from langchain_experimental.tools.python.tool import PythonAstREPLTool
7 | from langchain.schema import AgentAction, AgentFinish
8 | from pydantic import Field, PrivateAttr
9 | from typing import List, Union, Dict
10 | import contextlib
11 | import io
12 | import logging
13 | import re
14 |
15 | logging.basicConfig(level=logging.INFO)
16 |
17 | template = """{system_prompt}
18 |
19 | You have access to the following tools:
20 | {tools}
21 |
22 | Use the following format:
23 |
24 | Falsification Test: description of a hypothesis falsification test that you need to implement
25 | Datasets: the names and descriptions of datasets relevant to the input falsification test
26 | Thought: you should always think about what to do
27 | Action: the action to take, should be one of [{tool_names}]
28 | Action Input: the input to the action
29 | Observation: the result of the action
30 | ... (this Thought/Action/Action Input/Observation can repeat N times)
31 | Thought: I now know the final answer
32 | Final Answer: the final output from the falsification test (i.e., whether you are able to reject the null hypothesis with statistical significance). Make sure to also include the p-value of the statistical test written in scientific notations.
33 |
34 | IMPORTANT: Please make sure the Final Answer includes the p-value of the falsification test regardless if you are able to reject the null hypothesis. **Only return the Final Answer if you have obtained a non-zero p-value**. When printing p-values, please use scientific notations instead of the raw number.
35 |
36 | IMPORTANT: Please avoid p-hacking! Be fair and rigorous.
37 |
38 | Note: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section.
39 |
40 | --------------------------------------------
41 | Example
42 | Falsification Test:
43 | {{
44 | "Falsification Test name": "Body Length Evolution and Speciation Rate Relationship Test",
45 | "Falsification Test description": "Testing for a significant positive relationship between maximum body length evolution rate and spatial variation in speciation rates.",
46 | "Falsification Test Null hypothesis": "There is no statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates.",
47 | "Falsification Test Alternate hypothesis": "There is a statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates."
48 | }}
49 | Datasets:
50 | {{
51 | "name": "df_body-size-evolution-in-south-american-freshwater-fishes",
52 | "description": "Data on body size evolution in South American freshwater fishes, including speciation and extinction rates",
53 | "columns": {{
54 | "raw": [
55 | {{
56 | "name": "HYBAS_ID",
57 | "description": "Unique identifier for each hydrological basin"
58 | }},
59 | {{
60 | "name": "long",
61 | "description": "Longitude of the basin location"
62 | }},
63 | {{
64 | "name": "lat",
65 | "description": "Latitude of the basin location"
66 | }},
67 | {{
68 | "name": "BAMM_speciation",
69 | "description": "Rate of speciation as calculated by the BAMM method"
70 | }},
71 | {{
72 | "name": "BAMM_extinction",
73 | "description": "Rate of extinction as calculated by the BAMM method"
74 | }},
75 | {{
76 | "name": "BAMM_NetDiv",
77 | "description": "Net diversification rate, calculated as speciation minus extinction"
78 | }},
79 | {{
80 | "name": "aet",
81 | "description": "Mean annual evapotranspiration for each basin"
82 | }},
83 | {{
84 | "name": "Elevation",
85 | "description": "Average elevation of the basin"
86 | }},
87 | {{
88 | "name": "sgr",
89 | "description": "Species growth rate in each basin"
90 | }},
91 | {{
92 | "name": "soil_div",
93 | "description": "Soil diversity index for each basin"
94 | }},
95 | {{
96 | "name": "area",
97 | "description": "Total area of the basin in square kilometers"
98 | }},
99 | {{
100 | "name": "diversity",
101 | "description": "Diversity index for the species in each basin"
102 | }}
103 | ]
104 | }}
105 | }}
106 | Thought: First, I need to load the dataset from the global namespace in Python and inspect the data to identify the relevant columns for this hypothesis test.
107 | Action: python_repl_ast
108 | Action Input: import pandas as pd\n\ndf = df_body-size-evolution-in-south-american-freshwater-fishes\ndf.head()
109 | Observation:
110 | HYBAS_ID long lat BAMM_speciation BAMM_extinction BAMM_NetDiv ... aet Elevation sgr soil_div area diversity
111 | 0 6050000010 -76.477422 7.742693 0.137392 0.026807 0.110585 ... 1387 330.150088 166 0.482402 72363.7 68
112 | 1 6050000740 -74.628725 9.803586 0.117235 0.025796 0.091438 ... 1082 69.475294 23 0.457436 17944.3 35
113 | 2 6050068100 -75.295995 8.448815 0.119381 0.023826 0.095555 ... 1312 143.032178 74 0.378793 17105.5 44
114 | 3 6050068110 -74.608408 8.922863 0.132477 0.027777 0.104700 ... 1445 14.724138 3 0.468328 610.1 48
115 | 4 6050070260 -75.591588 5.770093 0.120127 0.022940 0.097187 ... 1371 1378.729945 421 0.158870 61901.9 81
116 | [5 rows x 21 columns]
117 | Thought: Now that the dataset is loaded and I can see the columns, I need to perform a statistical test to assess the significance of the relationship between 'BAMM_speciation' and 'BAMM_NetDiv'.
118 | Action: python_repl_ast
119 | Action Input: from scipy.stats import linregress\n\n# Perform linear regression to test for a statistically significant relationship\nresult = linregress(df['BAMM_speciation'], df['BAMM_NetDiv'])\ncoefficient = result.slope\np_value = result.pvalue\ncoefficient, "{{:.2e}}".format(p_value)
120 | Observation: (0.5175306498596297, 3.50e-03)
121 | Thought: The linear regression analysis provides a coefficient of approximately 0.518, indicating a positive relationship, and the p-value is 3.50e-03, which is statistically significant at the 0.05 level. Based on this, I can reject the null hypothesis in the falsification test.
122 | Final Answer: Falsification test passes. The null hypothesis is rejected with a p-value of 3.50e-03.
123 | --------------------------------------------
124 |
125 | Remember, your output should always **exactly** follow the aforementioned format:
126 | Falsification Test: description of a hypothesis falsification test that you need to implement
127 | Datasets: the names and descriptions of datasets relevant to the input falsification test
128 | Thought: you should always think about what to do
129 | Action: the action to take, should be one of [{tool_names}]
130 | Action Input: the input to the action
131 | Observation: the result of the action
132 | ... (this Thought/Action/Action Input/Observation can repeat N times)
133 | Thought: I now know the final answer
134 | Final Answer: the final output from the falsification test (i.e., whether you are able to reject the null hypothesis with statistical significance). Make sure to also include the p-value of the statistical test written in scientific notations.
135 |
136 | **IMPORTANT**
137 | You should ALWAYS report a p-value EXACTLY AS IT IS. If a p-value is 4.2e-01, report 4.2e-01, DO NOT REPORT 4.2e-02!
138 | BE CAREFUL WHEN READING THE P-VALUE RESULTS, MISREPORTING A P-VALUE IS WORSE THAN HAVING NO P-VALUE AT ALL.
139 | When reading p-values, make sure the sample sizes and the statistical test is valid.
140 | Please make sure to always return ONE valid p-value. If there are multiple p-values produced by the test, aggregate them in a meaningful and rigorous way.
141 | ** Always make sure the returned p-value matches your conclusion for the falsification test. For example, if you reject H0 but finds out that H1 is also incorrect (e.g., the suggested shape or relationship is wrong), you SHOULD NOT return a p-value < 0.05.
142 | If you think it's impossible to find a valid p-value for the falsification test, return a p-value of 1.00e+00.
143 | DO NOT perform p-hacking.
144 |
145 | Begin!
146 |
147 | {input} {agent_scratchpad}"""
148 |
149 |
150 | def load_data_to_react_globals(data_loader):
151 | for name, df in data_loader.table_dict.items():
152 | globals()[name] = df
153 |
154 |
155 | # Set up a prompt template
156 | class CustomPromptTemplate(StringPromptTemplate):
157 | # The template to use
158 | template: str
159 | # The list of tools available
160 | tools: List[BaseTool]
161 |
162 | def format(self, **kwargs) -> str:
163 | # Get the intermediate steps (AgentAction, Observation tuples)
164 | # Format them in a particular way
165 | intermediate_steps = kwargs.pop("intermediate_steps")
166 | thoughts = ""
167 | for action, observation in intermediate_steps:
168 | thoughts += action.log
169 | thoughts += f"\nObservation: {observation}\nThought: "
170 | # Set the agent_scratchpad variable to that value
171 | kwargs["agent_scratchpad"] = thoughts
172 | # Create a tools variable from the list of tools provided
173 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
174 | # Create a list of tool names for the tools provided
175 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
176 | prompt = self.template.format(**kwargs)
177 | # print([prompt])
178 | return prompt
179 |
180 | # CustomOutputParser to parse the output of the LLM and execute actions
181 | class CustomOutputParser(AgentOutputParser):
182 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
183 | # Check if agent should finish
184 | if "Final Answer:" in llm_output:
185 | return AgentFinish(
186 | return_values={"output": llm_output},
187 | log=llm_output,
188 | )
189 | # Parse out the action and action input
190 | regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
191 | match = re.search(regex, llm_output, re.DOTALL)
192 | if not match:
193 | # raise ValueError(f"Could not parse LLM output: `{llm_output}`")
194 | print(f"Warning: could not parse LLM output: `{llm_output}`, finishing chain...")
195 | return AgentFinish(
196 | return_values={"output": llm_output},
197 | log=llm_output,
198 | )
199 | action = match.group(1).strip()
200 | action_input = match.group(2)
201 | # Return the action and action input
202 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
203 |
204 |
205 | class CustomPythonAstREPLTool(PythonAstREPLTool):
206 | _exec_globals:Dict = PrivateAttr()
207 | def __init__(self, *args, **kwargs):
208 | super().__init__(*args, **kwargs)
209 | # Initialize a persistent global namespace for code execution
210 | self._exec_globals = {}
211 | self._exec_globals.update(__builtins__)
212 |
213 | def _set_globals(self, table_dict=None):
214 | self._exec_globals = {}
215 | self._exec_globals.update(__builtins__)
216 |
217 | if table_dict:
218 | self._exec_globals.update(table_dict)
219 |
220 | def _run(self, query: str, run_manager=None):
221 | code_match = re.search(r"```(.*?)```", query, re.DOTALL)
222 | if code_match:
223 | # Extract code within backticks
224 | code = code_match.group(1)
225 | else:
226 | code = query
227 | code = code.strip()
228 | if code.startswith("python"):
229 | code = code[len("python"):].lstrip()
230 |
231 | if code.endswith("Observation"):
232 | code = code[:-len("Observation")].rstrip()
233 |
234 | code_lines = code.strip().split('\n')
235 | code = '\n'.join(code_lines[:-1]) # avoid printing the last line twice
236 | last_line = code_lines[-1]
237 |
238 | output_capture = io.StringIO()
239 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture):
240 | logging.getLogger().handlers[0].stream = output_capture
241 | try:
242 | exec(code, self._exec_globals)
243 | try:
244 | result = eval(last_line, self._exec_globals)
245 | if result is not None:
246 | print(result, file=output_capture)
247 | except:
248 | pass
249 | except Exception as e:
250 | return str(e)
251 |
252 | # Retrieve the output and return it
253 | output = output_capture.getvalue()
254 | return output if output else "Execution completed without output."
255 |
256 |
257 | def create_agent(
258 | llm,
259 | handlers,
260 | max_iterations = 50,
261 | early_stopping_method: str = "force",
262 | ):
263 | output_parser = CustomOutputParser()
264 | python_tool = CustomPythonAstREPLTool(callbacks=handlers)
265 | tools = [python_tool]
266 | tool_names = [tool.name for tool in tools]
267 |
268 | prompt = CustomPromptTemplate(
269 | template=template,
270 | tools=tools,
271 | input_variables=["system_prompt", "input", "intermediate_steps", "tool_names", "tools", "agent_scratchpad"]
272 | )
273 | # llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=handlers)
274 |
275 | # agent = LLMSingleActionAgent(
276 | # llm_chain=llm_chain,
277 | # output_parser=output_parser,
278 | # stop=["\nObservation:"],
279 | # allowed_tools=tool_names
280 | # )
281 | agent = create_react_agent(
282 | llm=llm,
283 | tools=tools,
284 | prompt=prompt,
285 | output_parser=output_parser,
286 | stop_sequence=["\nObservation:"],
287 | )
288 |
289 | return AgentExecutor.from_agent_and_tools(
290 | agent=agent,
291 | tools=tools,
292 | verbose=True,
293 | max_iterations=max_iterations,
294 | callbacks=handlers,
295 | early_stopping_method=early_stopping_method
296 | )
--------------------------------------------------------------------------------
/popper/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os
4 | import json
5 | import time
6 | import openai
7 | from glob import glob
8 | import numpy as np
9 | import pandas as pd
10 |
11 | from popper.llm.custom_model import CustomChatModel
12 | from langchain_anthropic import ChatAnthropic
13 | from langchain_openai import ChatOpenAI
14 | from langchain_core.messages.base import get_msg_title_repr
15 | from langchain_core.utils.interactive_env import is_interactive_env
16 |
17 | def get_llm(model = 'claude-3-5-sonnet-20240620', temperature=0.7, port=30000, api_key = "EMPTY", **kwargs):
18 | source = "Local"
19 | if model[:7] == 'claude-':
20 | source = 'Anthropic'
21 | elif model[:4] == 'gpt-' or model.startswith("o1"):
22 | source = 'OpenAI'
23 | # elif model.startswith('llama'):
24 | # source = "Llama"
25 | # if source not in ['OpenAI', 'Anthropic']:
26 | # raise ValueError('Invalid source')
27 | if source == 'OpenAI':
28 | if model.startswith("o1"):
29 | return ChatOpenAI(model = model, temperature = -1, **kwargs)
30 | return ChatOpenAI(model = model, temperature = temperature, **kwargs)
31 | elif source == 'Anthropic':
32 | return ChatAnthropic(model = model,
33 | temperature = temperature,
34 | max_tokens = 4096,
35 | **kwargs)
36 | else:
37 | # assuming a locally-served model
38 | assert port is not None, f"Model {model} is not supported, please provide a local port if it is a locally-served model."
39 | llm = CustomChatModel(model = model, model_type=source, temperature = temperature)
40 | llm.client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key=api_key).chat.completions
41 | return llm
42 |
43 | class ExperimentalDataLoader:
44 | def __init__(self, data_path, table_dict_selection='default', data_sampling=-1):
45 | self.data_path = os.path.join(data_path, 'bio_database')
46 | self.available_datasets = [
47 | "gtex_tissue_gene_tpm",
48 | "gwas_catalog",
49 | "gene_info",
50 | "genetic_interaction",
51 | "genebass_synonymous_filtered",
52 | "genebass_missense_LC_filtered",
53 | "genebass_pLoF_filtered",
54 | "affinity_capture_ms",
55 | "two_hybrid",
56 | "synthetic_growth_defect",
57 | "affinity_capture_rna",
58 | "co_fractionation",
59 | "synthetic_lethality",
60 | "dosage_growth_defect",
61 | "proximity_label_ms",
62 | "synthetic_rescue",
63 | "reconstituted_complex",
64 | "eqtl_ukbb",
65 | "pqtl_ukbb",
66 | "sqtl_ukbb",
67 | "variant_table",
68 | "trait"
69 | ]
70 |
71 | self.permute_columns = {
72 | 'gtex_tissue_gene_tpm': ['Gene'],
73 | 'gwas_catalog': ['REPORTED GENE(S)', 'MAPPED_GENE', 'UPSTREAM_GENE_ID', 'DOWNSTREAM_GENE_ID', 'SNP_GENE_IDS'],
74 | }
75 |
76 | # Load datasets based on user input (default or all_bio)
77 | if table_dict_selection == 'default':
78 | self.datasets_to_load = [
79 | "gtex_tissue_gene_tpm",
80 | "gwas_catalog",
81 | "gene_info"
82 | ]
83 | elif table_dict_selection == 'all_bio':
84 | self.datasets_to_load = self.available_datasets # Load all datasets
85 |
86 | if data_sampling != -1:
87 | all_datasets = self.available_datasets
88 | np.random.seed(42)
89 | np.random.shuffle(all_datasets)
90 | self.datasets_to_load = all_datasets[:data_sampling]
91 | print(f"Sampled datasets: {self.datasets_to_load}")
92 |
93 | # Load the selected datasets into the table_dict
94 | self.table_dict_selection = table_dict_selection
95 | self.table_dict = self._load_selected_datasets()
96 | self.data_desc = self._generate_data_description()
97 |
98 | def _load_selected_datasets(self):
99 | """Loads only the selected datasets and returns a dictionary."""
100 | table_dict = {}
101 | for dataset in self.datasets_to_load:
102 | df_name = f"df_{dataset}"
103 | table_dict[df_name] = self._load_data(f"{dataset}.pkl")
104 | return table_dict
105 |
106 | def _load_data(self, file_name):
107 | """Helper method to load data from a pickle file."""
108 | try:
109 | return pd.read_pickle(os.path.join(self.data_path, file_name))
110 | except FileNotFoundError:
111 | print(f"File {file_name} not found in path {self.data_path}")
112 | return None
113 |
114 | def _generate_data_description(self):
115 | """Generates a description of each dataset's columns and the first row of data."""
116 | desc = ""
117 | for name, df in self.table_dict.items():
118 | if df is not None:
119 | desc += f"{name}:\n{dict(zip(df.columns.values, df.iloc[0].values))}\n\n"
120 | return desc
121 |
122 | def get_data(self, table_name):
123 | """Returns the requested DataFrame."""
124 | return self.table_dict.get(table_name, None)
125 |
126 | def load_into_globals(self):
127 | """Loads each dataset into the global namespace."""
128 | for name, df in self.table_dict.items():
129 | if df is not None:
130 | globals()[name] = df
131 |
132 | def display_data_description(self):
133 | """Prints the data description."""
134 | print(self.data_desc)
135 |
136 | def permute_selected_columns(self, random_seed = 42):
137 |
138 | if self.table_dict_selection == 'default':
139 | self.random_seed = random_seed
140 | """Permutes the specified columns together for each dataset in permute_columns."""
141 | for dataset_name, columns_to_permute in self.permute_columns.items():
142 | df_name = f"df_{dataset_name}"
143 | df = self.table_dict.get(df_name, None)
144 |
145 | if df is not None and all(col in df.columns for col in columns_to_permute):
146 | # Set the random seed for reproducibility
147 | if self.random_seed is not None:
148 | np.random.seed(self.random_seed)
149 |
150 | # Permute rows of the selected columns together
151 | # Shuffle the DataFrame rows and then reassign the columns
152 | permuted_df = df[columns_to_permute].sample(frac=1, random_state=self.random_seed).reset_index(drop=True)
153 | df[columns_to_permute] = permuted_df
154 |
155 | # Update the table_dict with the permuted DataFrame
156 | self.table_dict[df_name] = df
157 | print(f"Permuted columns {columns_to_permute} in dataset {df_name}")
158 | else:
159 | print(f"Columns {columns_to_permute} not found in dataset {df_name} or dataset does not exist.")
160 | elif self.table_dict_selection == 'all_bio':
161 | for df_name in self.table_dict:
162 | df = self.table_dict[df_name]
163 | permuted_df = df.copy()
164 | for col in df.columns:
165 | permuted_df[col] = df[col].sample(frac=1, random_state=random_seed).reset_index(drop=True)
166 | random_seed = random_seed * 2 % (2 ** 31)
167 | self.table_dict[df_name] = permuted_df
168 |
169 |
170 | def load_file_dynamic(filepath):
171 | # Read the first few bytes to infer the delimiter
172 | with open(filepath, 'r') as file:
173 | first_line = file.readline()
174 |
175 | # Check for delimiter
176 | if '\t' in first_line:
177 | delimiter = '\t'
178 | elif ',' in first_line:
179 | delimiter = ','
180 | else:
181 | raise ValueError("Unknown delimiter. File is neither CSV nor TSV.")
182 |
183 | # Load the DataFrame using the detected delimiter
184 | df = pd.read_csv(filepath, delimiter=delimiter)
185 | return df
186 |
187 |
188 | class DiscoveryBenchDataLoader:
189 | def __init__(self, data_path, metadata):
190 | self.data_path = data_path
191 | self.metadata = metadata # dictionary/json object
192 | self.available_datasets = metadata["datasets"]
193 | self.table_dict = self._load_datasets()
194 | self.data_desc = self._generate_data_description()
195 |
196 | def _load_datasets(self):
197 | table_dict = {}
198 | for entry in self.available_datasets:
199 | table_path = os.path.join(self.data_path, entry["name"])
200 | table_dict[f'df_{entry["name"].split(".")[0]}'] = load_file_dynamic(table_path)
201 | return table_dict
202 |
203 | def _generate_data_description(self):
204 | desc = {}
205 | for entry in self.available_datasets:
206 | # print(entry["name"])
207 | table_name = f'df_{entry["name"].split(".")[0]}'
208 | df = self.table_dict[table_name]
209 |
210 | columns = entry["columns"]["raw"]
211 | for column in columns:
212 | value = df[column['name']].iloc[0]
213 | if isinstance(value, np.generic):
214 | value = value.item()
215 | column["example_value"] = value
216 | desc_entry = {
217 | "description": entry["description"],
218 | "columns": json.dumps(columns),
219 | }
220 | desc[table_name] = desc_entry
221 |
222 | return json.dumps(desc, indent=4)
223 |
224 | def load_into_globals(self):
225 | """Loads each dataset into the global namespace."""
226 | for name, df in self.table_dict.items():
227 | if df is not None:
228 | globals()[name] = df
229 |
230 | def display_data_description(self):
231 | """Prints the data description."""
232 | print(self.data_desc)
233 |
234 | def permute_selected_columns(self, columns="all", random_seed = 42):
235 | # permute all columns
236 | np.random.seed(random_seed)
237 | for df_name in self.table_dict:
238 | df = self.table_dict[df_name]
239 | self.table_dict[df_name] = df.apply(np.random.permutation)
240 |
241 | # for df_name in self.table_dict:
242 | # df = self.table_dict[df_name]
243 | # permuted_df = df.copy()
244 | # for col in df.columns:
245 | # permuted_df[col] = df[col].sample(frac=1, random_state=random_seed).reset_index(drop=True)
246 | # random_seed = random_seed * 2 % (2 ** 31)
247 | # self.table_dict[df_name] = permuted_df
248 |
249 | def pretty_print(message, printout = True):
250 | if isinstance(message, tuple):
251 | title = message
252 | else:
253 | if isinstance(message.content, list):
254 | title = get_msg_title_repr(message.type.title().upper() + " Message", bold=is_interactive_env())
255 | if message.name is not None:
256 | title += f"\nName: {message.name}"
257 |
258 | for i in message.content:
259 | if i['type'] == 'text':
260 | title += f"\n{i['text']}\n"
261 | elif i['type'] == 'tool_use':
262 | title += f"\nTool: {i['name']}"
263 | title += f"\nInput: {i['input']}"
264 | if printout:
265 | print(f"{title}")
266 | else:
267 | title = get_msg_title_repr(message.type.title() + " Message", bold=is_interactive_env())
268 | if message.name is not None:
269 | title += f"\nName: {message.name}"
270 | title += f"\n\n{message.content}"
271 | if printout:
272 | print(f"{title}")
273 | return title
274 |
275 | class CustomDataLoader:
276 | def __init__(self, data_folder, random_seed=42):
277 | """Initialize data loader with path to data folder.
278 |
279 | Args:
280 | data_folder (str): Path to folder containing pickle files
281 | random_seed (int): Random seed for permutations
282 | """
283 | self.data_path = data_folder
284 | self.random_seed = random_seed
285 | self.table_dict = {}
286 | self._load_all_datasets()
287 | self.data_desc = self._generate_data_description()
288 |
289 | def _load_all_datasets(self):
290 | """Automatically loads all pickle and CSV files found in data_path."""
291 | pickle_files = glob(os.path.join(self.data_path, "*.pkl"))
292 | csv_files = glob(os.path.join(self.data_path, "*.csv"))
293 |
294 | if not pickle_files and not csv_files:
295 | raise ValueError(f"No pickle or CSV files found in {self.data_path}")
296 |
297 | for file_path in pickle_files + csv_files:
298 | file_name = os.path.basename(file_path)
299 | dataset_name = os.path.splitext(file_name)[0]
300 | df_name = f"df_{dataset_name}"
301 | if file_path.endswith('.pkl'):
302 | self.table_dict[df_name] = self._load_data(file_name)
303 | elif file_path.endswith('.csv'):
304 | self.table_dict[df_name] = pd.read_csv(file_path)
305 |
306 | def _load_data(self, file_name):
307 | """Helper method to load data from a pickle file."""
308 | try:
309 | df = pd.read_pickle(os.path.join(self.data_path, file_name))
310 | if not isinstance(df, pd.DataFrame):
311 | raise ValueError(f"File {file_name} does not contain a pandas DataFrame")
312 | return df
313 | except Exception as e:
314 | print(f"Error loading {file_name}: {str(e)}")
315 | return None
316 |
317 | def _generate_data_description(self):
318 | """Generates a description of each dataset's columns and the first row."""
319 | desc = ""
320 | for name, df in self.table_dict.items():
321 | if df is not None:
322 | desc += f"{name}:\nColumns: {df.columns.tolist()}\n"
323 | desc += f"Sample row: {dict(zip(df.columns, df.iloc[0]))}\n\n"
324 | return desc
325 |
326 | def get_data(self, table_name):
327 | """Returns the requested DataFrame."""
328 | return self.table_dict.get(table_name, None)
329 |
330 | def load_into_globals(self):
331 | """Loads each dataset into the global namespace."""
332 | for name, df in self.table_dict.items():
333 | if df is not None:
334 | globals()[name] = df
335 |
336 | def display_data_description(self):
337 | """Prints the data description."""
338 | print(self.data_desc)
339 |
340 | def permute_columns(self, dataset_name, columns_to_permute):
341 | """Permutes specified columns in a dataset.
342 |
343 | Args:
344 | dataset_name (str): Name of dataset (without 'df_' prefix)
345 | columns_to_permute (list): List of column names to permute
346 | """
347 | df_name = f"df_{dataset_name}"
348 | df = self.table_dict.get(df_name, None)
349 |
350 | if df is None:
351 | raise ValueError(f"Dataset {df_name} not found")
352 |
353 | if not all(col in df.columns for col in columns_to_permute):
354 | raise ValueError(f"Not all columns {columns_to_permute} found in {df_name}")
355 |
356 | np.random.seed(self.random_seed)
357 | permuted_df = df[columns_to_permute].sample(frac=1).reset_index(drop=True)
358 | df[columns_to_permute] = permuted_df
359 | self.table_dict[df_name] = df
--------------------------------------------------------------------------------
/popper/prompt_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | CODING_AGENT_SYSTEM_PROMPT_APPROX = '''You are an expert statistician specialized in the field of {domain}. You are tasked with validating a {domain} hypothesis (H) by collecting evidence supporting both the alternative hypothesis (h1) and the null hypothesis (h0).
4 |
5 | You should write code to gather, process, and analyze the available data, collecting evidence favoring both h1 and h0.
6 | The goal is to structure the evidence in a way that allows for a thorough and interpretable comparison, enabling an LLM to estimate the likelihood under both h1 and h0.
7 |
8 | The code should:
9 | - The output should be data/evidence, instead of test statistics.
10 | - Organize the evidence for h1 and h0 in a structured format, including metrics and qualitative descriptors.
11 | - Provide outputs that are interpretable, enabling easy comparison between the likelihoods of h1 and h0.
12 |
13 | You have access to the following pandas dataframe tables, where each table shows precise column names and an example row:
14 |
15 | {{context}}
16 |
17 | Write code based on the user’s request. Ensure any code provided is self-contained, executable, and includes all necessary imports and variable definitions.
18 |
19 | Structure your output as follows:
20 | 1) A brief summary of the approach,
21 | 2) The required imports,
22 | 3) The complete code.
23 |
24 | Include progress bars for lengthy processes where appropriate and optimize for time efficiency by using a small number of permutations (e.g., <1000) where relevant.
25 | Do not use placeholders. Each output should directly relate to the evidence under h1 and h0.
26 |
27 | The output should provide a comparison-ready format for h1 and h0. You should print the output in the code.
28 | --------------------------------------------
29 |
30 | Here is the user-requested falsification test specification:'''
31 |
32 |
33 | CODING_AGENT_SYSTEM_PROMPT = """You are an expert statistician specialized in the field of {domain}. You are tasked to validate rigorously if a {domain} hypothesis H is true by implementing an falsification test proposed by the user.
34 |
35 | You should write code to implement the falsification test.
36 | The test should be relevant to the main hypothesis and aims to falsify it.
37 | The test should use the available data described below, and use data processing, extraction, and perform statistical analysis to produce a p-value measuring the falsification of the main hypothesis.
38 | The test should be extremely rigorous. The p-value should be theoretically grounded.
39 | The code should be clear, concise, and efficient. Do progress bar when necessary. It will have a time limit, so please be efficient. For example, if possible, you can set the number of permutations to be small (e.g. <1000).
40 | The code should be self-contained, and do not need additional modifications from user.
41 |
42 | You have access to the following pandas dataframe tables, where each table, it shows the precise column names and a preview of column values:
43 |
44 | {{context}}
45 |
46 | Each of these dataframes have already been loaded into the global namespace. You may access each dataframe **directly as variables**. Make sure to use the **EXACT** dataframe names as shown above.
47 |
48 | Create a code from the user request. Ensure any code you provide can be executed with all required imports and variables defined.
49 | Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block.
50 | Invoke the code tool to structure the output correctly.
51 | NEVER PRODUCE ANY PLACEHOLDER IN ANY FUNCTION. PLACEHOLDER IS WORSE THAN FAILURE TO PRODUCE CODE.
52 | PLACEHOLDER including coming up with placeholder genes, names, ids, functions, p-value, or any other placeholder.
53 | The output should be a single p-value. If there are multiple p-values produced by the test, you should aggregate them in a meaningful and rigorous way.
54 | When printing p-values, please use scientific notations (e.g. 3.50e-03) instead of the raw number.
55 | For querying biological IDs, write code to look directly at raw datasets to map the exact ID, avoiding the use of LLMs to generate or infer gene names or IDs. Additionally, if the dataset includes p-values in its columns, refrain from using them as direct outputs of the falsification test; instead, process or contextualize them appropriately to maintain analytical rigor.
56 | -------------------------------------------------------
57 |
58 | Here is the user requested falsification test specification:"""
59 |
60 |
61 | def get_coding_agent_system_prompt(llm_approx, domain="biology",):
62 | if llm_approx:
63 | return CODING_AGENT_SYSTEM_PROMPT_APPROX.format(domain=domain)
64 | return CODING_AGENT_SYSTEM_PROMPT.format(domain=domain)
65 |
66 |
67 | REACT_CODING_AGENT_SYSTEM_PROMPT = """You are an expert statistician specialized in the field of {domain}. Given a Falsification Test, your task is to determine if you can reject the null hypothesis via rigorous data analysis and statistical testing.
68 |
69 | You have access to multiple datasets relevant to the hypothesis, as well as a python code execution environment to run your fasification test. The code execution environment has a persistent global namespace, meaning that states and variable names will persist through multiple rounds of code executions. Be sure to take advantage of this by developing your falsification test incrementally and reflect on the intermediate observations at each step, instead of coding up everything in one go. All datasets have already been loaded into the global namespace as pandas dataframes."""
70 |
71 | PROMPT_REVISION = """
72 | For querying biological IDs, write code to look directly at raw datasets to map the exact ID, avoiding the use of LLMs to generate or infer gene names or IDs. Additionally, if the dataset includes p-values in its columns, refrain from using them as direct outputs of the falsification test; instead, process or contextualize them appropriately to maintain analytical rigor.
73 | """
74 |
75 | def get_react_coding_agent_system_prompt(domain="biology", prompt_revision=False):
76 | if prompt_revision:
77 | return REACT_CODING_AGENT_SYSTEM_PROMPT.format(domain) + PROMPT_REVISION
78 | else:
79 | return REACT_CODING_AGENT_SYSTEM_PROMPT.format(domain=domain)
80 |
81 |
82 | LIKELIHOOD_ESTIMATION_AGENT_PROMPT = """Given a scientific hypothesis H, you have designed a sub-hypothesis test h to falsify the main hypothesis. You have also collected evidence from data for the null hypothesis (h0) and the alternative hypothesis (h1).
83 |
84 | Your goal is to:
85 | 1. Estimate the probability of this evidence under the alternative hypothesis, P(data|h1).
86 | 2. Estimate the probability of this evidence under the null hypothesis, P(data|h0).
87 |
88 | Follow this rigorous rubric to evaluate estimation precision, focusing on both theoretical grounding and accuracy in likelihood estimation:
89 |
90 | - **0.1**: Extremely poor estimate, lacks theoretical grounding; estimation is inconsistent with evidence and does not consider hypothesis structure.
91 | - **0.2**: Poor estimate; limited theoretical basis, fails to account for evidence specifics, and overlooks key elements of hypothesis testing.
92 | - **0.3**: Weak estimate, marginally considers evidence but lacks appropriate statistical measures or fails to apply probability theory accurately.
93 | - **0.4**: Below average; applies some basic probability theory but lacks rigor, poorly models the relationship between evidence and hypothesis.
94 | - **0.5**: Average estimate; applies probability theory minimally, captures some evidence but with limited specificity to the hypothesis context.
95 | - **0.6**: Above average; uses sound statistical principles, somewhat models the evidence-hypothesis relationship, but with notable gaps or simplifications.
96 | - **0.7**: Good estimate; well-grounded in theory, evidence is modeled with reasonable accuracy but lacks precision or depth in interpretation.
97 | - **0.8**: Very good estimate; rigorous application of probability theory, models evidence in the context of hypothesis well, with minor limitations in capturing uncertainty or alternative explanations.
98 | - **0.9**: Excellent estimate; highly accurate, theoretically sound, robustly interprets evidence under hypothesis, addressing key uncertainties and incorporating evidence nuances.
99 | - **1.0**: Perfect estimate; fully grounded in advanced probability theory, comprehensive and precise, accurately modeling all aspects of evidence given the hypothesis, leaving no uncertainties unaddressed.
100 |
101 | ---
102 | **Process**:
103 | - First, produce an initial estimate proposal.
104 | - In each round i, perform the following steps:
105 | 1. **Critique**: Evaluate the estimation’s reasonableness, theoretical rigor, and alignment with this rubric.
106 | 2. **Reflect**: Identify specific improvements to enhance accuracy and theoretical grounding based on critique.
107 | - If the estimation achieves a rigorous standard (e.g., reaching 0.9 or 1.0), return the final estimates:
108 | - P(data|h1) = [final value]
109 | - P(data|h0) = [final value]
110 | - If refinement is needed, improve or propose a new estimation, then proceed to the next round.
111 |
112 | ---
113 | **Information**:
114 | - Main Scientific Hypothesis H:
115 | {main_hypothesis}
116 |
117 | - Falsification Test Sub-Hypothesis h:
118 | {falsification_test}
119 |
120 | - Evidence:
121 | {data}
122 | """
123 |
124 | def get_likelihood_estimation_agent_prompt(main_hypothesis, falsification_test, data):
125 | return LIKELIHOOD_ESTIMATION_AGENT_PROMPT.format(main_hypothesis=main_hypothesis, falsification_test=falsification_test, data=data)
126 |
127 |
128 | TEST_PROPOSAL_AGENT_SYSTEM_PROMPT = """You are an expert statistician specialized in the field of {domain}."""
129 |
130 | TEST_PROPOSAL_AGENT_USER_PROMPT = '''
131 | Given a {domain} hypothesis "{main_hypothesis}", your goal is to propose a novel falsification test given the available {domain} data sources.
132 | A falsification test is a test that can potentially falsify the main hypothesis.
133 | The outcome of the falsification test is to return a p-value that measures the evidence to falsify the main hypothesis.
134 |
135 | Notably, the falsification test should satisfy the following property: if the main hypotheiss is null, then the falsification sub-hypothesis should also be null.
136 |
137 | Here are the list of available data sources, and you can directly call the dataframe as it has already been loaded; no need to load from file path. Each is a pandas dataframe with columns and example rows:
138 |
139 | {data}
140 |
141 | For the final test, return
142 | (1) Name: name of the test
143 | (2) Test description: be clear and concise. Describe the falsification outcomes.
144 | (3) Null sub-hypothesis h_0: what is the statistical null sub-hypothesis does this falsification test aim to test?
145 | (4) Alternate sub-hypothesis h_1: what is the statistical alternative sub-hypothesis does this falsification test aim to test?
146 |
147 | Here are the falsification tests that you've created in the previous rounds and their corresponding test results:
148 |
149 | """
150 | {existing_falsification_test}
151 | """
152 |
153 | You may use these information to formulate your next subhypothesis and falsification test, but make sure the proposed falsification test is non-redundant with any of the existing tests.
154 |
155 | The proposed test should also avoid these failed falsification tests in the previous rounds:
156 |
157 | """
158 | {failed_falsification_test}
159 | """
160 |
161 | A good falsification test should serve as a strong evidence for the main hypothesis. However, make sure it is answerable with the given available data sources.
162 | You should aim to maximize the implication strength of the proposed falsification test using the relevant parts of the provided data.
163 |
164 | ----
165 | First produce an initial falsification test proposal.
166 |
167 | Then, in each round i, you will do the following:
168 | (1) critic: ask if the main hypothesis is null, is this test also null? be rigorous. this is super important, otherwise, the test is invalid. Is it redundant on capabilities with existing tests? Is it overlapping with failed tests? Can this be answered and implemented based on the given data?
169 | (2) reflect: how to improve this test definition.
170 |
171 | If you think the test definition is good enough, return the final test definition to the user.
172 | If not, either refine the test definition that is better than the previous one or propose a new test definition, then go to the next round.
173 | '''
174 |
175 | def get_test_proposal_agent_system_prompt(domain):
176 | return TEST_PROPOSAL_AGENT_SYSTEM_PROMPT.format(domain=domain)
177 |
178 | def get_test_proposal_agent_user_prompt(domain, main_hypothesis, data, existing_tests, failed_tests):
179 | return TEST_PROPOSAL_AGENT_USER_PROMPT.format(domain = domain, main_hypothesis = main_hypothesis, data = data, existing_falsification_test = existing_tests, failed_falsification_test = failed_tests)
180 |
181 |
182 | SUMMARIZER_SYSTEM_PROMPT = """You are a helpful assistant trained to help scientists summarize their experiment observations.
183 | You have observed a sequential falsification test procedure of a scientific hypothesis and your goal is to accurately summarize and extract insights to present to a human scientist.
184 | For the observed list of falsification tests, each test includes the test description and its test results.
185 |
186 | The final output should state the following:
187 | (1) The main scientific hypothesis under study
188 | (2) The result of the sequential falsification test
189 | (3) Reasoning, summarizing, and analyzing these results
190 | (4) Your conclusion on whether or not this hypothesis is true or false; just return True/False
191 | (5) Rationale of the conclusion
192 |
193 | Remember, your MUST STRICTLY ADHERE to the experiment observations WITHOUT your personal bias or interpretations. For example, if the experiments fail to reject the null hypothesis, you MUST output the conclusion as False EVEN IF YOU BELIEVE THE STATEMENT IS TRUE.
194 | """
195 |
196 | def get_summarizer_system_prompt():
197 | return SUMMARIZER_SYSTEM_PROMPT
198 |
199 |
200 | RELEVANCE_PROMPT = """
201 | Given a main hypothesis and a proposed sub-hypothesis test, assess the relevance of this sub-hypothesis test to the main hypothesis.
202 | Use the following rubric to guide your response, providing a score from 0.1 to 1.0 and a brief justification for the score.
203 | Each score level represents a different degree of relevance based on evidence strength, mechanistic connection, and predictive value of the test results.
204 |
205 | Rubric:
206 |
207 | 1.0 - Highly Relevant: The sub-hypothesis provides direct evidence or a clear mechanistic insight that strongly supports or refutes the main hypothesis. The test is specific to variables or mechanisms involved in the main hypothesis, with significant predictive value.
208 | 0.8 - Strongly Relevant: The test addresses a major component of the main hypothesis, providing substantial supporting or refuting evidence, and shows strong mechanistic alignment. The results would significantly impact the confidence in the main hypothesis.
209 | 0.6 - Moderately Relevant: The test examines elements supporting the main hypothesis without direct mechanistic insight. Some aspects align with the main hypothesis, offering moderate predictive value.
210 | 0.4 - Slightly Relevant: The test is related to the main hypothesis but provides limited direct evidence. It explores loosely associated variables and has minimal predictive value.
211 | 0.2 - Barely Relevant: The test is tangentially related, providing minimal information that could impact the main hypothesis, with no clear mechanistic link and negligible predictive value.
212 | 0.1 - Irrelevant: The sub-hypothesis does not provide relevant evidence or mechanistic connection to the main hypothesis, with no predictive value.
213 |
214 | Instructions:
215 | 1. Read the main hypothesis and the sub-hypothesis test carefully.
216 | 2. Choose the relevance score from the rubric that best matches the relationship.
217 | 3. Explain your reasoning for selecting this score, referring to evidence strength, mechanistic connection, and predictive value of the sub-hypothesis test results.
218 | """
219 |
220 | def get_relevance_prompt():
221 | return RELEVANCE_PROMPT
222 |
223 |
224 |
225 |
226 | def bind_tools_to_system_prompt(system_prompt, tools):
227 | return f'''You are an intelligent agent capable of calling tools to complete user-assigned tasks.
228 | Here are the instructions specified by the user:
229 | """{system_prompt}"""
230 |
231 | In addition, you have access to the following tools:
232 | {json.dumps(tools, indent=4)}
233 |
234 | You may output any intermediate thoughts or reasonings before delivering your final response.
235 | Your final response must either be at least one tool call or a response message to the user.
236 |
237 | To make one or more tool calls, wrap your final response in the following JSON format:
238 | {{
239 | "type": "tool_calls",
240 | "content": [
241 | {{
242 | "name": "name of the function to call",
243 | "id": "an unique id for this tool call",
244 | "arguments": {{
245 | "argument1": value1,
246 | "argument2": value2,
247 | ...
248 | }}
249 | }},
250 | ...
251 | ]
252 | }}
253 |
254 | To send a direct response message to the user, wrap your final response in the following JSON format:
255 | {{
256 | "type": "text_message",
257 | "content": "content of the message according to the user instructions"
258 | }}
259 |
260 | You must choose either to send tool calls or a direct response message. Be sure to format the final response properly according to the given JSON specs.
261 |
262 | DO NOT put anything after the final response JSON object.'''
--------------------------------------------------------------------------------
/baseline_agents/react_utils.py:
--------------------------------------------------------------------------------
1 | # Set up the base template
2 | from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser
3 | from langchain.prompts import StringPromptTemplate
4 | from langchain.tools import BaseTool
5 | from langchain.chains.llm import LLMChain
6 | from langchain_experimental.tools.python.tool import PythonAstREPLTool
7 | from typing import List, Union, Dict
8 | from langchain.schema import AgentAction, AgentFinish
9 | from pydantic import Field, PrivateAttr
10 | import contextlib
11 | import io
12 | import logging
13 | import re
14 |
15 | logging.basicConfig(level=logging.INFO)
16 |
17 | template = """{system_prompt}
18 |
19 | You have access to the following tools:
20 | {tools}
21 |
22 | Use the following format:
23 |
24 | Question: an input hypothesis that you must decide if it is True or False
25 | Datasets: the names and descriptions of datasets relevant to the input hypothesis
26 | Thought: you should always think about what to do
27 | Action: the action to take, should be one of [{tool_names}]
28 | Action Input: the input to the action
29 | Observation: the result of the action
30 | ... (this Thought/Action/Action Input/Observation can repeat N times)
31 | Thought: I now know the final answer
32 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer
33 | Final Answer: True/False. Please output True if the input hypothesis is valid (e.g., you are able to reject the null hypothesis with statistical significance) and False if the input hypothesis is invalid (e.g., if you fail to reject the null hypothesis).
34 |
35 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as "WORKFLOW SUMMARY:"
36 |
37 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section.
38 |
39 |
40 | Example
41 | Question: Is the following hypothesis True or False? There is a statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates.
42 | Datasets:
43 | {{
44 | "name": "df_body-size-evolution-in-south-american-freshwater-fishes",
45 | "description": "Data on body size evolution in South American freshwater fishes, including speciation and extinction rates",
46 | "columns": {{
47 | "raw": [
48 | {{
49 | "name": "HYBAS_ID",
50 | "description": "Unique identifier for each hydrological basin"
51 | }},
52 | {{
53 | "name": "long",
54 | "description": "Longitude of the basin location"
55 | }},
56 | {{
57 | "name": "lat",
58 | "description": "Latitude of the basin location"
59 | }},
60 | {{
61 | "name": "BAMM_speciation",
62 | "description": "Rate of speciation as calculated by the BAMM method"
63 | }},
64 | {{
65 | "name": "BAMM_extinction",
66 | "description": "Rate of extinction as calculated by the BAMM method"
67 | }},
68 | {{
69 | "name": "BAMM_NetDiv",
70 | "description": "Net diversification rate, calculated as speciation minus extinction"
71 | }},
72 | {{
73 | "name": "aet",
74 | "description": "Mean annual evapotranspiration for each basin"
75 | }},
76 | {{
77 | "name": "Elevation",
78 | "description": "Average elevation of the basin"
79 | }},
80 | {{
81 | "name": "sgr",
82 | "description": "Species growth rate in each basin"
83 | }},
84 | {{
85 | "name": "soil_div",
86 | "description": "Soil diversity index for each basin"
87 | }},
88 | {{
89 | "name": "area",
90 | "description": "Total area of the basin in square kilometers"
91 | }},
92 | {{
93 | "name": "diversity",
94 | "description": "Diversity index for the species in each basin"
95 | }}
96 | ]
97 | }}
98 | }}
99 | Thought: First, I need to load the dataset from the global namespace in Python and inspect the data to identify the relevant columns for this hypothesis test.
100 | Action: python_repl_ast
101 | Action Input: import pandas as pd\n\ndf = df_body-size-evolution-in-south-american-freshwater-fishes\ndf.head()
102 | Observation:
103 | HYBAS_ID long lat BAMM_speciation BAMM_extinction BAMM_NetDiv ... aet Elevation sgr soil_div area diversity
104 | 0 6050000010 -76.477422 7.742693 0.137392 0.026807 0.110585 ... 1387 330.150088 166 0.482402 72363.7 68
105 | 1 6050000740 -74.628725 9.803586 0.117235 0.025796 0.091438 ... 1082 69.475294 23 0.457436 17944.3 35
106 | 2 6050068100 -75.295995 8.448815 0.119381 0.023826 0.095555 ... 1312 143.032178 74 0.378793 17105.5 44
107 | 3 6050068110 -74.608408 8.922863 0.132477 0.027777 0.104700 ... 1445 14.724138 3 0.468328 610.1 48
108 | 4 6050070260 -75.591588 5.770093 0.120127 0.022940 0.097187 ... 1371 1378.729945 421 0.158870 61901.9 81
109 | [5 rows x 21 columns]
110 | Thought: Now that the dataset is loaded and I can see the columns, I need to perform a statistical test to assess the significance of the relationship between 'BAMM_speciation' and 'BAMM_NetDiv'.
111 | Action: python_repl_ast
112 | Action Input: from scipy.stats import linregress\n\n# Perform linear regression to test for a statistically significant relationship\nresult = linregress(df['BAMM_speciation'], df['BAMM_NetDiv'])\ncoefficient = result.slope\np_value = result.pvalue\ncoefficient, p_value
113 | Observation: (0.5175306498596297, 0.0035)
114 | Thought: The linear regression analysis provides a coefficient of approximately 0.518, indicating a positive relationship, and the p-value is 0.0035, which is statistically significant at the 0.05 level. Based on this, I can conclude that the hypothesis is true.
115 | WORKFLOW SUMMARY:
116 | 1. Data Loading: Loaded the dataset from the global namespace using Python.
117 | 2. Data Inspection: Displayed the first few rows of the dataset to confirm relevant columns.
118 | 3. Statistical Analysis: Performed a linear regression analysis between 'BAMM_speciation' (predictor) and 'BAMM_NetDiv' (response). The analysis yielded a positive coefficient and a p-value of 0.0035, indicating statistical significance.
119 | Final Answer: True
120 |
121 |
122 | Begin!
123 |
124 | {input}
125 | {agent_scratchpad}"""
126 |
127 |
128 | template_v2 = """{system_prompt}
129 |
130 | You have access to the following tools:
131 | {tools}
132 |
133 | Use the following format:
134 |
135 | Question: an input hypothesis that you must decide if it is True or False
136 | Datasets: the names and descriptions of datasets relevant to the input hypothesis
137 | Thought: you should always think about what to do
138 | Action: the action to take, should be one of [{tool_names}]
139 | Action Input: the input to the action
140 | Observation: the result of the action
141 | ... (this Thought/Action/Action Input/Observation can repeat N times)
142 | Thought: I now know the final answer
143 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer
144 | Final Answer: True/False. Please output True if you believe the input hypothesis is correct and False if the input hypothesis is not based on your analysis.
145 |
146 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as "WORKFLOW SUMMARY:"
147 |
148 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section.
149 |
150 |
151 | Example
152 | Question: Is the following hypothesis True or False? There is a statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates.
153 | Datasets:
154 | {{
155 | "name": "df_body-size-evolution-in-south-american-freshwater-fishes",
156 | "description": "Data on body size evolution in South American freshwater fishes, including speciation and extinction rates",
157 | "columns": {{
158 | "raw": [
159 | {{
160 | "name": "HYBAS_ID",
161 | "description": "Unique identifier for each hydrological basin"
162 | }},
163 | {{
164 | "name": "long",
165 | "description": "Longitude of the basin location"
166 | }},
167 | {{
168 | "name": "lat",
169 | "description": "Latitude of the basin location"
170 | }},
171 | {{
172 | "name": "BAMM_speciation",
173 | "description": "Rate of speciation as calculated by the BAMM method"
174 | }},
175 | {{
176 | "name": "BAMM_extinction",
177 | "description": "Rate of extinction as calculated by the BAMM method"
178 | }},
179 | {{
180 | "name": "BAMM_NetDiv",
181 | "description": "Net diversification rate, calculated as speciation minus extinction"
182 | }},
183 | {{
184 | "name": "aet",
185 | "description": "Mean annual evapotranspiration for each basin"
186 | }},
187 | {{
188 | "name": "Elevation",
189 | "description": "Average elevation of the basin"
190 | }},
191 | {{
192 | "name": "sgr",
193 | "description": "Species growth rate in each basin"
194 | }},
195 | {{
196 | "name": "soil_div",
197 | "description": "Soil diversity index for each basin"
198 | }},
199 | {{
200 | "name": "area",
201 | "description": "Total area of the basin in square kilometers"
202 | }},
203 | {{
204 | "name": "diversity",
205 | "description": "Diversity index for the species in each basin"
206 | }}
207 | ]
208 | }}
209 | }}
210 | Thought: First, I need to load the dataset from the global namespace in Python and inspect the data to identify the relevant columns for this hypothesis test.
211 | Action: python_repl_ast
212 | Action Input: import pandas as pd\n\ndf = df_body-size-evolution-in-south-american-freshwater-fishes\ndf.head()
213 | Observation:
214 | HYBAS_ID long lat BAMM_speciation BAMM_extinction BAMM_NetDiv ... aet Elevation sgr soil_div area diversity
215 | 0 6050000010 -76.477422 7.742693 0.137392 0.026807 0.110585 ... 1387 330.150088 166 0.482402 72363.7 68
216 | 1 6050000740 -74.628725 9.803586 0.117235 0.025796 0.091438 ... 1082 69.475294 23 0.457436 17944.3 35
217 | 2 6050068100 -75.295995 8.448815 0.119381 0.023826 0.095555 ... 1312 143.032178 74 0.378793 17105.5 44
218 | 3 6050068110 -74.608408 8.922863 0.132477 0.027777 0.104700 ... 1445 14.724138 3 0.468328 610.1 48
219 | 4 6050070260 -75.591588 5.770093 0.120127 0.022940 0.097187 ... 1371 1378.729945 421 0.158870 61901.9 81
220 | [5 rows x 21 columns]
221 | Thought: Now that the dataset is loaded and I can see the columns, I need to perform a statistical test to assess the significance of the relationship between 'BAMM_speciation' and 'BAMM_NetDiv'.
222 | Action: python_repl_ast
223 | Action Input: from scipy.stats import linregress\n\n# Perform linear regression to test for a statistically significant relationship\nresult = linregress(df['BAMM_speciation'], df['BAMM_NetDiv'])\ncoefficient = result.slope\np_value = result.pvalue\ncoefficient, p_value
224 | Observation: (0.5175306498596297, 0.0035)
225 | Thought: The linear regression analysis provides a coefficient of approximately 0.518, indicating a positive relationship, and the p-value is 0.0035, which is statistically significant at the 0.05 level. Based on this, I can conclude that the hypothesis is true.
226 | WORKFLOW SUMMARY:
227 | 1. Data Loading: Loaded the dataset from the global namespace using Python.
228 | 2. Data Inspection: Displayed the first few rows of the dataset to confirm relevant columns.
229 | 3. Statistical Analysis: Performed a linear regression analysis between 'BAMM_speciation' (predictor) and 'BAMM_NetDiv' (response). The analysis yielded a positive coefficient and a p-value of 0.0035, indicating statistical significance.
230 | Final Answer: True
231 |
232 |
233 | Begin!
234 |
235 | {input}
236 | {agent_scratchpad}"""
237 |
238 |
239 | def load_data_to_react_globals(data_loader):
240 | for name, df in data_loader.table_dict.items():
241 | globals()[name] = df
242 |
243 |
244 | # Set up a prompt template
245 | class CustomPromptTemplate(StringPromptTemplate):
246 | # The template to use
247 | template: str
248 | # The list of tools available
249 | tools: List[BaseTool]
250 |
251 | def format(self, **kwargs) -> str:
252 | # Get the intermediate steps (AgentAction, Observation tuples)
253 | # Format them in a particular way
254 | intermediate_steps = kwargs.pop("intermediate_steps")
255 | thoughts = ""
256 | for action, observation in intermediate_steps:
257 | thoughts += action.log
258 | thoughts += f"\nObservation: {observation}\nThought: "
259 | # Set the agent_scratchpad variable to that value
260 | kwargs["agent_scratchpad"] = thoughts
261 | # Create a tools variable from the list of tools provided
262 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
263 | # Create a list of tool names for the tools provided
264 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
265 | return self.template.format(**kwargs)
266 |
267 | # CustomOutputParser to parse the output of the LLM and execute actions
268 | class CustomOutputParser(AgentOutputParser):
269 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
270 | # Check if agent should finish
271 | if "Final Answer:" in llm_output:
272 | output = llm_output.split("Final Answer:")[-1].split()[0].strip().lower()
273 | if output not in ["true", "false", "yes", "no", "y", "n"]:
274 | raise ValueError(f"Could not parse LLM output: `{llm_output}`")
275 | return AgentFinish(
276 | return_values={"output": output in ["true", "yes", 'y']},
277 | log=llm_output,
278 | )
279 | # Parse out the action and action input
280 | regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
281 | match = re.search(regex, llm_output, re.DOTALL)
282 | if not match:
283 | raise ValueError(f"Could not parse LLM output: `{llm_output}`")
284 | action = match.group(1).strip()
285 | action_input = match.group(2)
286 | # Return the action and action input
287 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
288 |
289 |
290 | class CustomPythonAstREPLTool(PythonAstREPLTool):
291 | _exec_globals:Dict = PrivateAttr()
292 | def __init__(self, *args, **kwargs):
293 | super().__init__(*args, **kwargs)
294 | # Initialize a persistent global namespace for code execution
295 | self._exec_globals = {}
296 | self._exec_globals.update(__builtins__)
297 |
298 | def _set_globals(self, table_dict=None):
299 | self._exec_globals = {}
300 | self._exec_globals.update(__builtins__)
301 |
302 | if table_dict:
303 | self._exec_globals.update(table_dict)
304 |
305 | def _run(self, query: str, run_manager=None):
306 | code_match = re.search(r"```(.*?)```", query, re.DOTALL)
307 | if code_match:
308 | # Extract code within backticks
309 | code = code_match.group(1)
310 | else:
311 | code = query
312 | code = code.strip()
313 | if code.startswith("python"):
314 | code = code[len("python"):].lstrip()
315 |
316 | code_lines = code.strip().split('\n')
317 | code = '\n'.join(code_lines[:-1]) # avoid printing the last line twice
318 | last_line = code_lines[-1]
319 |
320 | output_capture = io.StringIO()
321 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture):
322 | logging.getLogger().handlers[0].stream = output_capture
323 | try:
324 | exec(code, self._exec_globals)
325 | try:
326 | result = eval(last_line, self._exec_globals)
327 | if result is not None:
328 | print(result, file=output_capture)
329 | except:
330 | pass
331 | except Exception as e:
332 | return str(e)
333 |
334 | # Retrieve the output and return it
335 | output = output_capture.getvalue()
336 | return output if output else "Execution completed without output."
337 |
338 |
339 | def create_agent(
340 | llm,
341 | handlers,
342 | max_iterations = None,
343 | early_stopping_method: str = "force",
344 | simple_template = False
345 | ):
346 | output_parser = CustomOutputParser()
347 | python_tool = CustomPythonAstREPLTool(callbacks=handlers)
348 | tools = [python_tool]
349 | tool_names = [tool.name for tool in tools]
350 |
351 | if simple_template:
352 | use_template = template_v2
353 | else:
354 | use_template = template
355 |
356 | prompt = CustomPromptTemplate(
357 | template=use_template,
358 | tools=tools,
359 | input_variables=["system_prompt", "input", "intermediate_steps"]
360 | )
361 | llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=handlers)
362 |
363 | agent = LLMSingleActionAgent(
364 | llm_chain=llm_chain,
365 | output_parser=output_parser,
366 | stop=["\nObservation:"],
367 | allowed_tools=tool_names
368 | )
369 |
370 | return AgentExecutor.from_agent_and_tools(
371 | agent=agent,
372 | tools=tools,
373 | verbose=True,
374 | max_iterations=max_iterations,
375 | callbacks=handlers,
376 | early_stopping_method=early_stopping_method
377 | )
--------------------------------------------------------------------------------
/popper/llm/custom_model.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import sys
4 | import uuid
5 | import json
6 | from typing import (
7 | TYPE_CHECKING,
8 | Any,
9 | AsyncIterator,
10 | Callable,
11 | Dict,
12 | Iterator,
13 | List,
14 | Literal,
15 | Mapping,
16 | Optional,
17 | Sequence,
18 | Tuple,
19 | Type,
20 | Union,
21 | )
22 |
23 | from langchain_core.callbacks import (
24 | AsyncCallbackManagerForLLMRun,
25 | CallbackManagerForLLMRun,
26 | )
27 | from langchain_core.messages import (
28 | AIMessage,
29 | AIMessageChunk,
30 | BaseMessage,
31 | BaseMessageChunk,
32 | ChatMessage,
33 | ChatMessageChunk,
34 | FunctionMessage,
35 | FunctionMessageChunk,
36 | HumanMessage,
37 | HumanMessageChunk,
38 | InvalidToolCall,
39 | SystemMessage,
40 | SystemMessageChunk,
41 | ToolCall,
42 | ToolMessage,
43 | ToolMessageChunk,
44 | )
45 | from langchain_core.language_models import BaseChatModel, SimpleChatModel, LanguageModelInput
46 | from langchain_core.language_models.llms import create_base_retry_decorator
47 | from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage
48 | from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
49 | # from langchain_core.pydantic import BaseModel, Field
50 | from pydantic import BaseModel, Field, PrivateAttr
51 | from langchain_core.runnables import run_in_executor, Runnable
52 | from langchain_core.output_parsers.openai_tools import (
53 | JsonOutputKeyToolsParser,
54 | PydanticToolsParser,
55 | make_invalid_tool_call,
56 | parse_tool_call,
57 | )
58 | from langchain_core.utils import secret_from_env
59 | from langchain_core.utils.function_calling import convert_to_openai_tool
60 | from langchain_core.tools import BaseTool
61 | # from langchain_community.adapters.openai import (
62 | # convert_dict_to_message,
63 | # convert_message_to_dict,
64 | # )
65 | from langchain_openai import ChatOpenAI
66 | # from langchain_openai.chat_models import _convert_dict_to_message
67 | # from langchain_community.chat_models.openai import ChatOpenAI
68 | from popper.llm.prompt_utils import bind_tools_to_system_prompt
69 | from popper.llm.utils import parse_llm_output
70 |
71 |
72 | def _convert_message_to_dict(message: BaseMessage) -> dict:
73 | """Convert a LangChain message to a dictionary.
74 |
75 | Args:
76 | message: The LangChain message.
77 |
78 | Returns:
79 | The dictionary.
80 | """
81 | message_dict: Dict[str, Any]
82 | if isinstance(message, ChatMessage):
83 | message_dict = {"role": message.role, "content": message.content}
84 | elif isinstance(message, HumanMessage):
85 | message_dict = {"role": "user", "content": message.content}
86 | elif isinstance(message, AIMessage):
87 | message_dict = {"role": "assistant", "content": message.content}
88 | if "function_call" in message.additional_kwargs:
89 | message_dict["function_call"] = message.additional_kwargs["function_call"]
90 | # If function call only, content is None not empty string
91 | if message_dict["content"] == "":
92 | message_dict["content"] = None
93 | if "tool_calls" in message.additional_kwargs:
94 | message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
95 | # If tool calls only, content is None not empty string
96 | if message_dict["content"] == "":
97 | message_dict["content"] = None
98 | if "full_message" in message.additional_kwargs and message.additional_kwargs["full_message"]:
99 | # full_message is not empty, replace the current content with full message
100 | message_dict["content"] = message.additional_kwargs["full_message"]
101 | elif isinstance(message, SystemMessage):
102 | message_dict = {"role": "system", "content": message.content}
103 | elif isinstance(message, FunctionMessage):
104 | message_dict = {
105 | "role": "function",
106 | "content": message.content,
107 | "name": message.name,
108 | }
109 | elif isinstance(message, ToolMessage):
110 | run_results = {
111 | "name": message.name,
112 | "id": message.tool_call_id,
113 | "return_value": message.content
114 | }
115 | message_dict = {
116 | "role": "user",
117 | "content": "Tool call results:\n" + json.dumps(run_results, indent=4),
118 | "tool_call_id": message.tool_call_id,
119 | }
120 | else:
121 | raise TypeError(f"Got unknown type {message}")
122 | if "name" in message.additional_kwargs:
123 | message_dict["name"] = message.additional_kwargs["name"]
124 | return message_dict
125 |
126 |
127 | def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
128 | """Convert a dictionary to a LangChain message.
129 |
130 | Args:
131 | _dict: The dictionary.
132 |
133 | Returns:
134 | The LangChain message.
135 | """
136 | role = _dict.get("role")
137 | name = _dict.get("name")
138 | id_ = _dict.get("id")
139 | if role == "user":
140 | return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
141 | elif role == "assistant":
142 | # Fix for azure
143 | # Also OpenAI returns None for tool invocations
144 | content = _dict.get("content", "") or ""
145 | additional_kwargs: Dict = {}
146 | if function_call := _dict.get("function_call"):
147 | additional_kwargs["function_call"] = dict(function_call)
148 | if full_message := _dict.get("full_message"):
149 | additional_kwargs["full_message"] = full_message
150 | tool_calls = []
151 | invalid_tool_calls = []
152 | if raw_tool_calls := _dict.get("tool_calls"):
153 | additional_kwargs["tool_calls"] = raw_tool_calls
154 | for raw_tool_call in raw_tool_calls:
155 | try:
156 | tool_calls.append(ToolCall(name=raw_tool_call['function']['name'], args=raw_tool_call['function']['arguments'], id=raw_tool_call['id'], type="tool_call"))
157 | except Exception as e:
158 | invalid_tool_calls.append(
159 | make_invalid_tool_call(str(raw_tool_call), str(e))
160 | )
161 | # print(tool_calls)
162 | # print()
163 | # print(invalid_tool_calls)
164 | return AIMessage(
165 | content=content,
166 | additional_kwargs=additional_kwargs,
167 | name=name,
168 | id=id_,
169 | tool_calls=tool_calls,
170 | invalid_tool_calls=invalid_tool_calls,
171 | )
172 | elif role == "system":
173 | return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
174 | elif role == "function":
175 | return FunctionMessage(
176 | content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
177 | )
178 | elif role == "tool":
179 | additional_kwargs = {}
180 | if "name" in _dict:
181 | additional_kwargs["name"] = _dict["name"]
182 | return ToolMessage(
183 | content=_dict.get("content", ""),
184 | tool_call_id=cast(str, _dict.get("tool_call_id")),
185 | additional_kwargs=additional_kwargs,
186 | name=name,
187 | id=id_,
188 | )
189 | else:
190 | return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
191 |
192 | class CustomChatModel(ChatOpenAI):
193 | model_type: str = Field(default="custom-chat")
194 | tools: Optional[List[Any]] = Field(default=None)
195 | tool_choice: Optional[Union[dict, str, Literal["auto", "none", "required", "any"], bool]] = Field(default=None)
196 |
197 | @property
198 | def lc_secrets(self) -> Dict[str, str]:
199 | return {"openai_api_key": "EMPTY"}
200 |
201 | @property
202 | def _llm_type(self) -> str:
203 | """Return type of chat model."""
204 | return self.model_type + "-chat"
205 |
206 | def bind_tools(
207 | self,
208 | tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
209 | tool_choice: Optional[
210 | Union[dict, str, Literal["auto", "none", "required", "any"], bool]
211 | ] = None,
212 | strict: Optional[bool] = None,
213 | **kwargs: Any,
214 | ) -> Runnable[LanguageModelInput, BaseMessage]:
215 | """Bind tool-like objects to this chat model.
216 |
217 | Assumes model is compatible with OpenAI tool-calling API.
218 |
219 | Args:
220 | tools: A list of tool definitions to bind to this chat model.
221 | Supports any tool definition handled by
222 | :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
223 | tool_choice: Which tool to require the model to call.
224 | Must be the name of the single provided function or
225 | "auto" to automatically determine which function to call
226 | (if any), or a dict of the form:
227 | {"type": "function", "function": {"name": <>}}.
228 | kwargs: Any additional parameters are passed directly to
229 | ``self.bind(**kwargs)``.
230 | """
231 | formatted_tools = [convert_to_openai_tool(tool, strict=strict) for tool in tools]
232 |
233 | if tool_choice:
234 | if isinstance(tool_choice, str):
235 | # tool_choice is a tool/function name
236 | if tool_choice not in ("auto", "none", "any", "required"):
237 | tool_choice = {
238 | "type": "function",
239 | "function": {"name": tool_choice},
240 | }
241 | # 'any' is not natively supported by OpenAI API.
242 | # We support 'any' since other models use this instead of 'required'.
243 | if tool_choice == "any":
244 | tool_choice = "required"
245 | elif isinstance(tool_choice, bool):
246 | tool_choice = "required"
247 | elif isinstance(tool_choice, dict):
248 | tool_names = [
249 | formatted_tool["function"]["name"]
250 | for formatted_tool in formatted_tools
251 | ]
252 | if not any(
253 | tool_name == tool_choice["function"]["name"]
254 | for tool_name in tool_names
255 | ):
256 | raise ValueError(
257 | f"Tool choice {tool_choice} was specified, but the only "
258 | f"provided tools were {tool_names}."
259 | )
260 | else:
261 | raise ValueError(
262 | f"Unrecognized tool_choice type. Expected str, bool or dict. "
263 | f"Received: {tool_choice}"
264 | )
265 | kwargs["tool_choice"] = tool_choice
266 |
267 | self.tools = formatted_tools
268 | self.tool_choice = tool_choice
269 | return super().bind(tools=formatted_tools, **kwargs)
270 |
271 |
272 | def bind_functions(
273 | self,
274 | functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
275 | function_call: Optional[str] = None,
276 | **kwargs: Any,
277 | ) -> Runnable[LanguageModelInput, BaseMessage]:
278 | """Bind functions (and other objects) to this chat model.
279 |
280 | Args:
281 | functions: A list of function definitions to bind to this chat model.
282 | Can be a dictionary, pydantic model, or callable. Pydantic
283 | models and callables will be automatically converted to
284 | their schema dictionary representation.
285 | function_call: Which function to require the model to call.
286 | Must be the name of the single provided function or
287 | "auto" to automatically determine which function to call
288 | (if any).
289 | kwargs: Any additional parameters to pass to the
290 | :class:`~langchain.runnable.Runnable` constructor.
291 | """
292 | from langchain.chains.openai_functions.base import convert_to_openai_function
293 |
294 | formatted_functions = [convert_to_openai_function(fn) for fn in functions]
295 | self.tools = formatted_tools
296 | if function_call is not None:
297 | if len(formatted_functions) != 1:
298 | raise ValueError(
299 | "When specifying `function_call`, you must provide exactly one "
300 | "function."
301 | )
302 | if formatted_functions[0]["name"] != function_call:
303 | raise ValueError(
304 | f"Function call {function_call} was specified, but the only "
305 | f"provided function was {formatted_functions[0]['name']}."
306 | )
307 | function_call_ = {"name": function_call}
308 | kwargs = {**kwargs, "function_call": function_call_}
309 | return super().bind(
310 | functions=formatted_functions,
311 | **kwargs,
312 | )
313 |
314 |
315 | def _get_request_payload(
316 | self,
317 | input_: LanguageModelInput,
318 | *,
319 | stop: Optional[List[str]] = None,
320 | **kwargs: Any,
321 | ) -> dict:
322 |
323 | if "tools" in kwargs:
324 | self.tools = kwargs['tools']
325 | if "tool_choice" in kwargs:
326 | self.tool_choice = kwargs["tool_choice"]
327 |
328 | messages = self._convert_input(input_).to_messages()
329 | if stop is not None:
330 | kwargs["stop"] = stop
331 |
332 | message_dicts = [_convert_message_to_dict(m) for m in messages]
333 | if self.tools:
334 | has_system_prompt = False
335 | for msg in message_dicts:
336 | if msg['role'] == 'system':
337 | system_prompt = msg['content']
338 | msg['content'] = bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice)
339 | has_system_prompt = True
340 | break
341 | if not has_system_prompt:
342 | system_prompt = "You are a helpful assistant"
343 | message_dicts = [{
344 | 'role': 'system',
345 | 'content': bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice),
346 | }] + message_dicts
347 |
348 | if self.tool_choice is not None and message_dicts[-1]['role'] == 'user':
349 | last_user_message = message_dicts[-1]['content']
350 | message_dicts[-1]['content'] = f"""{last_user_message}
351 |
352 | Remember to format your reponse as a call to one of the following tools:
353 | {json.dumps(self.tools, indent=4)}
354 | Your tool call should have the following JSON format:
355 | following JSON format:
356 | {{
357 | "type": "tool_calls",
358 | "content": [
359 | {{
360 | "name": "name of the function to call",
361 | "id": "an unique id for this tool call",
362 | "arguments": {{
363 | "argument1": value1,
364 | "argument2": value2,
365 | ...
366 | }}
367 | }},
368 | ...
369 | ]
370 | }}
371 | """
372 | # print(message_dicts)
373 |
374 | return {
375 | "messages": message_dicts,
376 | **self._default_params,
377 | **kwargs,
378 | }
379 |
380 | def _create_message_dicts(
381 | self, messages: List[BaseMessage], stop: Optional[List[str]]
382 | ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
383 | params = self._client_params
384 | if stop is not None:
385 | if "stop" in params:
386 | raise ValueError("`stop` found in both the input and default params.")
387 | params["stop"] = stop
388 | message_dicts = [_convert_message_to_dict(m) for m in messages]
389 | if self.tools:
390 | has_system_prompt = False
391 | for msg in message_dicts:
392 | if msg['role'] == 'system':
393 | system_prompt = msg['content']
394 | msg['content'] = bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice)
395 | has_system_prompt = True
396 | break
397 | if not has_system_prompt:
398 | system_prompt = "You are a helpful assistant"
399 | message_dicts = [{
400 | 'role': 'system',
401 | 'content': bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice),
402 | }] + message_dicts
403 |
404 | return message_dicts, params
405 |
406 | def _create_chat_result(self, response: Union[dict, BaseModel], generation_info: Optional[Dict] = None) -> ChatResult:
407 | generations = []
408 | if not isinstance(response, dict):
409 | response = response.dict()
410 | for res in response["choices"]:
411 | # print(res)
412 | if self.tools:
413 | # attempt to parse the tool calls
414 | full_message = res["message"]["content"]
415 | scratchpad, parsed_message = parse_llm_output(full_message)
416 | if parsed_message['type'] == 'text_message':
417 | # res["message"]['full_message'] = full_message
418 | res["message"]["content"] = parsed_message["content"]
419 | else:
420 | assert parsed_message['type'] == 'tool_calls'
421 | tool_calls = []
422 | for tool_call in parsed_message['content']:
423 | if 'id' not in tool_call:
424 | tool_call['id'] = 'call_' + str(uuid.uuid4())
425 | tool_calls.append({
426 | 'id': tool_call['id'],
427 | 'type': 'function',
428 | 'function': {
429 | 'name': tool_call['name'],
430 | 'arguments': tool_call['arguments']
431 | }
432 | })
433 | res["message"]["tool_calls"] = tool_calls
434 | res["message"]['content'] = None
435 | res['finish_reason'] = 'tool_calls'
436 | res["message"]['full_message'] = scratchpad + json.dumps(parsed_message, indent=4)
437 | message = _convert_dict_to_message(res["message"])
438 | generation_info = dict(finish_reason=res.get("finish_reason"))
439 | if "logprobs" in res:
440 | generation_info["logprobs"] = res["logprobs"]
441 | gen = ChatGeneration(
442 | message=message,
443 | generation_info=generation_info,
444 | )
445 | generations.append(gen)
446 | # print(message)
447 | token_usage = response.get("usage", {})
448 | llm_output = {
449 | "token_usage": token_usage,
450 | "model_name": self.model_name,
451 | "system_fingerprint": response.get("system_fingerprint", ""),
452 | }
453 | return ChatResult(generations=generations, llm_output=llm_output)
--------------------------------------------------------------------------------
/baseline_agents/self_refine_agent.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import io
3 | import logging
4 | import traceback
5 | import multiprocessing
6 | import re
7 | import sys
8 | import numpy as np
9 |
10 | from typing import Optional, Tuple, Dict, Any, List, Literal
11 | from typing_extensions import TypedDict, Annotated
12 | from pydantic import BaseModel, Field
13 |
14 | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15 | from langchain_core.messages.base import get_msg_title_repr
16 | from langchain_core.utils.interactive_env import is_interactive_env
17 | from langgraph.graph import StateGraph, START, END
18 |
19 | import openai
20 | from falsification_agent.llm.custom_model import CustomChatModel
21 | from langchain_anthropic import ChatAnthropic
22 | from langchain_ollama import ChatOllama
23 | from langchain_openai import ChatOpenAI
24 |
25 | logging.getLogger("httpx").setLevel(logging.WARNING)
26 |
27 |
28 | CODING_AGENT_SYSTEM_PROMPT = """You are an expert statistician. You are tasked to validate rigorously if a scientific hypothesis H is true by implementing a rigorous statistical test.
29 |
30 | You should write code to implement a falsification test for the given hypothesis.
31 | The test should be relevant to the main hypothesis and aims to falsify it.
32 | The test should use the available data described below, and use data processing, extraction, and perform statistical analysis to produce a p-value measuring the falsification of the main hypothesis.
33 | The test should be extremely rigorous. The p-value should be theoretically grounded.
34 | The code should be clear, concise, and efficient. Do progress bar when necessary. It will have a time limit, so please be efficient. For example, if possible, you can set the number of permutations to be small (e.g. <1000).
35 | The code should be self-contained, and do not need additional modifications from user.
36 |
37 | Create a code from the user request. Ensure any code you provide can be executed with all required imports and variables defined.
38 | Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block.
39 | Invoke the CodeOutputSpec tool to structure the output correctly.
40 | NEVER PRODUCE ANY PLACEHOLDER IN ANY FUNCTION. PLACEHOLDER IS WORSE THAN FAILURE TO PRODUCE CODE.
41 | PLACEHOLDER including coming up with placeholder genes, names, ids, functions, p-value, or any other placeholder.
42 | The output should be a single p-value. If there are multiple p-values produced by the test, you should aggregate them in a meaningful and rigorous way.
43 | When printing p-values, please use scientific notations (e.g. 3.50e-03) instead of the raw number.
44 | -------------------------------------------------------
45 |
46 | Here is the user requested falsification test specification:"""
47 |
48 | FEEDBACK_AGENT_SYSTEM_PROMPT = """You are an evaluateGPT. Given a test specification, an implementation of the test, and the results from executing the implemented test, your task is to evaluate if the test implementation and output is valid and provide detailed feedback on any identified issues and suggestions for improving the test.
49 |
50 | To evaluate the validity of the test, you should
51 | 1. Check if the test implementation strictly follows the test specification
52 | 2. Make sure the output shows a valid p-value without any errors. The p-value should have a reasonable value: it cannot be smaller than or equal to 0 or larger than 1.
53 | 3. Double-check that all data inputs used in the experiment are accessed from the provided data sources, and there are no fake/made-up data entries.
54 | 4. Examine carefully through the experiment implementation; make sure it uses rigorous statistical test and there are no bugs or logical issues in the code.
55 | 5. Carefully examine any other potential problems of the experiment, such as unhandled edge-cases, invalid sample sizes, or other subtle issues that might lead to a misleading result.
56 |
57 | If you find any problems with the test implementation or experiment output, please provide a detailed feedback on how to fix and refine the test. If the test is valid, please output the final p-value formatted in scientific notation."""
58 |
59 | def get_llm(model = 'claude-3-5-sonnet-20240620', temperature=0.7, **kwargs):
60 | source = "custom"
61 | if model[:7] == 'claude-':
62 | source = 'Anthropic'
63 | elif model[:4] == 'gpt-':
64 | source = 'OpenAI'
65 | elif model.startswith('llama'):
66 | source = "Llama"
67 | # if source not in ['OpenAI', 'Anthropic']:
68 | # raise ValueError('Invalid source')
69 | if source == 'OpenAI':
70 | return ChatOpenAI(model = model, temperature = temperature, **kwargs)
71 | elif source == 'Anthropic':
72 | return ChatAnthropic(model = model,
73 | temperature = temperature,
74 | max_tokens = 4096,
75 | **kwargs)
76 | elif source == 'Llama':
77 | llm = CustomChatModel(model = model, model_type=source, temperature = temperature)
78 | llm.client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY").chat.completions
79 | return llm
80 | else:
81 | # listen to a different port
82 | llm = CustomChatModel(model = model, model_type=source, temperature = temperature)
83 | llm.client = openai.Client(base_url="http://127.0.0.1:40000/v1", api_key="EMPTY").chat.completions
84 | return llm
85 |
86 |
87 | class FeedbackOutputSpec(BaseModel):
88 | """Output specification for the experiment evaluation & feedback"""
89 |
90 | is_valid: str = Field(
91 | description="The validity of the experiment implementation and output. Use 'Yes' if the experiment is valid, or 'No' if any issues are detected."
92 | )
93 | p_value: Optional[str] = Field(description="The p-value from the experiment, formatted in scientific notation. Include this field only if the experiment is valid.")
94 | feedback: str = Field(
95 | description="Detailed feedback to address identified issues and suggestions for improving the experiment."
96 | )
97 |
98 | class CodeOutputSpec(BaseModel):
99 | """Code output"""
100 |
101 | prefix: str = Field(description="Description of the problem and approach")
102 | imports: str = Field(description="Code block import statements")
103 | code: str = Field(description="Code block not including import statements")
104 |
105 |
106 | class SelfRefineState(TypedDict):
107 | iteration: int
108 | messages: List[Tuple[str, str]]
109 | code_impl: Optional[CodeOutputSpec]
110 | code_out: Optional[str]
111 | feedback_valid: bool
112 | p_value: Optional[float]
113 | feedback_text: Optional[str]
114 | done: bool
115 |
116 |
117 | class CodingAgent:
118 | """
119 | - Takes test_specification & data as input.
120 | - Produces code using CODING_AGENT_SYSTEM_PROMPT + user specification.
121 | - Executes the code and captures the output.
122 | """
123 |
124 | def __init__(self, llm="claude-3-5-sonnet-20241022", time_limit: int = 3):
125 | """
126 | Args:
127 | llm: name of the language model
128 | time_limit: Time limit (in minutes) to run code.
129 | """
130 | self.llm = get_llm(llm)
131 | self.time_limit = time_limit
132 |
133 | # We will build a ChatPromptTemplate for code generation
134 | # Then parse the result with CodeOutputSpec.
135 | self.code_generation_prompt = ChatPromptTemplate.from_messages(
136 | [
137 | (
138 | "system",
139 | CODING_AGENT_SYSTEM_PROMPT,
140 | ),
141 | ("placeholder", "{messages}"),
142 | ]
143 | )
144 |
145 | self.structured_llm = self.llm.with_structured_output(CodeOutputSpec)
146 |
147 | self._exec_globals = {}
148 | self._exec_globals.update(__builtins__)
149 |
150 | def _set_globals(self, table_dict=None):
151 | self._exec_globals = {}
152 | self._exec_globals.update(__builtins__)
153 |
154 | if table_dict:
155 | self._exec_globals.update(table_dict)
156 |
157 | def generate_code(self, messages) -> CodeOutputSpec:
158 | """Generate code from the LLM given a test specification."""
159 | # Invoke the pipeline
160 | code_gen_chain = self.code_generation_prompt | self.structured_llm
161 | result: CodeOutputSpec = code_gen_chain.invoke({"messages": messages})
162 |
163 | return result
164 |
165 | def execute_code(self, code_implementation: CodeOutputSpec) -> str:
166 | """
167 | Safely executes the provided code in a subprocess (with a time limit).
168 | Returns stdout/stderr as a string.
169 | """
170 | # Combine imports + code
171 | full_code = code_implementation.imports + "\n\n" + code_implementation.code
172 |
173 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S')
174 | logging.getLogger().setLevel(logging.INFO)
175 |
176 | # Worker function to run code in a separate process
177 | def run_code(queue):
178 | output_capture = io.StringIO()
179 | try:
180 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture):
181 | logging.getLogger().handlers[0].stream = output_capture
182 | exec_globals = self._exec_globals
183 | exec(full_code, exec_globals)
184 | except Exception:
185 | error_message = traceback.format_exc()
186 | queue.put(error_message)
187 | else:
188 | queue.put(output_capture.getvalue())
189 |
190 | # Use multiprocessing to enforce time limit
191 | queue = multiprocessing.Queue()
192 | proc = multiprocessing.Process(target=run_code, args=(queue,))
193 | proc.start()
194 | proc.join(timeout=self.time_limit * 60)
195 |
196 | if proc.is_alive():
197 | # Timed out
198 | proc.terminate()
199 | proc.join()
200 | return "TimeoutError: Code execution took too long."
201 |
202 | # Retrieve captured output
203 | if not queue.empty():
204 | code_output = queue.get()
205 | else:
206 | code_output = "No output was captured."
207 | return code_output
208 |
209 | def run(
210 | self,
211 | messages: Any
212 | ) -> Tuple[CodeOutputSpec, str]:
213 | """
214 | High-level method:
215 | 1) Generate code from LLM
216 | 2) Execute the code
217 | 3) Return code + output
218 | """
219 | code_impl = self.generate_code(messages)
220 | print("---------Code Implementation-----------")
221 | print(code_impl)
222 | code_out = self.execute_code(code_impl)
223 | print("----------Code Output-----------")
224 | print(code_out)
225 | return code_impl, code_out
226 |
227 | class FeedbackAgent:
228 | """
229 | - Takes in the test specification, data, code implementation, and execution result.
230 | - Evaluates using FEEDBACK_AGENT_SYSTEM_PROMPT.
231 | - Produces a FeedbackOutputSpec indicating whether the test is valid, a final p-value if valid, and suggestions.
232 | """
233 |
234 | def __init__(self, llm="claude-3-5-sonnet-20241022"):
235 | self.llm = get_llm(llm)
236 |
237 | # Build a ChatPromptTemplate for feedback
238 | self.feedback_prompt = ChatPromptTemplate.from_messages(
239 | [
240 | ("system", FEEDBACK_AGENT_SYSTEM_PROMPT),
241 | ("user", "{evaluation_input}"),
242 | ]
243 | )
244 | self.structured_llm = self.llm.with_structured_output(FeedbackOutputSpec)
245 |
246 | def run(
247 | self,
248 | test_specification: str,
249 | data: str,
250 | code_implementation: CodeOutputSpec,
251 | code_output: str,
252 | ) -> FeedbackOutputSpec:
253 | """
254 | 1) Concatenate specification + code + output into a single message
255 | 2) Ask LLM for feedback with FEEDBACK_AGENT_SYSTEM_PROMPT
256 | 3) Parse into FeedbackOutputSpec
257 | """
258 | spec_text = (
259 | f"Main Hypothesis: {test_specification}\n"
260 | )
261 |
262 | code_text = (
263 | f"--- Code Implementation ---\n"
264 | f"Prefix:\n{code_implementation.prefix}\n\n"
265 | f"Imports:\n{code_implementation.imports}\n\n"
266 | f"Code:\n{code_implementation.code}\n\n"
267 | )
268 |
269 | output_text = f"--- Execution Result ---\n{code_output}\n\n"
270 |
271 | # Summarize data if needed (or directly pass it)
272 | # For brevity, we just pass a note that data is available:
273 | data_text = f"Data description:\n{data}\n\n"
274 |
275 | evaluation_input = spec_text + code_text + output_text + data_text
276 |
277 | # Now run the LLM
278 | final_chain = self.feedback_prompt | self.structured_llm
279 | feedback_result: FeedbackOutputSpec = final_chain.invoke(
280 | {"evaluation_input": evaluation_input}
281 | )
282 |
283 | print("---------Feedback----------")
284 | print(feedback_result)
285 | return feedback_result
286 |
287 | class SelfRefineAgent:
288 | """
289 | - Orchestrates iterative refinement:
290 | 1) Generate code with CodingAgent
291 | 2) Evaluate with FeedbackAgent
292 | 3) If invalid, refine or retry until success or max attempts
293 | 4) Once valid, parse p-value < 0.05 => {"output": True} else {"output": False}.
294 | """
295 |
296 | def __init__(
297 | self,
298 | llm: str = "claude-3-5-sonnet-20241022",
299 | max_iterations: int = 10,
300 | p_value_threshold: float = 0.05,
301 | ):
302 | self.coding_agent = CodingAgent(llm)
303 | self.feedback_agent = FeedbackAgent(llm)
304 | self.max_iterations = max_iterations
305 | self.p_value_threshold = p_value_threshold
306 |
307 | self.graph = self._build_graph()
308 |
309 | def _build_graph(self):
310 | """
311 | Create and compile the LangGraph for the iterative refinement.
312 | """
313 | def generate_and_run_code(state: SelfRefineState):
314 | """
315 | Node that calls coding_agent.run(...) to produce code & execution results.
316 | """
317 | code_impl, code_out = self.coding_agent.run(
318 | state["messages"]
319 | )
320 | state["code_impl"] = code_impl
321 | state["code_out"] = code_out
322 | state["messages"] += [
323 | (
324 | "assistant",
325 | f"{code_impl.prefix} \n Imports: {code_impl.imports} \n Code: {code_impl.code}",
326 | )
327 | ]
328 | return state
329 |
330 | def evaluate_feedback(state: SelfRefineState):
331 | """
332 | Node that calls feedback_agent.run(...) to see if code is valid.
333 | """
334 | feedback = self.feedback_agent.run(
335 | self.test_specification,
336 | self.data,
337 | state["code_impl"],
338 | state["code_out"]
339 | )
340 | # Suppose feedback has structure:
341 | # {
342 | # "is_valid": "Yes" or "No",
343 | # "p_value": "...",
344 | # "feedback": "..."
345 | # }
346 | # We store these in state
347 | state["feedback_valid"] = (feedback.is_valid.lower() == "yes")
348 | if feedback.is_valid.lower() == "yes":
349 | # attempt to parse p-value
350 | try:
351 | pval = float(feedback.p_value) if feedback.p_value else None
352 | except:
353 | pval = None
354 | state["p_value"] = pval
355 | else:
356 | state["p_value"] = None
357 | state["feedback_text"] = feedback.feedback
358 | state["messages"] += [
359 | ("user", feedback.feedback)
360 | ]
361 | return state
362 |
363 | def decide_if_done(state: SelfRefineState) -> Literal["check_p_value", "maybe_retry"]:
364 | """
365 | Decide if the feedback was valid. If valid, move to next step. If not, maybe keep iterating.
366 | """
367 | if state["feedback_valid"]:
368 | state["done"] = True
369 | return "end"
370 | else:
371 | return "check_max_iteration"
372 |
373 | def check_max_iteration(state: SelfRefineState) -> Literal["generate_and_run_code", "end"]:
374 |
375 | state["iteration"] += 1
376 | if state["iteration"] > self.max_iterations:
377 | state["done"] = True
378 | return state
379 |
380 | def maybe_retry(state: SelfRefineState) -> Literal["generate_and_run_code", "end"]:
381 | if state["done"] == True:
382 | return "end"
383 | return "generate_and_run_code"
384 |
385 | graph_builder = StateGraph(SelfRefineState)
386 |
387 | graph_builder.add_node("generate_and_run_code", generate_and_run_code)
388 | graph_builder.add_node("evaluate_feedback", evaluate_feedback)
389 | graph_builder.add_node("check_max_iteration", check_max_iteration)
390 |
391 | graph_builder.add_edge(START, "generate_and_run_code")
392 | graph_builder.add_edge("generate_and_run_code", "evaluate_feedback")
393 | graph_builder.add_conditional_edges(
394 | "evaluate_feedback",
395 | decide_if_done,
396 | {
397 | "end": END,
398 | "check_max_iteration": "check_max_iteration",
399 | }
400 | )
401 | graph_builder.add_conditional_edges(
402 | "check_max_iteration",
403 | maybe_retry,
404 | {
405 | "generate_and_run_code": "generate_and_run_code",
406 | "end": END,
407 | }
408 | )
409 |
410 | return graph_builder.compile()
411 |
412 | def generate(
413 | self,
414 | query: str,
415 | data_loader: Any,
416 | ) -> Dict[str, Any]:
417 | """
418 | Returns:
419 | Dict with keys:
420 | - 'output': bool (True if p-value < threshold, False otherwise)
421 | - 'feedback': feedback text from the final iteration
422 | - 'p_value': (optional) p-value from the final iteration
423 | """
424 | self.test_specification = query
425 | self.data = data_loader.data_desc
426 | self.coding_agent._set_globals(data_loader.table_dict)
427 | messages = [
428 | ("user", "Here is the hypothesis to falsify:" + self.test_specification + "\n\n" + "And here are the available data relevant to the hypothesis:\n" + self.data + "\n\nEach of these dataframes have already been loaded into the global namespace. You may access each dataframe **directly as variables**. Make sure to use the **EXACT** dataframe names as shown above.")
429 | ]
430 |
431 | state_dict = SelfRefineState(
432 | iteration=1,
433 | messages=messages,
434 | code_impl=None,
435 | code_out=None,
436 | feedback_valid=False,
437 | p_value=None,
438 | feedback_text=None,
439 | done=False
440 | )
441 | # run the graph
442 | result = self.graph.invoke(state_dict)
443 |
444 | # interpret final results
445 | if result["p_value"] is not None and result["p_value"] < self.p_value_threshold:
446 | return {
447 | "output": True,
448 | "feedback": result["feedback_text"],
449 | "p_value": result["p_value"],
450 | }
451 | elif result["p_value"] is not None:
452 | return {
453 | "output": False,
454 | "feedback": result["feedback_text"],
455 | "p_value": result["p_value"],
456 | }
457 | else:
458 | # no valid p-value or max iterations exhausted
459 | return {
460 | "output": False,
461 | "feedback": "No valid p-value found or max iteration limit reached."
462 | }
--------------------------------------------------------------------------------