├── popper ├── llm │ ├── __init__.py │ ├── utils.py │ ├── prompt_utils.py │ └── custom_model.py ├── __init__.py ├── version.py ├── react_agent.py ├── benchmark.py ├── popper.py ├── react_utils.py ├── utils.py └── prompt_utils.py ├── baseline_agents ├── __init__.py ├── utils │ └── dv_log.py ├── react_agent.py ├── coder_agent.py ├── coder_utils.py ├── react_utils.py └── self_refine_agent.py ├── .pre-commit-config.yaml ├── benchmark_scripts ├── __init__.py ├── run_discoverybench.sh ├── run_discoverybench_docker.sh ├── run_discovery_bench_baseline.py ├── run_targetval.sh ├── run_targetval_baseline.py ├── run_targetval_benchmark.py └── run_discovery_bench.py ├── MANIFEST.in ├── setup.cfg ├── figs └── popper_agent_illustration.png ├── requirements.txt ├── Dockerfile ├── setup.py ├── .gitignore └── README.md /popper/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: [] -------------------------------------------------------------------------------- /benchmark_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /popper/__init__.py: -------------------------------------------------------------------------------- 1 | from .popper import Popper -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /figs/popper_agent_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/POPPER/HEAD/figs/popper_agent_illustration.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | tqdm 3 | pandas 4 | gradio 5 | pydantic==2.9.2 6 | openai==1.61.0 7 | langchain_core==0.3.22 8 | langchain_openai==0.2.11 9 | langchain_anthropic==0.2.3 10 | numpy==1.26.4 11 | scipy==1.13.1 12 | langchain==0.3.7 13 | langgraph==0.2.39 14 | langchain_experimental==0.3.3 15 | ipython -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | WORKDIR /app 4 | 5 | COPY requirements.txt . 6 | 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | 9 | COPY benchmark_scripts /app/benchmark_scripts 10 | 11 | COPY popper /app/popper 12 | 13 | COPY baseline_agents /app/baseline_agents 14 | 15 | # Create directories for storing output and logs 16 | RUN mkdir -p /app/data /app/.logs && chmod -R 777 /app/data /app/.logs 17 | 18 | RUN useradd -m nonrootuser 19 | USER nonrootuser 20 | 21 | ENTRYPOINT ["python"] 22 | -------------------------------------------------------------------------------- /benchmark_scripts/run_discoverybench.sh: -------------------------------------------------------------------------------- 1 | python benchmark_scripts/run_discovery_bench.py \ 2 | --exp_name discovery_bench --num_tests 5 --samples 300 --permute --e_value --react --relevance_checker \ 3 | --path /dfs/scratch0/lansong/discoverybench &> .logs/discovery_bench.log 4 | 5 | python benchmark_scripts/run_discovery_bench_baseline.py \ 6 | --exp_name discovery_bench_baseline_react --agent_type react --samples 50 --log_file .logs/discovery_bench_baseline_react.log --permute \ 7 | --path /dfs/scratch0/lansong/discoverybench &> .logs/discovery_bench_baseline_react_stdout.log 8 | 9 | -------------------------------------------------------------------------------- /popper/version.py: -------------------------------------------------------------------------------- 1 | 2 | # Based on NiLearn package 3 | # License: simplified BSD 4 | 5 | # PEP0440 compatible formatted version, see: 6 | # https://www.python.org/dev/peps/pep-0440/ 7 | # 8 | # Generic release markers: 9 | # X.Y 10 | # X.Y.Z # For bug fix releases 11 | # 12 | # Admissible pre-release markers: 13 | # X.YaN # Alpha release 14 | # X.YbN # Beta release 15 | # X.YrcN # Release Candidate 16 | # X.Y # Final release 17 | # 18 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. 19 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev' 20 | # 21 | __version__ = '0.0.5' # pragma: no cover 22 | -------------------------------------------------------------------------------- /benchmark_scripts/run_discoverybench_docker.sh: -------------------------------------------------------------------------------- 1 | # docker build --no-cache -t discovery_benchmark . 2 | docker build -t discovery_benchmark . 3 | 4 | 5 | docker run \ 6 | -v /dfs/scratch0/lansong/discoverybench:/dfs/scratch0/lansong/discoverybench \ 7 | -v /dfs/scratch0/lansong/data:/app/data \ 8 | --name discovery_bench \ 9 | --env-file .env \ 10 | discovery_benchmark benchmark_scripts/run_discovery_bench.py \ 11 | --exp_name discovery_bench_v3 --num_tests 5 --samples 100 --permute --e_value --react --relevance_checker \ 12 | --path /dfs/scratch0/lansong/discoverybench &> .logs/discovery_bench_evalue_react_v3.log 13 | 14 | 15 | docker wait discovery_bench 16 | 17 | docker rm discovery_bench 18 | -------------------------------------------------------------------------------- /baseline_agents/utils/dv_log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | 5 | class JSONFormatter(logging.Formatter): 6 | def format(self, record): 7 | log_data = { 8 | 'timestamp': self.formatTime(record), 9 | 'level': record.levelname, 10 | 'logger_name': record.name, 11 | 'module': record.module, 12 | 'function': record.funcName, 13 | 'message': record.getMessage() 14 | } 15 | return json.dumps(log_data) 16 | 17 | 18 | class DVLogger(): 19 | def __init__(self, logger_name, log_filename): 20 | self.file_handler = logging.FileHandler(log_filename) 21 | self.file_handler.setLevel(logging.INFO) 22 | self.json_formatter = JSONFormatter() 23 | self.file_handler.setFormatter(self.json_formatter) 24 | self.logger = logging.getLogger(logger_name) 25 | self.logger.setLevel(logging.INFO) 26 | self.logger.addHandler(self.file_handler) 27 | 28 | # To log string 29 | def log(self, message): 30 | self.logger.info(message) 31 | 32 | # To log dictionary (agent reponse) 33 | def log_json(self, message): 34 | self.logger.info(json.dumps(message)) 35 | 36 | def close(self): 37 | self.file_handler.close() 38 | self.logger.removeHandler(self.file_handler) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # read the contents of README file 4 | from os import path 5 | from io import open # for Python 2 and 3 compatibility 6 | 7 | # get __version__ from _version.py 8 | ver_file = path.join('popper', 'version.py') 9 | with open(ver_file) as f: 10 | exec(f.read()) 11 | 12 | this_directory = path.abspath(path.dirname(__file__)) 13 | 14 | 15 | # read the contents of README.md 16 | def readme(): 17 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 18 | return f.read() 19 | 20 | 21 | # read the contents of requirements.txt 22 | with open(path.join(this_directory, 'requirements.txt'), 23 | encoding='utf-8') as f: 24 | requirements = f.read().splitlines() 25 | 26 | setup(name='popper_agent', 27 | version=__version__, 28 | license='MIT', 29 | description='POPPER', 30 | long_description=readme(), 31 | long_description_content_type='text/markdown', 32 | url='https://github.com/snap-stanford/POPPER', 33 | author='POPPER Team', 34 | author_email='kexinh@cs.stanford.edu', 35 | packages=find_packages(exclude=['test']), 36 | zip_safe=False, 37 | include_package_data=True, 38 | install_requires=requirements, 39 | setup_requires=['setuptools>=38.6.0'] 40 | ) 41 | -------------------------------------------------------------------------------- /popper/llm/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import traceback 3 | import logging 4 | import uuid 5 | import re 6 | 7 | def clean_json_string(json_string: str) -> str: 8 | # Remove trailing commas before closing braces/brackets in JSON structures 9 | cleaned_json_string = re.sub(r',\s*(\}|\])', r'\1', json_string) 10 | return cleaned_json_string 11 | 12 | def parse_llm_output(llm_output: str): 13 | # Regular expressions to match the JSON structures 14 | tool_calls_pattern = r'\{\s*"type":\s*"tool_calls",\s*"content":\s*\[.*?\]\s*\}' 15 | text_message_pattern = r'\{\s*"type":\s*"text_message",\s*"content":\s*".*?"\s*\}' 16 | 17 | # Initial text that appears before the JSON 18 | final_text_before_json = "" 19 | 20 | # Try to find tool calls JSON 21 | tool_calls_match = re.search(tool_calls_pattern, llm_output, re.DOTALL) 22 | if tool_calls_match: 23 | tool_calls_json = tool_calls_match.group(0) 24 | final_text_before_json = llm_output[:tool_calls_match.start()] 25 | # Clean up the JSON string (remove trailing commas) 26 | tool_calls_json = clean_json_string(tool_calls_json) 27 | print(tool_calls_json) 28 | return final_text_before_json, json.loads(tool_calls_json) 29 | 30 | # Try to find text message JSON 31 | text_message_match = re.search(text_message_pattern, llm_output) 32 | if text_message_match: 33 | text_message_json = text_message_match.group(0) 34 | final_text_before_json = llm_output[:text_message_match.start()] 35 | return final_text_before_json, json.loads(text_message_json) 36 | 37 | # If neither pattern is found, treat the entire output as text message 38 | return "", { 39 | "type": "text_message", 40 | "content": llm_output, 41 | } -------------------------------------------------------------------------------- /popper/llm/prompt_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def bind_tools_to_system_prompt(system_prompt, tools, tool_choice = None): 4 | if not tool_choice or tool_choice == 'none': 5 | return f'''You are an intelligent agent capable of calling tools to complete user-assigned tasks. 6 | Here are the instructions specified by the user: 7 | """{system_prompt}""" 8 | 9 | In addition, you have access to the following tools: 10 | {json.dumps(tools, indent=4)} 11 | 12 | You may output any intermediate thoughts or reasonings before delivering your final response. 13 | Your final response must either be at least one tool call or a response message to the user. 14 | 15 | To make one or more tool calls, wrap your final response in the following JSON format: 16 | {{ 17 | "type": "tool_calls", 18 | "content": [ 19 | {{ 20 | "name": "name of the function to call", 21 | "id": "an unique id for this tool call", 22 | "arguments": {{ 23 | "argument1": value1, 24 | "argument2": value2, 25 | ... 26 | }} 27 | }}, 28 | ... 29 | ] 30 | }} 31 | 32 | To send a direct response message to the user, wrap your final response in the following JSON format: 33 | {{ 34 | "type": "text_message", 35 | "content": "content of the message according to the user instructions" 36 | }} 37 | 38 | You must choose either to send tool calls or a direct response message. Be sure to format the final response properly according to the given JSON specs. 39 | 40 | DO NOT put anything after the final response JSON object. 41 | ''' 42 | 43 | # tool choice! 44 | system_prompt = f'''You are an intelligent agent capable of calling tools to complete user-assigned tasks. 45 | Here are the instructions specified by the user: 46 | """{system_prompt}""" 47 | 48 | In addition, you have access to the following tools: 49 | {json.dumps(tools, indent=4)} 50 | 51 | You may output any intermediate thoughts or reasonings before delivering your final response. 52 | Your final response MUST BE one or more tool calls. 53 | 54 | To make one or more tool calls, wrap your final response in the following JSON format: 55 | {{ 56 | "type": "tool_calls", 57 | "content": [ 58 | {{ 59 | "name": "name of the function to call", 60 | "id": "an unique id for this tool call", 61 | "arguments": {{ 62 | "argument1": value1, 63 | "argument2": value2, 64 | ... 65 | }} 66 | }}, 67 | ... 68 | ] 69 | }} 70 | 71 | 72 | You MUST wrap your response as a tool call formatted in the above JSON schema. 73 | 74 | DO NOT put anything after the tool call. 75 | ''' 76 | return system_prompt -------------------------------------------------------------------------------- /benchmark_scripts/run_discovery_bench_baseline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append('../') 4 | 5 | from baseline_agents.coder_agent import BaseAgent 6 | from baseline_agents.react_agent import ReactAgent 7 | from baseline_agents.self_refine_agent import SelfRefineAgent 8 | from baseline_agents.coder_utils import load_data_to_coder_globals 9 | from baseline_agents.react_utils import load_data_to_react_globals 10 | from popper.benchmark import discovery_bench_hypothesis 11 | 12 | from sklearn.metrics import accuracy_score, average_precision_score 13 | 14 | from tqdm import tqdm 15 | import argparse 16 | import traceback 17 | import time 18 | import json 19 | 20 | argparser = argparse.ArgumentParser() 21 | argparser.add_argument("--exp_name", type=str, default="discovery_bench_baseline") 22 | argparser.add_argument("--model_name", type=str, default="claude-3-5-sonnet") 23 | argparser.add_argument("--agent_type", type=str, choices=['coder', 'react', 'self_refine'], default="coder") 24 | argparser.add_argument('--samples', type=int, default=50) 25 | argparser.add_argument('--starts_from', type=int, default=0) 26 | argparser.add_argument("--log_file", type=str, default=".logs/baseline_log.log") 27 | argparser.add_argument('--permute', action='store_true', default=False) 28 | argparser.add_argument('--path', type=str, required = True) 29 | 30 | args = argparser.parse_args() 31 | 32 | exp_name = args.exp_name 33 | exp_name += f"_{args.agent_type}" 34 | if args.permute: 35 | # data_loader.permute_selected_columns() 36 | exp_name+='_permuted' 37 | 38 | samples = args.samples 39 | bm = discovery_bench_hypothesis(num_samples = samples, path = path) 40 | predictions = [] 41 | targets = [] 42 | 43 | agent = None 44 | load_to_globals = None 45 | if args.agent_type == 'coder' or args.agent_type == 'react': 46 | load_to_globals = load_data_to_coder_globals if args.agent_type == 'coder' else load_data_to_react_globals 47 | 48 | start = time.time() 49 | 50 | for i, example in tqdm(enumerate(bm.get_iterator()), total=samples, desc="Processing"): 51 | if i < args.starts_from: 52 | print(f"Skipping example {i}") 53 | continue 54 | 55 | if args.agent_type == 'coder': 56 | agent = BaseAgent( 57 | model_name=args.model_name, 58 | api_config="baseline_agents/config/api_config.json", 59 | model_config="baseline_agents/config/model_config.json", 60 | log_file=args.log_file 61 | ) 62 | elif args.agent_type == 'react': 63 | agent = ReactAgent( 64 | model_name=args.model_name, 65 | api_config="baseline_agents/config/api_config.json", 66 | model_config="baseline_agents/config/model_config.json", 67 | log_file=args.log_file 68 | ) 69 | elif args.agent_type == 'self_refine': 70 | agent = SelfRefineAgent( 71 | llm="claude-3-5-sonnet-20241022" 72 | ) 73 | else: 74 | raise ValueError(f"Agent {args.agent_type} not found") 75 | 76 | try: 77 | data_loader = example["data_loader"] 78 | if load_to_globals is not None: 79 | load_to_globals(data_loader) 80 | 81 | query = example['prompt'] 82 | output = agent.generate(data_loader=data_loader, query=query) 83 | if output is not None: 84 | predictions.append((0.0, output["output"])) 85 | else: 86 | predictions.append((0.0, False)) 87 | targets.append(example['answer']) 88 | except: 89 | print(traceback.format_exc()) 90 | 91 | 92 | print("------------------------------------") 93 | 94 | end = time.time() 95 | print(f"Total elapsed time: {end - start}") 96 | 97 | print("Benchmark Results:") 98 | print(f"Predictions: {predictions}") 99 | print(f"Targets: {targets}") 100 | eval_results = bm.evaluate(predictions, targets) 101 | 102 | for metric in eval_results: 103 | print(f"{metric}: {eval_results[metric]}") -------------------------------------------------------------------------------- /benchmark_scripts/run_targetval.sh: -------------------------------------------------------------------------------- 1 | ## TargetVal benchmarks 2 | data_path="/dfs/user/kexinh/popper_data_processed" 3 | for seed in 1 2 3 4 5 4 | do 5 | datasets=("IFNG" "IL2") 6 | for dataset in "${datasets[@]}" 7 | do 8 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path 9 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path 10 | done 11 | done 12 | 13 | ### For Popper-codegen: 14 | python run_targetval_benchmark.py --e_value --relevance_checker --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path 15 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path 16 | 17 | 18 | ### For Popper-NoRelevanceChecker: 19 | python run_targetval_benchmark.py --e_value --react --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path 20 | python run_targetval_benchmark.py --e_value --permute --react --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path 21 | 22 | ### For Popper-Fisher Combined Test: 23 | python run_targetval_benchmark.py --relevance_checker --react --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path 24 | python run_targetval_benchmark.py --relevance_checker --permute --react --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path 25 | 26 | ### For Popper-LLM estiamted likelihood ratio: 27 | python run_targetval_benchmark.py --llm_approx --relevance_checker --seed $seed --use_full_data --samples 20 --dataset $dataset --path $data_path 28 | python run_targetval_benchmark.py --llm_approx --relevance_checker --permute --seed $seed --use_full_data --samples 50 --dataset $dataset --path $data_path 29 | 30 | ### For other LLMs: 31 | # GPT-4o 32 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --model gpt-4o --samples 20 --dataset $dataset --path $data_path 33 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --model gpt-4o --samples 50 --dataset $dataset --path $data_path 34 | 35 | # o1 36 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --model o1-2024-12-17 --samples 20 --dataset $dataset --path $data_path 37 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --model o1-2024-12-17 --samples 50 --dataset $dataset --path $data_path 38 | 39 | # Haiku 40 | python run_targetval_benchmark.py --e_value --relevance_checker --react --seed $seed --use_full_data --model claude-3-5-haiku-20241022 --samples 20 --dataset $dataset --path $data_path 41 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --model claude-3-5-haiku-20241022 --samples 50 --dataset $dataset --path $data_path 42 | 43 | ### For user study type I error genes: 44 | python run_targetval_benchmark.py --e_value --relevance_checker --permute --react --seed $seed --use_full_data --user_study_neg_genes --dataset IL2 --path $data_path 45 | 46 | 47 | ### Baselines 48 | # Coder-o1 49 | python run_targetval_baseline.py --agent_type coder --model o1-2024-12-17 --permute --samples 50 --seed $seed --dataset IL2 --path $data_path 50 | python run_targetval_baseline.py --agent_type coder --model o1-2024-12-17 --samples 20 --seed $seed --dataset IL2 --path $data_path 51 | 52 | # Coder 53 | python run_targetval_baseline.py --agent_type coder --permute --samples 50 --seed $seed --dataset IL2 --path $data_path 54 | python run_targetval_baseline.py --agent_type coder --samples 20 --seed $seed --dataset IL2 --path $data_path 55 | 56 | # Self-refine 57 | python run_targetval_baseline.py --agent_type self_refine --use_full_data --permute --samples 50 --seed $seed --dataset IL2 --path $data_path 58 | python run_targetval_baseline.py --agent_type self_refine --use_full_data --samples 20 --seed $seed --dataset IL2 --path $data_path 59 | 60 | # React 61 | python run_targetval_baseline.py --use_other_claude_api --use_simple_template --agent_type react --permute --samples 50 --seed $seed --dataset IL2 --path $data_path 62 | python run_targetval_baseline.py --use_other_claude_api --use_simple_template --agent_type react --samples 20 --seed $seed --dataset IL2 --path $data_path -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | demo_test.ipynb 3 | .gradio/ 4 | test.ipynb 5 | *xlsx 6 | notebooks/figures/* 7 | notebooks/tmp_dir/* 8 | notebooks/tmp_directory/* 9 | notebooks/save_folder/* 10 | notebooks/cache 11 | notebooks/\[draft* 12 | notebooks/\[scratch* 13 | scripts/save_folder/* 14 | scripts/\[draft* 15 | scripts/tmp_dir/* 16 | scripts/tmp_directory/* 17 | bioagentos/agent/popper_agent.py 18 | bioagentos/tool/data_tool/old_data_tool/* 19 | examples/*.png 20 | examples/*.csv 21 | examples/*.fa 22 | scratch/ 23 | packages/ 24 | data 25 | data_lake 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | res/ 30 | examples/tmp* 31 | 32 | # C extensions 33 | *.so 34 | 35 | # Distribution / packaging 36 | .Python 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | .eggs/ 43 | lib/ 44 | lib64/ 45 | parts/ 46 | sdist/ 47 | var/ 48 | wheels/ 49 | share/python-wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .nox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | *.py,cover 76 | .hypothesis/ 77 | .pytest_cache/ 78 | cover/ 79 | 80 | # Translations 81 | *.mo 82 | *.pot 83 | 84 | # Django stuff: 85 | *.log 86 | local_settings.py 87 | db.sqlite3 88 | db.sqlite3-journal 89 | 90 | # Flask stuff: 91 | instance/ 92 | .webassets-cache 93 | 94 | # Scrapy stuff: 95 | .scrapy 96 | 97 | # Sphinx documentation 98 | docs/_build/ 99 | 100 | # PyBuilder 101 | .pybuilder/ 102 | target/ 103 | 104 | # Jupyter Notebook 105 | .ipynb_checkpoints 106 | 107 | # IPython 108 | profile_default/ 109 | ipython_config.py 110 | 111 | # pyenv 112 | # For a library or package, you might want to ignore these files since the code is 113 | # intended to run in multiple environments; otherwise, check them in: 114 | # .python-version 115 | 116 | # pipenv 117 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 118 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 119 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 120 | # install all needed dependencies. 121 | #Pipfile.lock 122 | 123 | # poetry 124 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 125 | # This is especially recommended for binary packages to ensure reproducibility, and is more 126 | # commonly ignored for libraries. 127 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 128 | #poetry.lock 129 | 130 | # pdm 131 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 132 | #pdm.lock 133 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 134 | # in version control. 135 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 136 | .pdm.toml 137 | .pdm-python 138 | .pdm-build/ 139 | 140 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 141 | __pypackages__/ 142 | 143 | # Celery stuff 144 | celerybeat-schedule 145 | celerybeat.pid 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # Environments 151 | .env 152 | .venv 153 | env/ 154 | venv/ 155 | ENV/ 156 | env.bak/ 157 | venv.bak/ 158 | 159 | # Spyder project settings 160 | .spyderproject 161 | .spyproject 162 | 163 | # Rope project settings 164 | .ropeproject 165 | 166 | # mkdocs documentation 167 | /site 168 | 169 | # mypy 170 | .mypy_cache/ 171 | .dmypy.json 172 | dmypy.json 173 | 174 | # Pyre type checker 175 | .pyre/ 176 | 177 | # pytype static type analyzer 178 | .pytype/ 179 | 180 | # Cython debug symbols 181 | cython_debug/ 182 | 183 | # PyCharm 184 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 185 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 186 | # and can be added to the global gitignore or merged into this file. For a more nuclear 187 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 188 | #.idea/ 189 | -------------------------------------------------------------------------------- /benchmark_scripts/run_targetval_baseline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append('../') 4 | 5 | from baseline_agents.coder_agent import BaseAgent 6 | from baseline_agents.react_agent import ReactAgent 7 | from baseline_agents.self_refine_agent import SelfRefineAgent 8 | from baseline_agents.coder_utils import load_data_to_coder_globals 9 | from baseline_agents.react_utils import load_data_to_react_globals 10 | from popper.benchmark import gene_perturb_hypothesis 11 | from popper.utils import ExperimentalDataLoader 12 | from langchain_core.prompts import ChatPromptTemplate 13 | from popper.utils import get_llm 14 | from sklearn.metrics import accuracy_score, average_precision_score 15 | from pydantic import BaseModel, Field 16 | from typing import ( 17 | Optional, List, Tuple, Union, Literal, Dict, TypedDict, Annotated 18 | ) 19 | from tqdm import tqdm 20 | import argparse 21 | import traceback 22 | import time 23 | import json 24 | 25 | argparser = argparse.ArgumentParser() 26 | argparser.add_argument("--exp_name", type=str, default="gene_baseline") 27 | argparser.add_argument("--model_name", type=str, default="claude-3-5-sonnet") 28 | argparser.add_argument("--agent_type", type=str, choices=['coder', 'react', 'self_refine'], default="coder") 29 | argparser.add_argument('--samples', type=int, default=25) 30 | argparser.add_argument('--starts_from', type=int, default=0) 31 | argparser.add_argument("--log_file", type=str, default=".logs/baseline_log.log") 32 | argparser.add_argument('--permute', action='store_true', default=False) 33 | argparser.add_argument('--use_full_data', action='store_true', default=False) 34 | argparser.add_argument("--dataset", type=str, default="IL2") 35 | argparser.add_argument('--seed', type=int, default=-1) 36 | argparser.add_argument('--use_simple_template', action='store_true', default=False) 37 | argparser.add_argument('--path', type=str, required = True) 38 | 39 | args = argparser.parse_args() 40 | 41 | data_path = args.path 42 | if args.use_full_data: 43 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'all_bio') 44 | else: 45 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'default') 46 | 47 | exp_name = args.dataset 48 | exp_name += f"_{args.agent_type}" 49 | if args.permute: 50 | data_loader.permute_selected_columns() 51 | exp_name+='_permuted' 52 | 53 | if args.use_full_data: 54 | exp_name+='_full_data' 55 | 56 | if args.seed != -1: 57 | exp_name+=f'_seed_{args.seed}' 58 | 59 | if args.use_simple_template: 60 | exp_name+='_simple_template' 61 | 62 | samples = args.samples 63 | bm = gene_perturb_hypothesis(num_of_samples = samples, 64 | permuted=args.permute, dataset = args.dataset, path = args.path) 65 | predictions = [] 66 | targets = [] 67 | 68 | agent = None 69 | load_to_globals = None 70 | if args.agent_type == 'coder' or args.agent_type == 'react': 71 | load_to_globals = load_data_to_coder_globals if args.agent_type == 'coder' else load_data_to_react_globals 72 | 73 | start = time.time() 74 | for i, example in tqdm(enumerate(bm.get_iterator()), total=samples, desc="Processing"): 75 | if i < args.starts_from: 76 | print(f"Skipping example {i}") 77 | continue 78 | 79 | if args.agent_type == 'coder': 80 | agent = BaseAgent( 81 | model_name=args.model_name, 82 | api_config="baseline_agents/config/api_config.json", 83 | model_config="baseline_agents/config/model_config.json", 84 | log_file=args.log_file, 85 | simple_prompt=args.use_simple_template 86 | ) 87 | elif args.agent_type == 'react': 88 | agent = ReactAgent( 89 | model_name=args.model_name, 90 | api_config="baseline_agents/config/api_config.json", 91 | model_config="baseline_agents/config/model_config.json", 92 | log_file=args.log_file, 93 | simple_prompt=args.use_simple_template 94 | ) 95 | elif args.agent_type == 'self_refine': 96 | agent = SelfRefineAgent( 97 | llm="claude-3-5-sonnet-20241022" 98 | ) 99 | else: 100 | raise ValueError(f"Agent {args.agent_type} not found") 101 | 102 | try: 103 | if load_to_globals is not None: 104 | load_to_globals(data_loader) 105 | 106 | query = example['prompt'] 107 | datasets = data_loader.data_desc 108 | if args.agent_type == 'coder': 109 | output = agent.generate(data_loader=data_loader, query=query) 110 | elif args.agent_type in ['react', 'self_refine']: 111 | output = agent.generate(data_loader=data_loader, query=query) 112 | 113 | print(output) 114 | if output is not None: 115 | predictions.append((0.0, output["output"], example['gene'])) 116 | else: 117 | predictions.append((0.0, False, example['gene'])) 118 | targets.append(example['binary_answer']) 119 | except: 120 | print(traceback.format_exc()) 121 | 122 | 123 | print("------------------------------------") 124 | 125 | end = time.time() 126 | print(f"Total elapsed time: {end - start}") 127 | 128 | print("Benchmark Results:") 129 | print(f"Predictions: {predictions}") 130 | print(f"Targets: {targets}") 131 | eval_results = bm.evaluate(predictions, targets) 132 | 133 | for metric in eval_results: 134 | print(f"{metric}: {eval_results[metric]}") 135 | 136 | 137 | import pickle 138 | import os 139 | os.makedirs(args.path + '/res', exist_ok=True) 140 | with open(args.path + '/res/' + exp_name + '_res_final.pkl', 'wb') as f: 141 | pickle.dump(res, f) -------------------------------------------------------------------------------- /benchmark_scripts/run_targetval_benchmark.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append('../') 4 | 5 | from popper.benchmark import gene_perturb_hypothesis 6 | from popper.agent import SequentialFalsificationTest 7 | from popper.utils import ExperimentalDataLoader 8 | 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | argparser = argparse.ArgumentParser() 13 | argparser.add_argument('--permute', action='store_true', default=False) 14 | argparser.add_argument('--samples', type=int, default=25) 15 | argparser.add_argument('--llm_approx', action='store_true', default=False) 16 | argparser.add_argument('--e_value', action='store_true', default=False) 17 | argparser.add_argument('--relevance_checker', action='store_true', default=False) 18 | argparser.add_argument('--react', action='store_true', default=False) 19 | argparser.add_argument('--use_full_data', action='store_true', default=False) 20 | argparser.add_argument('--max_num_of_tests', type=int, default=5) 21 | argparser.add_argument('--alpha', type=float, default=0.1) 22 | argparser.add_argument('--seed', type=int, default=-1) 23 | argparser.add_argument("--model", type=str, default="claude-3-5-sonnet-20241022") 24 | argparser.add_argument("--dataset", type=str, default="IL2") 25 | argparser.add_argument('--data_sampling', type=int, default=-1) 26 | argparser.add_argument('--user_study_neg_genes', action='store_true', default=False) 27 | argparser.add_argument('--is_locally_served', action='store_true', default=False) 28 | argparser.add_argument('--server_port', type=int, required=False) 29 | argparser.add_argument("--api_key", type=str, default="EMPTY") 30 | argparser.add_argument('--path', type=str, required = True) 31 | 32 | args = argparser.parse_args() 33 | 34 | data_path = args.path 35 | 36 | if args.data_sampling != -1: 37 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'all_bio', data_sampling = args.data_sampling) 38 | elif args.use_full_data: 39 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'all_bio') 40 | else: 41 | data_loader = ExperimentalDataLoader(data_path, table_dict_selection = 'default') 42 | 43 | exp_name = args.dataset 44 | if args.permute: 45 | data_loader.permute_selected_columns() 46 | exp_name+='_permuted' 47 | 48 | if args.max_num_of_tests != 5: 49 | exp_name+=f'_max_{args.max_num_of_tests}' 50 | 51 | if args.alpha != 0.1: 52 | exp_name+=f'_alpha_{args.alpha}' 53 | 54 | if args.llm_approx: 55 | exp_name+='_llm_approx' 56 | 57 | if args.use_full_data: 58 | exp_name+='_full_data' 59 | 60 | if args.react: 61 | exp_name+="_react_v2" 62 | #args.e_value = True 63 | 64 | if args.e_value: 65 | exp_name+='_e_value' 66 | 67 | if args.relevance_checker: 68 | exp_name+='_relevance_checker' 69 | 70 | if args.seed != -1: 71 | exp_name+=f'_seed_{args.seed}' 72 | 73 | if args.model[:6] != 'claude': 74 | exp_name+=f'_{args.model}' 75 | else: 76 | if 'haiku' in args.model: 77 | exp_name+='_haiku' 78 | 79 | if args.data_sampling != -1: 80 | exp_name+=f'_sampling_{args.data_sampling}' 81 | 82 | if args.user_study_neg_genes: 83 | exp_name+='_user_study_neg_genes' 84 | 85 | print('Experiment name: ' + exp_name) 86 | res = {} 87 | samples = args.samples 88 | bm = gene_perturb_hypothesis(num_of_samples = samples, permuted=args.permute, 89 | dataset = args.dataset, user_study_neg_genes= args.user_study_neg_genes, path = args.path) 90 | #response = [] 91 | for example in tqdm(bm.get_iterator(), total=samples, desc="Processing"): 92 | import traceback 93 | try: 94 | agent = SequentialFalsificationTest(llm = args.model, is_local=args.is_locally_served, port=args.server_port, api_key=args.api_key) 95 | if args.llm_approx: 96 | agent.configure(data = data_loader, 97 | alpha = args.alpha, beta = 0.1, 98 | aggregate_test = 'LLM_approx', 99 | max_num_of_tests = 5, 100 | max_retry = args.max_num_of_tests, time_limit = 2, 101 | llm_approx = True, 102 | relevance_checker = args.relevance_checker) 103 | else: 104 | if args.e_value: 105 | agent.configure(data = data_loader, alpha = args.alpha, 106 | beta = 0.1, aggregate_test = 'E-value', 107 | max_num_of_tests = args.max_num_of_tests, max_retry = 5, time_limit = 2, 108 | relevance_checker = args.relevance_checker, use_react_agent=args.react) 109 | else: 110 | agent.configure(data = data_loader, alpha = args.alpha, beta = 0.1, 111 | aggregate_test = 'Fisher', max_num_of_tests = args.max_num_of_tests, 112 | max_retry = args.max_num_of_tests, time_limit = 2, 113 | relevance_checker = args.relevance_checker, use_react_agent=args.react) 114 | 115 | log, last_message, parsed_result = agent.go(example['prompt']) 116 | res[example['gene']] = (log, last_message, parsed_result, agent.res_stat) 117 | except Exception as e: 118 | print(f"Error for prompt '{example['prompt']}': {e}") 119 | print(traceback.format_exc()) # Print the full traceback for debugging 120 | res[example['gene']] = ('Error', traceback.format_exc()) 121 | continue 122 | 123 | import pickle 124 | import os 125 | os.makedirs(args.path + '/res', exist_ok=True) 126 | with open(args.path + '/res/' + exp_name + '_res_final.pkl', 'wb') as f: 127 | pickle.dump(res, f) -------------------------------------------------------------------------------- /baseline_agents/react_agent.py: -------------------------------------------------------------------------------- 1 | from baseline_agents.react_utils import create_agent 2 | from langchain_anthropic import ChatAnthropic 3 | from langchain_together import Together 4 | from langchain_openai import ChatOpenAI 5 | from langchain_google_genai import ChatGoogleGenerativeAI 6 | import os 7 | import json 8 | import langchain 9 | from langchain_core.callbacks import FileCallbackHandler, StdOutCallbackHandler 10 | from baseline_agents.utils.dv_log import DVLogger 11 | import uuid 12 | import traceback 13 | 14 | 15 | # uncomment the following line to enable debug mode 16 | # langchain.debug = True 17 | 18 | def get_prompt_data( 19 | prompt_config: str = None 20 | ): 21 | if prompt_config is None and os.environ.get("PROMPT_CONFIG") is None: 22 | raise ValueError("PROMPT_CONFIG not set and prompt_config not provided") 23 | else: 24 | prompt_config = prompt_config or os.environ.get("PROMPT_CONFIG") 25 | 26 | with open(prompt_config, "r") as file: 27 | return json.load(file) 28 | 29 | 30 | class ReactAgent(): 31 | def __init__( 32 | self, 33 | model_config: str = None, 34 | api_config: str = None, 35 | model_name: str = "claude-3-5-sonnet", 36 | log_file: str = "base_react_agent.log", 37 | max_iterations: int = 25, 38 | simple_prompt: bool = False 39 | ): 40 | self.logfile = log_file 41 | self.logger = DVLogger(f"{model_name}_{uuid.uuid4()}", log_file) 42 | # logger.add(log_file, format="{time} {level} {message}", level="INFO") 43 | self.file_handler = FileCallbackHandler(self.logfile) 44 | self.stdout_handler = StdOutCallbackHandler() 45 | 46 | # set max iterations 47 | self.max_iterations = max_iterations 48 | 49 | # check if model config is provided 50 | if model_config is None and os.environ.get("MODEL_CONFIG") is None: 51 | raise ValueError("MODEL_CONFIG not set and model_config not provided") 52 | else: 53 | # override environment variable config path if model_config is provided 54 | model_config = model_config or os.environ.get("MODEL_CONFIG") 55 | 56 | # do a similar check for api_config 57 | if api_config is None and os.environ.get("API_CONFIG") is None: 58 | raise ValueError("API_CONFIG not set and api_config not provided") 59 | else: 60 | api_config = api_config or os.environ.get("API_CONFIG") 61 | 62 | 63 | # load model config 64 | self.model_name = model_name 65 | with open(model_config, "r") as file: 66 | self.model_config = json.load(file) 67 | 68 | # load api config 69 | with open(api_config, "r") as file: 70 | self.api_config = json.load(file) 71 | 72 | try: 73 | # get full model name and type 74 | self.full_model_name = self.model_config['models'][self.model_name]['model_name'] 75 | self.model_type = self.model_config['models'][self.model_name]['model_type'] 76 | except KeyError: 77 | raise ValueError(f"Model {model_name} not found in model config") 78 | 79 | try: 80 | # get api key using model type 81 | self.api_key = self.api_config[self.model_type] 82 | except KeyError: 83 | raise ValueError(f"API key not found for {self.model_type}") 84 | 85 | # get the model 86 | self.llm = self.get_model( 87 | api=self.model_type, 88 | model=self.full_model_name, 89 | api_key=self.api_key 90 | ) 91 | 92 | # create agent 93 | self.agent = create_agent( 94 | llm=self.llm, 95 | handlers=[self.file_handler, self.stdout_handler], 96 | max_iterations=self.max_iterations, 97 | simple_template=simple_prompt 98 | ) 99 | 100 | def get_model( 101 | self, 102 | api, 103 | api_key, 104 | model, 105 | **kwargs 106 | ): 107 | llm = None 108 | if (api == "together"): 109 | llm = Together( 110 | model=model, 111 | together_api_key=api_key, 112 | **kwargs 113 | ) 114 | elif (api == "anthropic"): 115 | llm = ChatAnthropic( 116 | model=model, 117 | api_key=api_key, 118 | **kwargs 119 | ) 120 | elif (api == "openai"): 121 | llm = ChatOpenAI( 122 | model=model, 123 | api_key=api_key, 124 | **kwargs 125 | ) 126 | elif (api == "google"): 127 | llm = ChatGoogleGenerativeAI( 128 | model=model, 129 | google_api_key=api_key, 130 | **kwargs 131 | ) 132 | else: 133 | raise ValueError(f"Invalid API: {api}") 134 | return llm 135 | 136 | def generate(self, data_loader, query): 137 | try: 138 | table_dict = data_loader.table_dict 139 | dataset_desc = data_loader.data_desc 140 | self.agent.tools[0]._set_globals(table_dict) 141 | output = self.agent.invoke(input={ 142 | "system_prompt": "You are a scientific agent who can plan and execute python code multiple times to answer a query based on one or more datasets. All datasets have already been loaded into the global namespace as pandas dataframes.", 143 | "input": f"""Question: Is the following hypothesis true or false? {query} 144 | Datasets: {dataset_desc}""" 145 | }) 146 | self.logger.log_json(output) 147 | return output 148 | except Exception as e: 149 | print("Execution Stopped due to : ", e) 150 | print(traceback.format_exc()) 151 | self.logger.logger.error(f"Execution Stopped due to : {e}") 152 | self.logger.close() -------------------------------------------------------------------------------- /benchmark_scripts/run_discovery_bench.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.getcwd()) 4 | 5 | from popper.benchmark import discovery_bench_hypothesis 6 | from popper.agent import SequentialFalsificationTest 7 | from sklearn.metrics import accuracy_score, average_precision_score 8 | 9 | from tqdm import tqdm 10 | import argparse 11 | import traceback 12 | import time 13 | import json 14 | import pickle 15 | import numpy as np 16 | 17 | 18 | argparser = argparse.ArgumentParser() 19 | argparser.add_argument("--exp_name", type=str, default="discovery_bench") 20 | argparser.add_argument("--model", type=str, default="claude-3-5-sonnet-20241022") 21 | argparser.add_argument('--samples', type=int, default=50) 22 | argparser.add_argument('--num_tests', type=int, default=3) 23 | argparser.add_argument('--starts_from', type=int, default=0) 24 | argparser.add_argument('--permute', action='store_true', default=False) 25 | argparser.add_argument('--llm_approx', action='store_true', default=False) 26 | argparser.add_argument('--e_value', action='store_true', default=False) 27 | argparser.add_argument('--react', action='store_true', default=False) 28 | argparser.add_argument('--relevance_checker', action='store_true', default=False) 29 | argparser.add_argument('--is_locally_served', action='store_true', default=False) 30 | argparser.add_argument('--server_port', type=int, required=False) 31 | argparser.add_argument("--api_key", type=str, default="EMPTY") 32 | argparser.add_argument('--path', type=str, required = True) 33 | 34 | args = argparser.parse_args() 35 | 36 | 37 | exp_name = args.exp_name 38 | 39 | exp_name += args.model 40 | 41 | exp_name += f'_{args.num_tests}tests' 42 | 43 | if args.permute: 44 | # data_loader.permute_selected_columns() 45 | exp_name+='_permuted' 46 | 47 | if args.llm_approx: 48 | exp_name+='_llm_approx' 49 | 50 | if args.e_value: 51 | exp_name+='_e_value' 52 | 53 | if args.react: 54 | exp_name+="_react" 55 | 56 | if args.relevance_checker: 57 | exp_name+="_relevance_checker" 58 | 59 | print(f"Running {exp_name}") 60 | 61 | res = [] 62 | samples = args.samples 63 | bm = discovery_bench_hypothesis(num_samples = samples, path=args.path) 64 | predictions = [] 65 | targets = [] 66 | 67 | start = time.time() 68 | 69 | #response = [] 70 | for i, example in tqdm(enumerate(bm.get_iterator()), total=samples, desc="Processing"): 71 | if i < args.starts_from: 72 | print(f"Skipping example {i}") 73 | continue 74 | try: 75 | data_loader = example["data_loader"] 76 | 77 | permuted = "not permuted" if example["answer"] else "permuted" 78 | print("======================================================") 79 | print(f'Processing {(example["task"], example["metadataid"], example["query_id"], permuted)}') 80 | for name, df in data_loader.table_dict.items(): 81 | print(name) 82 | print(df.head()) 83 | print("------------------------------------") 84 | 85 | agent = SequentialFalsificationTest(llm = args.model, is_local=args.is_locally_served, port=args.server_port, api_key=args.api_key) 86 | if args.llm_approx: 87 | agent.configure(data = data_loader, alpha = 0.1, beta = 0.1, aggregate_test = 'LLM_approx', max_num_of_tests = args.num_tests, max_retry = 5, time_limit = 2, llm_approx = True, domain=example['domain'], relevance_checker=args.relevance_checker) 88 | else: 89 | if args.e_value: 90 | agent.configure(data = data_loader, alpha = 0.1, beta = 0.1, aggregate_test = 'E-value', max_num_of_tests = args.num_tests, max_retry = 5, time_limit = 2, domain=example['domain'], relevance_checker=args.relevance_checker, use_react_agent=args.react, max_failed_tests=args.num_tests) 91 | else: 92 | agent.configure(data = data_loader, alpha = 0.1, beta = 0.1, aggregate_test = 'Fisher', max_num_of_tests = args.num_tests, max_retry = 5, time_limit = 2, domain=example['domain'], relevance_checker=args.relevance_checker, max_failed_tests=args.num_tests) 93 | 94 | log, last_message, parsed_result = agent.go(example['prompt']) 95 | predictions.append((agent.res_stat, agent.res)) 96 | targets.append(example['answer']) 97 | res.append({ 98 | "task": example["task"], 99 | "metadataid": example["metadataid"], 100 | "query_id": example["query_id"], 101 | "log": log, 102 | "last_message": last_message, 103 | "parsed_result": parsed_result, 104 | "res": agent.res, 105 | "res_stat": agent.res_stat, 106 | "answer": example["answer"] 107 | }) 108 | # res[(example["task"], example["metadataid"], example["query_id"])] = (log, last_message, parsed_result, agent.res, agent.res_stat, example['answer']) 109 | except Exception as e: 110 | print(f"Error for prompt '{example['prompt']}': {e}") 111 | print(traceback.format_exc()) # Print the full traceback for debugging 112 | res.append({ 113 | "task": example["task"], 114 | "metadataid": example["metadataid"], 115 | "query_id": example["query_id"], 116 | "error": traceback.format_exc() 117 | }) 118 | # res[(example["task"], example["metadataid"], example["query_id"])] = ('Error', traceback.format_exc()) 119 | # predictions.append((0.0, False)) 120 | continue 121 | 122 | output_path = os.path.join(os.getcwd(), 'data/' + exp_name + '.json') 123 | 124 | # with open(os.path.join(os.getcwd(), 'data/' + exp_name + '.pkl'), 'wb') as f: 125 | with open(output_path, 'w') as f: 126 | # pickle.dump(res, f) 127 | json.dump(res, f, indent=4) 128 | 129 | print(f"Results saved to {output_path}") 130 | print("------------------------------------") 131 | print("------------------------------------") 132 | 133 | end = time.time() 134 | print(f"Total elapsed time: {end - start}") 135 | 136 | print("Benchmark Results:") 137 | print(f"Predictions: {predictions}") 138 | print(f"Targets: {targets}") 139 | eval_results = bm.evaluate(predictions, targets) 140 | 141 | for metric in eval_results: 142 | print(f"{metric}: {eval_results[metric]}") -------------------------------------------------------------------------------- /baseline_agents/coder_agent.py: -------------------------------------------------------------------------------- 1 | from baseline_agents.coder_utils import create_agent 2 | from baseline_agents.utils.dv_log import DVLogger 3 | from langchain_anthropic import ChatAnthropic 4 | from langchain_together import Together 5 | from langchain_openai import ChatOpenAI 6 | from langchain_google_genai import ChatGoogleGenerativeAI 7 | from langchain_core.callbacks import FileCallbackHandler, StdOutCallbackHandler 8 | 9 | import os 10 | import json 11 | import uuid 12 | import traceback 13 | 14 | # uncomment the following line to enable debug mode 15 | # langchain.debug = True 16 | 17 | 18 | def get_prompt_data( 19 | prompt_config: str = None 20 | ): 21 | if prompt_config is None and os.environ.get("PROMPT_CONFIG") is None: 22 | raise ValueError("PROMPT_CONFIG not set and prompt_config not provided") 23 | else: 24 | prompt_config = prompt_config or os.environ.get("PROMPT_CONFIG") 25 | 26 | with open(prompt_config, "r") as file: 27 | return json.load(file) 28 | 29 | 30 | class BaseAgent(): 31 | def __init__( 32 | self, 33 | model_config: str = None, 34 | api_config: str = None, 35 | model_name: str = "claude-3-5-sonnet", 36 | log_file: str = "base_coder_agent.log", 37 | max_iterations: int = 5, 38 | simple_prompt: bool = False 39 | ): 40 | self.simple_prompt = simple_prompt 41 | self.logfile = log_file 42 | # logger.add(log_file, format="{time} {level} {message}", level="INFO") 43 | self.logger = DVLogger(f"{model_name}_{uuid.uuid4()}", log_file) 44 | self.file_handler = FileCallbackHandler(self.logfile) 45 | self.stdout_handler = StdOutCallbackHandler() 46 | self.max_iterations = max_iterations 47 | 48 | # check if model config is provided 49 | if model_config is None and os.environ.get("MODEL_CONFIG") is None: 50 | raise ValueError("MODEL_CONFIG not set and model_config not provided") 51 | else: 52 | # override environment variable config path if model_config is provided 53 | model_config = model_config or os.environ.get("MODEL_CONFIG") 54 | 55 | # do a similar check for api_config 56 | if api_config is None and os.environ.get("API_CONFIG") is None: 57 | raise ValueError("API_CONFIG not set and api_config not provided") 58 | else: 59 | api_config = api_config or os.environ.get("API_CONFIG") 60 | 61 | # load model config 62 | self.model_name = model_name 63 | with open(model_config, "r") as file: 64 | self.model_config = json.load(file) 65 | 66 | # load api config 67 | with open(api_config, "r") as file: 68 | self.api_config = json.load(file) 69 | 70 | # get full model name and type 71 | try: 72 | self.full_model_name = self.model_config['models'][self.model_name]['model_name'] 73 | self.model_type = self.model_config['models'][self.model_name]['model_type'] 74 | except KeyError: 75 | raise ValueError(f"Model {model_name} not found in model config") 76 | 77 | try: 78 | # get api key using model type 79 | self.api_key = self.api_config[self.model_type] 80 | if self.api_key == "": 81 | raise ValueError(f"API key for {self.model_type} in api config is empty") 82 | except KeyError: 83 | raise ValueError(f"API key for {self.model_type} not found in api config") 84 | 85 | # get the model 86 | self.llm = self.get_model( 87 | api=self.model_type, 88 | model=self.full_model_name, 89 | api_key=self.api_key 90 | ) 91 | 92 | # create agent 93 | self.agent = create_agent( 94 | llm=self.llm, 95 | handlers=[self.file_handler, self.stdout_handler], 96 | max_iterations=self.max_iterations, 97 | simple_template=self.simple_prompt 98 | ) 99 | 100 | def get_model( 101 | self, 102 | api, 103 | api_key, 104 | model, 105 | **kwargs 106 | ): 107 | llm = None 108 | if (api == "together"): 109 | llm = Together( 110 | model=model, 111 | together_api_key=api_key, 112 | **kwargs 113 | ) 114 | elif (api == "anthropic"): 115 | llm = ChatAnthropic( 116 | model=model, 117 | api_key=api_key, 118 | **kwargs 119 | ) 120 | elif (api == "openai"): 121 | llm = ChatOpenAI( 122 | model=model, 123 | api_key=api_key, 124 | **kwargs 125 | ) 126 | elif (api == "google"): 127 | llm = ChatGoogleGenerativeAI( 128 | model=model, 129 | google_api_key=api_key, 130 | **kwargs 131 | ) 132 | else: 133 | raise ValueError(f"Invalid API: {api}") 134 | return llm 135 | 136 | def generate(self, data_loader, query): 137 | dataset_desc = data_loader.data_desc 138 | 139 | if self.simple_prompt: 140 | system_prompt = "You are a scientific agent who can execute python codes to answer a query based on one or more datasets. All datasets have already been loaded into the global namespace as pandas dataframes." 141 | else: 142 | system_prompt = "You are a scientific agent who can execute a python code only once to answer a query based on one or more datasets. All datasets have already been loaded into the global namespace as pandas dataframes.", 143 | 144 | try: 145 | output = self.agent.invoke(input={ 146 | "system_prompt": system_prompt, 147 | "input": f"""Question: Is the following hypothesis true or false? {query} 148 | Datasets: {dataset_desc}""" 149 | }) 150 | self.logger.log_json(output) 151 | return output 152 | except Exception as e: 153 | print("Execution Stopped due to : ", e) 154 | print(traceback.format_exc()) 155 | self.logger.logger.error(f"Execution Stopped due to : {e}") 156 | self.logger.close() 157 | -------------------------------------------------------------------------------- /popper/react_agent.py: -------------------------------------------------------------------------------- 1 | from popper.react_utils import create_agent 2 | from popper.prompt_utils import get_react_coding_agent_system_prompt 3 | from popper.llm.custom_model import CustomChatModel 4 | from langchain_anthropic import ChatAnthropic 5 | from langchain_openai import ChatOpenAI 6 | import os 7 | import json 8 | import langchain 9 | import openai 10 | from langchain_core.callbacks import FileCallbackHandler, StdOutCallbackHandler 11 | import logging 12 | import uuid 13 | import traceback 14 | import io 15 | import contextlib 16 | import sys 17 | import re 18 | 19 | # uncomment the following line to enable debug mode 20 | # langchain.debug = True 21 | 22 | class LiveLogger: 23 | """Custom stdout handler that logs in real-time while also printing output.""" 24 | def __init__(self, log): 25 | self.original_stdout = sys.stdout # Store original stdout 26 | self.log = log # Log dictionary 27 | self.current_buffer = [] # Store intermediate logs 28 | 29 | def clean_message(self, message): 30 | """Remove ANSI escape codes and filter out unnecessary logs.""" 31 | # Remove ANSI escape codes (color formatting) 32 | message = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', message) 33 | 34 | # Filter out specific unwanted messages 35 | unwanted_logs = [ 36 | "> Entering new AgentExecutor chain...", 37 | "> Finished chain." 38 | ] 39 | if any(unwanted in message for unwanted in unwanted_logs): 40 | return None # Skip logging this message 41 | 42 | return message.strip() if message.strip() else None 43 | 44 | def write(self, message): 45 | cleaned_message = self.clean_message(message) 46 | if cleaned_message: 47 | self.original_stdout.write(cleaned_message + "\n") # Print to console 48 | self.original_stdout.flush() # Ensure immediate output 49 | 50 | # Append each new log update separately in Markdown format 51 | self.current_buffer.append(cleaned_message) 52 | self.log['executor'].append(f"```\n{cleaned_message}\n```") # Markdown formatting 53 | 54 | def flush(self): 55 | self.original_stdout.flush() 56 | 57 | 58 | def get_prompt_data( 59 | prompt_config: str = None 60 | ): 61 | if prompt_config is None and os.environ.get("PROMPT_CONFIG") is None: 62 | raise ValueError("PROMPT_CONFIG not set and prompt_config not provided") 63 | else: 64 | prompt_config = prompt_config or os.environ.get("PROMPT_CONFIG") 65 | 66 | with open(prompt_config, "r") as file: 67 | return json.load(file) 68 | 69 | 70 | class ReactAgent(): 71 | def __init__( 72 | self, 73 | model_name: str = "claude-3-5-sonnet-20241022", 74 | model_config: str = None, 75 | api_config: str = None, 76 | max_iterations: int = 25, 77 | prompt_revision: bool = False, 78 | port=None, 79 | api_key="EMPTY", 80 | ): 81 | self.prompt_revision = prompt_revision 82 | self.api = "custom" 83 | if model_name[:7] == 'claude-': 84 | self.api = 'anthropic' 85 | elif model_name[:4] == 'gpt-': 86 | self.api = 'openai' 87 | else: 88 | self.api = 'local' 89 | 90 | # logger.add(log_file, format="{time} {level} {message}", level="INFO") 91 | self.stdout_handler = StdOutCallbackHandler() 92 | 93 | # set max iterations 94 | self.max_iterations = max_iterations 95 | 96 | # load model config 97 | self.model_name = model_name 98 | 99 | # get the model 100 | self.llm = self.get_model( 101 | api=self.api, 102 | model=self.model_name, 103 | port=port, 104 | api_key=api_key 105 | ) 106 | 107 | # create agent 108 | self.agent = create_agent( 109 | llm=self.llm, 110 | handlers=[self.stdout_handler], 111 | max_iterations=self.max_iterations 112 | ) 113 | 114 | def get_model( 115 | self, 116 | api, 117 | model, 118 | port=None, 119 | api_key=None, 120 | **kwargs 121 | ): 122 | llm = None 123 | if (api == "anthropic"): 124 | llm = ChatAnthropic( 125 | model=model, 126 | api_key=os.environ["ANTHROPIC_API_KEY"], 127 | **kwargs 128 | ) 129 | elif (api == "openai"): 130 | llm = ChatOpenAI( 131 | model=model, 132 | api_key=os.environ["OPENAI_API_KEY"], 133 | **kwargs 134 | ) 135 | # elif (api == 'llama'): 136 | # llm = CustomChatModel( 137 | # model=model, 138 | # model_type='custom', 139 | # **kwargs 140 | # ) 141 | # llm.client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY").chat.completions 142 | else: 143 | # Llama or other locally-served models 144 | assert port is not None, "Port must be specified for local models" 145 | llm = CustomChatModel( 146 | model=model, 147 | model_type='custom', 148 | **kwargs 149 | ) 150 | api_key = "EMPTY" if api_key is None else api_key 151 | llm.client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key=api_key).chat.completions 152 | return llm 153 | 154 | def generate(self, data_loader, test_spec, domain, log=None): 155 | try: 156 | self.agent.tools[0]._set_globals(data_loader.table_dict) 157 | dataset_desc = data_loader.data_desc 158 | 159 | # Use LiveLogger only if a log is provided 160 | logger = LiveLogger(log) if log is not None else sys.stdout 161 | 162 | # Redirect stdout to capture real-time logs 163 | sys.stdout = logger 164 | try: 165 | output = self.agent.invoke(input={ 166 | "system_prompt": get_react_coding_agent_system_prompt(domain=domain, prompt_revision=self.prompt_revision), 167 | "input": f"""Falsification Test: {test_spec} 168 | Datasets: {dataset_desc} 169 | Thought:""" 170 | }) 171 | finally: 172 | sys.stdout = logger.original_stdout # Restore stdout 173 | 174 | return output['output'] 175 | 176 | except Exception as e: 177 | error_message = f"Execution Stopped due to: {e}\n{traceback.format_exc()}" 178 | print(error_message) 179 | if log is not None: 180 | log['executor'].append(f"```\n{error_message}\n```") # Markdown format 181 | return None 182 | ''' 183 | def generate(self, data_loader, test_spec, domain, log = None): 184 | try: 185 | self.agent.tools[0]._set_globals(data_loader.table_dict) 186 | dataset_desc = data_loader.data_desc 187 | output = self.agent.invoke(input={ 188 | "system_prompt": get_react_coding_agent_system_prompt(domain=domain, prompt_revision=self.prompt_revision), 189 | "input": f"""Falsification Test: {test_spec} 190 | Datasets: {dataset_desc} 191 | Thought:""" 192 | }) 193 | return output['output'] 194 | except Exception as e: 195 | print("Execution Stopped due to : ", e) 196 | print(traceback.format_exc()) 197 | return None 198 | ''' 199 | -------------------------------------------------------------------------------- /baseline_agents/coder_utils.py: -------------------------------------------------------------------------------- 1 | # Set up the base template 2 | from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser 3 | from langchain.prompts import StringPromptTemplate 4 | from langchain.tools import BaseTool 5 | from langchain.chains.llm import LLMChain 6 | from langchain_experimental.tools.python.tool import PythonAstREPLTool 7 | from langchain.schema import AgentAction, AgentFinish 8 | from pydantic import Field 9 | from typing import List, Union 10 | import contextlib 11 | import io 12 | import logging 13 | import re 14 | 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | 19 | # set the maximum number of python code blocks that can be run 20 | MAX_TURNS = 1 21 | 22 | template = """{system_prompt} 23 | 24 | You have access to the following tools: 25 | {tools} 26 | 27 | Use the following format: 28 | 29 | Question: an input hypothesis that you must decide if it is True or False 30 | Datasets: the names and descriptions of datasets relevant to the input hypothesis 31 | Action: the action to take, should be one of [{tool_names}] 32 | Action Input: the input to the action 33 | Observation: the result of the action 34 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer 35 | Final Answer: True/False. Please output True if the input hypothesis is valid (e.g., you are able to reject the null hypothesis with statistical significance) and False if the input hypothesis is invalid (e.g., if you fail to reject the null hypothesis). 36 | 37 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY: 38 | 39 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section. 40 | 41 | NOTE: You will be able to execute the python code ONLY ONCE. So you will need to generate the complete code to solve the query in one go. Please provide the final answer after that. 42 | 43 | Begin! 44 | 45 | {input} 46 | {agent_scratchpad}""" 47 | 48 | 49 | template_v2 = """{system_prompt} 50 | 51 | You have access to the following tools: 52 | {tools} 53 | 54 | Use the following format: 55 | 56 | Question: an input hypothesis that you must decide if it is True or False 57 | Datasets: the names and descriptions of datasets relevant to the input hypothesis 58 | Action: the action to take, should be one of [{tool_names}] 59 | Action Input: the input to the action 60 | Observation: the result of the action 61 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer 62 | Final Answer: True/False. Please output True if you believe the input hypothesis is correct and False if the input hypothesis is not based on your analysis. 63 | 64 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY: 65 | 66 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section. 67 | 68 | NOTE: You will be able to execute the python code ONLY ONCE. So you will need to generate the complete code to solve the query in one go. Please provide the final answer after that. 69 | 70 | Begin! 71 | 72 | {input} 73 | {agent_scratchpad}""" 74 | 75 | 76 | def load_data_to_coder_globals(data_loader): 77 | for name, df in data_loader.table_dict.items(): 78 | globals()[name] = df 79 | 80 | 81 | # Set up a prompt template 82 | class CustomPromptTemplate(StringPromptTemplate): 83 | # The template to use 84 | template: str 85 | # The list of tools available 86 | tools: List[BaseTool] 87 | 88 | def format(self, **kwargs) -> str: 89 | # Get the intermediate steps (AgentAction, Observation tuples) 90 | # Format them in a particular way 91 | intermediate_steps = kwargs.pop("intermediate_steps") 92 | thoughts = "" 93 | for action, observation in intermediate_steps: 94 | thoughts += action.log 95 | thoughts += f"\nObservation: {observation}\nThought: " 96 | # Set the agent_scratchpad variable to that value 97 | kwargs["agent_scratchpad"] = thoughts 98 | # Create a tools variable from the list of tools provided 99 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) 100 | # Create a list of tool names for the tools provided 101 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) 102 | return self.template.format(**kwargs) 103 | 104 | 105 | # CustomOutputParser to parse the output of the LLM and execute actions 106 | class CustomOutputParser(AgentOutputParser): 107 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: 108 | # Check if agent should finish 109 | if "Final Answer:" in llm_output: 110 | output = llm_output.split("Final Answer:")[-1].split()[0].strip().lower() 111 | if output not in ["true", "false", "yes", "no", "y", "n"]: 112 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 113 | return AgentFinish( 114 | return_values={"output": output in ["true", "yes", 'y']}, 115 | log=llm_output, 116 | ) 117 | # Parse out the action and action input 118 | regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" 119 | match = re.search(regex, llm_output, re.DOTALL) 120 | if not match: 121 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 122 | action = match.group(1).strip() 123 | action_input = match.group(2) 124 | # Return the action and action input 125 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) 126 | 127 | 128 | class CustomPythonAstREPLTool(PythonAstREPLTool): 129 | max_turns: int = Field(0, exclude=True) 130 | 131 | def _run(self, query: str, run_manager=None): 132 | if self.max_turns >= MAX_TURNS: 133 | return 'You cannot run the code more than once - you have already run it earlier. Please provide the "Final Answer:" immediately after "Thought:", based on whatever information you got till now. Do not attempt to output an "Action:" or run the code again.' 134 | self.max_turns += 1 135 | 136 | code_match = re.search(r"```(.*?)```", query, re.DOTALL) 137 | if code_match: 138 | # Extract code within backticks 139 | code = code_match.group(1) 140 | else: 141 | code = query 142 | code = code.strip() 143 | if code.startswith('"""'): 144 | code = code[3:].lstrip() 145 | if code.endswith('"""'): 146 | code = code[:-3].rstrip() 147 | if code.startswith("python"): 148 | code = code[len("python"):].lstrip() 149 | 150 | code = "import pandas as pd\n" + code 151 | 152 | exec_globals = globals().copy() 153 | exec_globals.update(__builtins__) 154 | 155 | output_capture = io.StringIO() 156 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture): 157 | logging.getLogger().handlers[0].stream = output_capture 158 | try: 159 | exec(code, exec_globals) 160 | except Exception as e: 161 | return str(e) 162 | 163 | # Retrieve the output and return it 164 | output = output_capture.getvalue() 165 | return output if output else "Execution completed without output." 166 | 167 | def create_agent( 168 | llm, 169 | handlers, 170 | max_iterations = None, 171 | early_stopping_method: str = "force", 172 | simple_template = False 173 | ): 174 | output_parser = CustomOutputParser() 175 | python_tool = CustomPythonAstREPLTool(callbacks=handlers) 176 | tools = [python_tool] 177 | tool_names = [tool.name for tool in tools] 178 | 179 | if simple_template: 180 | use_template = template_v2 181 | else: 182 | use_template = template 183 | 184 | print('Creating agent with template:', use_template) 185 | 186 | prompt = CustomPromptTemplate( 187 | template=use_template, 188 | tools=tools, 189 | input_variables=["system_prompt", "input", "intermediate_steps"] 190 | ) 191 | llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=handlers) 192 | 193 | agent = LLMSingleActionAgent( 194 | llm_chain=llm_chain, 195 | output_parser=output_parser, 196 | stop=["\nObservation:"], 197 | allowed_tools=tool_names 198 | ) 199 | 200 | return AgentExecutor.from_agent_and_tools( 201 | agent=agent, 202 | tools=tools, 203 | verbose=True, 204 | max_iterations=max_iterations, 205 | callbacks=handlers, 206 | early_stopping_method=early_stopping_method 207 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # POPPER: Automated Hypothesis Validation with Agentic Sequential Falsifications 2 | 3 | This repository hosts the code base for the paper 4 | 5 | **Automated Hypothesis Validation with Agentic Sequential Falsifications** 6 | 7 | Kexin Huang*, Ying Jin*, Ryan Li*, Michael Y. Li, Emmanuel Candès, Jure Leskovec\ 8 | [Link to Paper](https://arxiv.org/abs/2502.09858) 9 | 10 | 11 | If you find this work useful, please consider cite: 12 | 13 | ``` 14 | @misc{popper, 15 | title={Automated Hypothesis Validation with Agentic Sequential Falsifications}, 16 | author={Kexin Huang and Ying Jin and Ryan Li and Michael Y. Li and Emmanuel Candès and Jure Leskovec}, 17 | year={2025}, 18 | eprint={2502.09858}, 19 | archivePrefix={arXiv} 20 | } 21 | ``` 22 | 23 | 24 | ### Overview 25 | Hypotheses are central to information acquisition, decision-making, and discovery. However, many real-world hypotheses are abstract, high-level statements that are difficult to validate directly. 26 | This challenge is further intensified by the rise of hypothesis generation from Large Language Models (LLMs), which are prone to hallucination and produce hypotheses in volumes that make manual validation impractical. Here we propose Popper, an agentic framework for rigorous automated validation of free-form hypotheses. 27 | Guided by Karl Popper's principle of falsification, Popper validates a hypothesis using LLM agents that design and execute falsification experiments targeting its measurable implications. A novel sequential testing framework ensures strict Type-I error control while actively gathering evidence from diverse observations, whether drawn from existing data or newly conducted procedures. 28 | We demonstrate Popper on six domains including biology, economics, and sociology. Popper delivers robust error control, high power, and scalability. Furthermore, compared to human scientists, Popper achieved comparable performance in validating complex biological hypotheses while reducing time by 10 folds, providing a scalable, rigorous solution for hypothesis validation. 29 | 30 | 31 |

logo

32 | 33 | 34 | ## Installation 35 | 36 | We highly recommend using a virtual environment to manage the dependencies. 37 | 38 | ```bash 39 | conda create -n popper_env python=3.10 40 | conda activate popper_env 41 | ``` 42 | 43 | For direct usage of Popper, you can install the package via pip: 44 | ```bash 45 | pip install popper_agent 46 | ``` 47 | 48 | For source code development, you can clone the repository and install the package: 49 | ```bash 50 | git clone https://github.com/snap-stanford/POPPER.git 51 | cd POPPER 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | Add the OpenAI/Anthropic API key to the environment variables: 56 | ```bash 57 | export OPENAI_API_KEY="YOUR_API_KEY" 58 | export ANTHROPIC_API_KEY="YOUR_API_KEY" 59 | ``` 60 | 61 | Datasets will be automatically downloaded to specified data folder when you run the code. 62 | 63 | ## Demo 64 | 65 | A demo is provided in [here](demo.ipynb) to show how to use the Popper agent to validate a hypothesis and basic functionalities of the Popper agent. 66 | 67 | ## Core API Usage 68 | 69 | ```python 70 | from popper import Popper 71 | 72 | # Initialize the Popper agent 73 | agent = Popper(llm="claude-3-5-sonnet-20240620") 74 | 75 | # Register data for hypothesis testing; 76 | # for bio/discoverybench data in the paper, 77 | # it will be automatically downloaded to your specified data_path 78 | agent.register_data(data_path='path/to/data', loader_type='bio') 79 | 80 | # Configure the agent with custom parameters 81 | agent.configure( 82 | alpha=0.1, 83 | max_num_of_tests=5, 84 | max_retry=3, 85 | time_limit=2, 86 | aggregate_test='E-value', 87 | relevance_checker=True, 88 | use_react_agent=True 89 | ) 90 | 91 | # Validate a hypothesis 92 | results = agent.validate(hypothesis="Your hypothesis here") 93 | 94 | # Print the results 95 | print(results) 96 | ``` 97 | 98 | ## Running locally-served LLM with OpenAI-Compatible API 99 | **Popper** supports inferencing with local LLM servers such as vLLM, SGLang, and llama.cpp, as long as they support OpenAI-compatible API. Here are some example usage with locally hosted LLMs: 100 | 101 | Using [SGLang](https://github.com/sgl-project/sglang/tree/main): 102 | ```bash 103 | # mistral large 2 with SGLang, using 4 GPUs with 8-bit quantization 104 | python -m sglang.launch_server --model-path mistralai/Mistral-Large-Instruct-2411 --port 40000 --host 0.0.0.0 --tp 4 --quantization fp8 --mem-fraction-static 0.8 --trust-remote-code 105 | ``` 106 | ```python 107 | from popper import Popper 108 | agent = Popper(llm="mistralai/Mistral-Large-Instruct-2411", is_locally_served=True, server_port=40000) 109 | agent.configure(alpha=0.1) 110 | agent.register_data(data_path='path/to/data', loader_type='bio') 111 | agent.validate(hypothesis = 'YOUR HYPOTHESIS') 112 | ``` 113 | 114 | Using [vLLM](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html): 115 | ```bash 116 | vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 117 | ``` 118 | ```python 119 | from popper import Popper 120 | agent = Popper(llm="NousResearch/Meta-Llama-3-8B-Instruct", is_locally_served=True, server_port=8000, api_key="token-abc123") 121 | ``` 122 | 123 | Using [llama.cpp](https://github.com/ggml-org/llama.cpp): 124 | ```bash 125 | llama-server -m model.gguf --port 8080 126 | ``` 127 | ```python 128 | from popper import Popper 129 | agent = Popper(llm="qwen2 1.5B", is_locally_served=True, server_port=8080) 130 | ``` 131 | 132 | ## Run on your own hypothesis and database 133 | 134 | You can simply dump in a set of datasets in your domain (e.g. business, economics, political science, etc.) and run Popper on your own hypothesis. 135 | We only expect every file is in a csv or pkl format. 136 | 137 | ```python 138 | from popper import Popper 139 | 140 | agent = Popper(llm="claude-3-5-sonnet-20240620") 141 | agent.configure(alpha = 0.1) 142 | agent.register_data(data_path='path/to/data', loader_type='custom') 143 | agent.validate(hypothesis = 'YOUR HYPOTHESIS') 144 | ``` 145 | 146 | ## Hypothesis in Popper 147 | 148 | You can arbitrarily define any free-form hypothesis. In the paper, we provide two types of hypothesis: biological hypothesis and discovery-bench hypothesis. 149 | 150 | You can load the biological hypothesis with: 151 | 152 | ```python 153 | from popper.benchmark import gene_perturb_hypothesis 154 | bm = gene_perturb_hypothesis(num_of_samples = samples, permuted=False, dataset = 'IL2', path = path) 155 | example = bm.get_example(0) 156 | ``` 157 | It will return something like: 158 | 159 | ``` 160 | {'prompt': 'Gene VAV1 regulates the production of Interleukin-2 (IL-2).', 161 | 'gene': 'VAV1', 162 | 'answer': 2.916, 163 | 'binary_answer': True} 164 | ``` 165 | 166 | `num_of_samples` is the number of samples you want to generate, `permuted` is whether you want to permute the dataset for type I error estimation, and `dataset` is the dataset you want to use and you can choose from `IL2` and `IFNG`. 167 | 168 | For discovery-bench, you can load the hypothesis with: 169 | 170 | ```python 171 | from popper.benchmark import discovery_bench_hypothesis 172 | bm = discovery_bench_hypothesis(num_samples = samples, path = path) 173 | example = bm.get_example(0) 174 | ``` 175 | 176 | It will return something like: 177 | 178 | ``` 179 | {'task': 'archaeology', 180 | 'domain': 'humanities', 181 | 'metadataid': 5, 182 | 'query_id': 0, 183 | 'prompt': 'From 1700 BCE onwards, the use of hatchets and swords increased while the use of daggers decreased.', 184 | 'data_loader': , 185 | 'answer': True} 186 | ``` 187 | 188 | As each hypothesis in discoverybench has its own associated dataset, the example will return `data_loader` its own dataset. 189 | 190 | 191 | ## Run benchmarks in the paper 192 | 193 | Bash scripts for reproducing the paper is provided in the `benchmark_scripts/run_targetval.sh` for `TargetVal` benchmark and `benchmark_scripts/run_discoverybench.sh` for `DiscoveryBench` benchmark. 194 | 195 | **Note:** the Popper agent can read or write files to your filesystem. We recommend running the benchmark scripts inside a containerized environments. We have provided a working `Dockerfile` and an example script to launch a Docker container and execute the scripts in `benchmark_scripts/run_discoverybench_docker.sh`. 196 | 197 | **To run paper benchmarks with locally-served models,** you can simply passed in the extra parameters to the benchmark script, e.g., 198 | ```bash 199 | python benchmark_scripts/run_discovery_bench.py --exp_name discovery_bench --model llama-3.3-70b --num_tests 5 --samples 100 --permute --e_value --react --relevance_checker --is_locally_served --server_port 30000 --path PATH_TO_YOUR_DATASET 200 | ``` 201 | 202 | ## UI interface 203 | You can deploy a simple UI interface with one line of code using your datasets or our bio dataset - a gradio UI will be generated and you can interact with it to validate your hypothesis. 204 | 205 | ```python 206 | agent.launch_UI() 207 | ``` 208 | 209 | An interface like this will be popped up: 210 | 211 | [![demo](https://img.youtube.com/vi/jYFEeP2mEY8/0.jpg)](https://www.youtube.com/watch?v=jYFEeP2mEY8) 212 | 213 | ## Acknowledgement 214 | The DiscoveryBench benchmark and some of the baseline agents are built on top of [allenai/discoverybench](https://github.com/allenai/discoverybench). Thanks for their awsome work! 215 | 216 | ## Contact 217 | 218 | For any questions, please raise an issue in the GitHub or contact Kexin Huang (kexinh@cs.stanford.edu). 219 | -------------------------------------------------------------------------------- /popper/benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import traceback 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from popper.utils import DiscoveryBenchDataLoader 9 | 10 | class gene_perturb_hypothesis: 11 | def __init__(self, dataset='IL2', 12 | num_of_samples = 50, 13 | permuted = False, 14 | user_study_neg_genes = False, 15 | path = None): 16 | 17 | path = os.path.join(path, 'benchmark/targetval/') 18 | if dataset == 'IL2': 19 | self.prompt = "Gene {gene} regulates the production of Interleukin-2 (IL-2)." 20 | ground_truth_path = path + 'ground_truth_IL2.csv' 21 | elif dataset == 'IFNG': 22 | self.prompt = "Gene {gene} regulates the production of Interferon-gamma (IFN-g)." 23 | ground_truth_path = path + 'ground_truth_IFNG.csv' 24 | 25 | self.ground_truth = pd.read_csv(ground_truth_path, index_col=0) 26 | 27 | self.query = [] 28 | self.ground_truth['abs_score'] = self.ground_truth.Score.abs() 29 | self.ground_truth = self.ground_truth.sort_values('abs_score') 30 | self.hypothesis2score = self.ground_truth.abs_score.to_dict() 31 | 32 | if not permuted: 33 | self.query += self.ground_truth.iloc[-num_of_samples:].index.values.tolist() 34 | self.answer = np.array([True] * num_of_samples) 35 | else: 36 | self.query += self.ground_truth.sample(frac = 1, random_state = 42).iloc[:num_of_samples].index.values.tolist() 37 | self.answer = np.array([False] * num_of_samples) 38 | 39 | if user_study_neg_genes: 40 | self.query = ['CD28', 'CD2', 'CD3D', 'MAK16', 'RAC2', 'CD3E', 'VAV1', 'CD247', 'ZAP70', 'CD3G', 'LCP2'] 41 | self.answer = np.concatenate([self.answer, np.array([False] * 10)]) 42 | 43 | 44 | def get_example(self, index = None): 45 | if index is None: 46 | index = np.random.randint(len(self.query)) 47 | 48 | q = self.query[index] 49 | a = self.hypothesis2score[q] 50 | return {"prompt": self.prompt.format(gene = q), 51 | "gene": q, 52 | "answer": a, 53 | "binary_answer": self.answer[index] 54 | } 55 | 56 | def get_iterator(self): 57 | for i in range(len(self.query)): 58 | yield self.get_example(i) 59 | 60 | def output_class(self): 61 | from langchain_core.pydantic_v1 import BaseModel, Field 62 | from typing import Optional 63 | class Output(BaseModel): 64 | """Whether or not this hypothesis is considered true/false.""" 65 | 66 | hypothesis_test_result: Optional[bool] = Field( 67 | description="Whether or not this hypothesis is considered true/false." 68 | ) 69 | return Output 70 | 71 | def evaluate(self, response, answers=None): 72 | from sklearn.metrics import accuracy_score, average_precision_score, f1_score 73 | predicted = np.array([i[1] for i in response]) 74 | if not answers: 75 | answers = np.array([exp["binary_answer"] for exp in self.examples]) 76 | 77 | res_stats = np.array([i[0] for i in response]) 78 | return { 79 | 'accuracy': accuracy_score(answers, predicted), 80 | 'power': np.sum((predicted == True) & (answers == True)) / np.sum((answers == True)), 81 | 'false discovery rate': np.sum((predicted == True) & (answers == False)) / np.sum((answers == False)), 82 | 'f1': f1_score(answers, predicted) 83 | } 84 | 85 | class discovery_bench_hypothesis: 86 | def __init__(self, split="test", synthetic=False, num_samples=50, seed=1234, path = None): 87 | 88 | print("----------Loading Discovery Bench------------") 89 | 90 | if split != "test": 91 | raise NotImplementedError 92 | 93 | random.seed(seed) 94 | if path is None: 95 | raise ValueError 96 | 97 | root_path = path 98 | self.split = split 99 | self.synthetic = synthetic 100 | self.data_path = os.path.join(root_path, "discoverybench", "synthetic" if self.synthetic else "real", self.split) 101 | 102 | ground_truth_path = os.path.join(root_path, "answer_key/answer_key_synth.csv") if self.synthetic else os.path.join(root_path, "answer_key/answer_key_real_cleaned_1.csv") 103 | self.ground_truth = pd.read_csv(ground_truth_path) 104 | 105 | self.examples = [] 106 | for task_dir in os.listdir(self.data_path): 107 | task_path = os.path.join(self.data_path, task_dir) 108 | 109 | for file in os.listdir(task_path): 110 | if file.endswith(".json"): 111 | file_path = os.path.join(task_path, file) 112 | with open(file_path, 'r', encoding='utf-8', errors='replace') as f: 113 | metadata = json.load(f) 114 | metadata_id = int(file.split(".")[0].split("_")[1]) 115 | data_loader = DiscoveryBenchDataLoader(task_path, metadata) 116 | 117 | # permute the data loader for negative example 118 | permuted_dataloader = DiscoveryBenchDataLoader(task_path, metadata) 119 | permuted_dataloader.permute_selected_columns() 120 | 121 | 122 | for query_list in metadata["queries"]: 123 | for query in query_list: 124 | try: 125 | # print(file_path) 126 | hypothesis = self.ground_truth.loc[ 127 | (self.ground_truth['dataset'] == task_dir) & 128 | (self.ground_truth['metadataid'] == metadata_id) & 129 | (self.ground_truth['query_id'] == query["qid"]), 130 | 'gold_hypo' 131 | ].iloc[0] 132 | 133 | if 'non-trivially falsifiable' in self.ground_truth.columns and self.ground_truth.loc[ 134 | (self.ground_truth['dataset'] == task_dir) & 135 | (self.ground_truth['metadataid'] == metadata_id) & 136 | (self.ground_truth['query_id'] == query["qid"]), 137 | 'non-trivially falsifiable' 138 | ].iloc[0] == 0: 139 | continue 140 | 141 | self.examples.append({ 142 | "task": task_dir, 143 | "domain": metadata["domain"], 144 | "metadataid": metadata_id, 145 | "query_id": query["qid"], 146 | "prompt": hypothesis, 147 | # "metadata": metadata, 148 | "data_loader": data_loader, 149 | "answer": True, 150 | }) 151 | 152 | self.examples.append({ 153 | "task": task_dir, 154 | "domain": metadata["domain"], 155 | "metadataid": metadata_id, 156 | "query_id": query["qid"], 157 | "prompt": hypothesis, 158 | # "metadata": metadata, 159 | "data_loader": permuted_dataloader, 160 | "answer": False, 161 | }) 162 | except Exception as e: 163 | # print(e) 164 | # print(traceback.format_exc()) 165 | pass 166 | 167 | 168 | if num_samples < len(self.examples): 169 | random.shuffle(self.examples) 170 | self.examples = self.examples[:num_samples] 171 | 172 | self.num_samples = len(self.examples) 173 | print(f"Loaded {self.num_samples} hypotheses") 174 | print("--------------------------------------") 175 | 176 | def get_example(self, index = None): 177 | return self.examples[index] 178 | 179 | def get_iterator(self): 180 | for i in range(self.num_samples): 181 | yield self.get_example(i) 182 | 183 | def output_class(self): 184 | from langchain_core.pydantic_v1 import BaseModel, Field 185 | from typing import Optional 186 | class Output(BaseModel): 187 | """Whether or not this hypothesis is considered true/false.""" 188 | 189 | hypothesis_test_result: Optional[bool] = Field( 190 | description="Whether or not this hypothesis is considered true/false." 191 | ) 192 | return Output 193 | 194 | def evaluate(self, response, answers=None): 195 | ## expected [(res_stat, conclusion)] following the order of the query 196 | from sklearn.metrics import accuracy_score, average_precision_score, f1_score 197 | predicted = np.array([i[1] for i in response]) 198 | if not answers: 199 | answers = np.array([exp["answer"] for exp in self.examples]) 200 | else: 201 | answers = np.array(answers) 202 | 203 | res_stats = np.array([i[0] for i in response]) 204 | return { 205 | 'accuracy': accuracy_score(answers, predicted), 206 | 'power': np.sum((predicted == True) & (answers == True)) / np.sum((answers == True)), 207 | 'false discovery rate': np.sum((predicted == True) & (answers == False)) / np.sum((answers == False)), 208 | 'f1': f1_score(answers, predicted) 209 | 210 | # 'auprc': average_precision_score(answers, res_stats), 211 | # 'stat_pos': res_stats[np.where(answers & res_stats)[0]].mean(), 212 | # 'stat_neg': res_stats[np.where(answers == False & res_stats)[0]].mean() 213 | } -------------------------------------------------------------------------------- /popper/popper.py: -------------------------------------------------------------------------------- 1 | from popper.utils import ExperimentalDataLoader, CustomDataLoader, DiscoveryBenchDataLoader 2 | from popper.agent import SequentialFalsificationTest 3 | from typing import Optional, Dict, Any 4 | import os 5 | import requests 6 | import zipfile 7 | import urllib 8 | from tqdm import tqdm 9 | import tarfile 10 | import subprocess 11 | import shutil 12 | 13 | class Popper: 14 | """Wrapper class for hypothesis validation using sequential falsification testing.""" 15 | 16 | def __init__(self, llm: str = "claude-3-5-sonnet-20240620", is_locally_served = False, server_port = None, api_key = "EMPTY", **kwargs): 17 | """Initialize Popper. 18 | 19 | Args: 20 | llm (str): Name of the LLM model to use 21 | **kwargs: Additional arguments to pass to SequentialFalsificationTest 22 | """ 23 | self.llm = llm 24 | self.agent = None 25 | self.data_loader = None 26 | self.is_local = is_locally_served 27 | self.port = server_port 28 | self.api_key = api_key 29 | self.kwargs = kwargs 30 | 31 | def register_data(self, data_path: str, data_sampling: int = -1, loader_type: str = 'bio', metadata: Optional[Dict] = None): 32 | """Register data for hypothesis testing. 33 | 34 | Args: 35 | data_path (str): Path to data directory 36 | data_sampling (int): Number of datasets to sample (-1 for all) 37 | loader_type (str): Type of data loader to use ('bio', 'custom', or 'discovery_bench') 38 | metadata (Optional[Dict]): Metadata required for DiscoveryBenchDataLoader 39 | """ 40 | if not os.path.exists(data_path): 41 | os.makedirs(data_path) 42 | 43 | self.data_path = data_path 44 | if not os.path.exists(os.path.join(data_path, 'bio_database')): 45 | print('It will take a few minutes to download the data for the first time...') 46 | self.download_all_data() 47 | else: 48 | print('Data already exists, loading...') 49 | 50 | if loader_type == 'bio': 51 | self.data_loader = ExperimentalDataLoader( 52 | data_path=data_path, 53 | table_dict_selection='all_bio', 54 | data_sampling=data_sampling 55 | ) 56 | elif loader_type == 'bio_selected': 57 | self.data_loader = ExperimentalDataLoader( 58 | data_path=data_path, 59 | table_dict_selection='default', 60 | data_sampling=data_sampling 61 | ) 62 | elif loader_type == 'custom': 63 | self.data_loader = CustomDataLoader(data_folder=data_path) 64 | elif loader_type == 'discovery_bench': 65 | if metadata is None: 66 | raise ValueError("Metadata must be provided for DiscoveryBenchDataLoader") 67 | self.data_loader = DiscoveryBenchDataLoader(data_path=data_path, metadata=metadata) 68 | else: 69 | raise ValueError(f"Unknown loader_type: {loader_type}") 70 | 71 | 72 | def configure(self, 73 | alpha: float = 0.1, 74 | aggregate_test: str = 'E-value', 75 | max_num_of_tests: int = 5, 76 | max_retry: int = 5, 77 | time_limit: int = 2, 78 | relevance_checker: bool = True, 79 | use_react_agent: bool = True): 80 | """Configure the sequential falsification test parameters. 81 | 82 | Args: 83 | alpha (float): Significance level 84 | aggregate_test (str): Test aggregation method 85 | max_num_of_tests (int): Maximum number of tests to run 86 | max_retry (int): Maximum number of retries for failed tests 87 | time_limit (int): Time limit in hours 88 | relevance_checker (bool): Whether to use relevance checker 89 | use_react_agent (bool): Whether to use ReAct agent 90 | """ 91 | if self.data_loader is None: 92 | raise ValueError("Please register data first using register_data()") 93 | 94 | self.agent = SequentialFalsificationTest(llm=self.llm, is_local=self.is_local, port=self.port, api_key=self.api_key) 95 | self.agent.configure( 96 | data=self.data_loader, 97 | alpha=alpha, 98 | aggregate_test=aggregate_test, 99 | max_num_of_tests=max_num_of_tests, 100 | max_retry=max_retry, 101 | time_limit=time_limit, 102 | relevance_checker=relevance_checker, 103 | use_react_agent=use_react_agent, 104 | **self.kwargs 105 | ) 106 | 107 | def validate(self, hypothesis: str) -> Dict[str, Any]: 108 | """Validate a scientific hypothesis using sequential falsification testing. 109 | 110 | Args: 111 | hypothesis (str): The scientific hypothesis to test 112 | 113 | Returns: 114 | Dict containing the test results including logs, final message, and parsed results 115 | """ 116 | if self.agent is None: 117 | raise ValueError("Please configure the agent first using configure()") 118 | 119 | log, last_message, parsed_result = self.agent.go(hypothesis) 120 | 121 | return { 122 | "log": log, 123 | "last_message": last_message, 124 | "parsed_result": parsed_result 125 | } 126 | 127 | def _setup_default_agent(self): 128 | """Set up agent with default configuration if not already configured.""" 129 | self.configure( 130 | alpha=0.1, 131 | aggregate_test='E-value', 132 | max_num_of_tests=5, 133 | max_retry=5, 134 | time_limit=2, 135 | relevance_checker=True, 136 | use_react_agent=True 137 | ) 138 | 139 | def launch_UI(self): 140 | import gradio as gr 141 | from gradio import ChatMessage 142 | from time import time 143 | import asyncio 144 | import copy 145 | 146 | async def generate_response(prompt, 147 | designer_history=[], 148 | executor_history=[], 149 | relevance_checker_history=[], 150 | error_control_history=[], 151 | summarizer_history=[]): 152 | 153 | #designer_history.append(ChatMessage(role="user", content=prompt)) 154 | #yield designer_history, executor_history, relevance_checker_history, error_control_history, summarizer_history 155 | 156 | # Initialize log tracking 157 | prev_log = copy.deepcopy(self.agent.log) # Store initial log state 158 | 159 | # Run the agent asynchronously 160 | task = asyncio.create_task(asyncio.to_thread(self.agent.go, prompt)) 161 | 162 | while not task.done(): # Check while the agent is still running 163 | #print("Checking for new log messages...") 164 | await asyncio.sleep(1) # Wait for 1 second 165 | 166 | # Check if log has changed 167 | if self.agent.log != prev_log: 168 | prev_log = copy.deepcopy(self.agent.log) # Update previous log state 169 | 170 | # Convert new log messages to ChatMessage format 171 | designer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['designer']] 172 | executor_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['executor']] 173 | relevance_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['relevance_checker']] 174 | sequential_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['sequential_testing']] 175 | summarizer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['summarizer']] 176 | 177 | yield designer_msgs, executor_msgs, relevance_msgs, sequential_msgs, summarizer_msgs 178 | 179 | # Ensure final result is captured 180 | result = await task 181 | 182 | # Convert final logs to ChatMessage format before yielding 183 | designer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['designer']] 184 | executor_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['executor']] 185 | relevance_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['relevance_checker']] 186 | sequential_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['sequential_testing']] 187 | summarizer_msgs = [ChatMessage(role="assistant", content=msg) for msg in self.agent.log['summarizer']] 188 | 189 | yield designer_msgs, executor_msgs, relevance_msgs, sequential_msgs, summarizer_msgs 190 | 191 | 192 | def like(evt: gr.LikeData): 193 | print("User liked the response") 194 | print(evt.index, evt.liked, evt.value) 195 | 196 | with gr.Blocks() as demo: 197 | with gr.Row(): 198 | with gr.Column(scale=1): 199 | designer_chatbot = gr.Chatbot(label="Popper Experiment Designer", 200 | type="messages", height=600, 201 | show_copy_button=True, 202 | show_share_button = True, 203 | group_consecutive_messages = False, 204 | show_copy_all_button = True, 205 | ) 206 | relevance_checker_chatbot = gr.Chatbot(label="Relevance Checker", 207 | type="messages", height=300, 208 | show_copy_button=True, 209 | show_share_button = True, 210 | group_consecutive_messages = False, 211 | show_copy_all_button = True 212 | ) 213 | 214 | with gr.Column(scale=1): 215 | executor_chatbot = gr.Chatbot(label="Popper Experiment Executor", 216 | type="messages", height=600, 217 | show_copy_button=True, 218 | show_share_button = True, 219 | group_consecutive_messages = False, 220 | show_copy_all_button = True, 221 | ) 222 | error_control_chatbot = gr.Chatbot(label="Sequential Error Control", 223 | type="messages", height=300, 224 | show_copy_button=True, 225 | show_share_button = True, 226 | group_consecutive_messages = False, 227 | show_copy_all_button = True 228 | ) 229 | 230 | with gr.Row(): 231 | summarizer = gr.Chatbot(label="Popper Summarizer", 232 | type="messages", height=300, 233 | show_copy_button=True, 234 | show_share_button = True, 235 | group_consecutive_messages = False, 236 | show_copy_all_button = True, 237 | ) 238 | with gr.Row(): 239 | # Textbox on the left, and Button with an icon on the right 240 | prompt_input = gr.Textbox(show_label = False, placeholder="What is your hypothesis?", scale=8) 241 | button = gr.Button("Validate") 242 | 243 | button.click(lambda: gr.update(value=""), inputs=None, outputs=prompt_input) 244 | # Bind button click to generate_response function, feeding results to both chatbots 245 | button.click(generate_response, inputs=[prompt_input, designer_chatbot, executor_chatbot, relevance_checker_chatbot, error_control_chatbot, summarizer], outputs=[designer_chatbot, executor_chatbot, relevance_checker_chatbot, error_control_chatbot, summarizer]) 246 | 247 | demo.launch(share = True) 248 | 249 | 250 | def download_all_data(self): 251 | url = "https://dataverse.harvard.edu/api/access/datafile/10888484" 252 | file_name = 'popper_data_processed' 253 | self._download_and_extract_data(url, file_name) 254 | 255 | def _download_and_extract_data(self, url, file_name): 256 | """Download, extract, and merge directories using rsync.""" 257 | tar_file_path = os.path.join(self.data_path, f"{file_name}.tar.gz") 258 | 259 | if not os.path.exists(tar_file_path): 260 | # Download the file 261 | print(f"Downloading {file_name}.tar.gz...") 262 | self._download_with_progress(url, tar_file_path) 263 | print("Download complete.") 264 | 265 | # Extract the tar.gz file 266 | print("Extracting files...") 267 | with tarfile.open(tar_file_path, 'r:gz') as tar: 268 | for member in tqdm(tar.getmembers(), desc="Extracting: "): 269 | member.name = member.name.split('popper_data_processed/')[-1] # Strip directory structure 270 | tar.extract(member, self.data_path) 271 | print("Extraction complete.") 272 | 273 | def _download_with_progress(self, url, file_path): 274 | """Download a file with a progress bar.""" 275 | request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) 276 | response = urllib.request.urlopen(request) 277 | total_size = int(response.getheader('Content-Length').strip()) 278 | block_size = 1024 # 1 KB 279 | 280 | with open(file_path, 'wb') as file, tqdm( 281 | total=total_size, unit='B', unit_scale=True, desc="Downloading" 282 | ) as pbar: 283 | while True: 284 | buffer = response.read(block_size) 285 | if not buffer: 286 | break 287 | file.write(buffer) 288 | pbar.update(len(buffer)) -------------------------------------------------------------------------------- /popper/react_utils.py: -------------------------------------------------------------------------------- 1 | # Set up the base template 2 | from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser, create_react_agent 3 | from langchain.prompts import StringPromptTemplate 4 | from langchain.tools import BaseTool 5 | from langchain.chains.llm import LLMChain 6 | from langchain_experimental.tools.python.tool import PythonAstREPLTool 7 | from langchain.schema import AgentAction, AgentFinish 8 | from pydantic import Field, PrivateAttr 9 | from typing import List, Union, Dict 10 | import contextlib 11 | import io 12 | import logging 13 | import re 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | template = """{system_prompt} 18 | 19 | You have access to the following tools: 20 | {tools} 21 | 22 | Use the following format: 23 | 24 | Falsification Test: description of a hypothesis falsification test that you need to implement 25 | Datasets: the names and descriptions of datasets relevant to the input falsification test 26 | Thought: you should always think about what to do 27 | Action: the action to take, should be one of [{tool_names}] 28 | Action Input: the input to the action 29 | Observation: the result of the action 30 | ... (this Thought/Action/Action Input/Observation can repeat N times) 31 | Thought: I now know the final answer 32 | Final Answer: the final output from the falsification test (i.e., whether you are able to reject the null hypothesis with statistical significance). Make sure to also include the p-value of the statistical test written in scientific notations. 33 | 34 | IMPORTANT: Please make sure the Final Answer includes the p-value of the falsification test regardless if you are able to reject the null hypothesis. **Only return the Final Answer if you have obtained a non-zero p-value**. When printing p-values, please use scientific notations instead of the raw number. 35 | 36 | IMPORTANT: Please avoid p-hacking! Be fair and rigorous. 37 | 38 | Note: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section. 39 | 40 | -------------------------------------------- 41 | Example 42 | Falsification Test: 43 | {{ 44 | "Falsification Test name": "Body Length Evolution and Speciation Rate Relationship Test", 45 | "Falsification Test description": "Testing for a significant positive relationship between maximum body length evolution rate and spatial variation in speciation rates.", 46 | "Falsification Test Null hypothesis": "There is no statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates.", 47 | "Falsification Test Alternate hypothesis": "There is a statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates." 48 | }} 49 | Datasets: 50 | {{ 51 | "name": "df_body-size-evolution-in-south-american-freshwater-fishes", 52 | "description": "Data on body size evolution in South American freshwater fishes, including speciation and extinction rates", 53 | "columns": {{ 54 | "raw": [ 55 | {{ 56 | "name": "HYBAS_ID", 57 | "description": "Unique identifier for each hydrological basin" 58 | }}, 59 | {{ 60 | "name": "long", 61 | "description": "Longitude of the basin location" 62 | }}, 63 | {{ 64 | "name": "lat", 65 | "description": "Latitude of the basin location" 66 | }}, 67 | {{ 68 | "name": "BAMM_speciation", 69 | "description": "Rate of speciation as calculated by the BAMM method" 70 | }}, 71 | {{ 72 | "name": "BAMM_extinction", 73 | "description": "Rate of extinction as calculated by the BAMM method" 74 | }}, 75 | {{ 76 | "name": "BAMM_NetDiv", 77 | "description": "Net diversification rate, calculated as speciation minus extinction" 78 | }}, 79 | {{ 80 | "name": "aet", 81 | "description": "Mean annual evapotranspiration for each basin" 82 | }}, 83 | {{ 84 | "name": "Elevation", 85 | "description": "Average elevation of the basin" 86 | }}, 87 | {{ 88 | "name": "sgr", 89 | "description": "Species growth rate in each basin" 90 | }}, 91 | {{ 92 | "name": "soil_div", 93 | "description": "Soil diversity index for each basin" 94 | }}, 95 | {{ 96 | "name": "area", 97 | "description": "Total area of the basin in square kilometers" 98 | }}, 99 | {{ 100 | "name": "diversity", 101 | "description": "Diversity index for the species in each basin" 102 | }} 103 | ] 104 | }} 105 | }} 106 | Thought: First, I need to load the dataset from the global namespace in Python and inspect the data to identify the relevant columns for this hypothesis test. 107 | Action: python_repl_ast 108 | Action Input: import pandas as pd\n\ndf = df_body-size-evolution-in-south-american-freshwater-fishes\ndf.head() 109 | Observation: 110 | HYBAS_ID long lat BAMM_speciation BAMM_extinction BAMM_NetDiv ... aet Elevation sgr soil_div area diversity 111 | 0 6050000010 -76.477422 7.742693 0.137392 0.026807 0.110585 ... 1387 330.150088 166 0.482402 72363.7 68 112 | 1 6050000740 -74.628725 9.803586 0.117235 0.025796 0.091438 ... 1082 69.475294 23 0.457436 17944.3 35 113 | 2 6050068100 -75.295995 8.448815 0.119381 0.023826 0.095555 ... 1312 143.032178 74 0.378793 17105.5 44 114 | 3 6050068110 -74.608408 8.922863 0.132477 0.027777 0.104700 ... 1445 14.724138 3 0.468328 610.1 48 115 | 4 6050070260 -75.591588 5.770093 0.120127 0.022940 0.097187 ... 1371 1378.729945 421 0.158870 61901.9 81 116 | [5 rows x 21 columns] 117 | Thought: Now that the dataset is loaded and I can see the columns, I need to perform a statistical test to assess the significance of the relationship between 'BAMM_speciation' and 'BAMM_NetDiv'. 118 | Action: python_repl_ast 119 | Action Input: from scipy.stats import linregress\n\n# Perform linear regression to test for a statistically significant relationship\nresult = linregress(df['BAMM_speciation'], df['BAMM_NetDiv'])\ncoefficient = result.slope\np_value = result.pvalue\ncoefficient, "{{:.2e}}".format(p_value) 120 | Observation: (0.5175306498596297, 3.50e-03) 121 | Thought: The linear regression analysis provides a coefficient of approximately 0.518, indicating a positive relationship, and the p-value is 3.50e-03, which is statistically significant at the 0.05 level. Based on this, I can reject the null hypothesis in the falsification test. 122 | Final Answer: Falsification test passes. The null hypothesis is rejected with a p-value of 3.50e-03. 123 | -------------------------------------------- 124 | 125 | Remember, your output should always **exactly** follow the aforementioned format: 126 | Falsification Test: description of a hypothesis falsification test that you need to implement 127 | Datasets: the names and descriptions of datasets relevant to the input falsification test 128 | Thought: you should always think about what to do 129 | Action: the action to take, should be one of [{tool_names}] 130 | Action Input: the input to the action 131 | Observation: the result of the action 132 | ... (this Thought/Action/Action Input/Observation can repeat N times) 133 | Thought: I now know the final answer 134 | Final Answer: the final output from the falsification test (i.e., whether you are able to reject the null hypothesis with statistical significance). Make sure to also include the p-value of the statistical test written in scientific notations. 135 | 136 | **IMPORTANT** 137 | You should ALWAYS report a p-value EXACTLY AS IT IS. If a p-value is 4.2e-01, report 4.2e-01, DO NOT REPORT 4.2e-02! 138 | BE CAREFUL WHEN READING THE P-VALUE RESULTS, MISREPORTING A P-VALUE IS WORSE THAN HAVING NO P-VALUE AT ALL. 139 | When reading p-values, make sure the sample sizes and the statistical test is valid. 140 | Please make sure to always return ONE valid p-value. If there are multiple p-values produced by the test, aggregate them in a meaningful and rigorous way. 141 | ** Always make sure the returned p-value matches your conclusion for the falsification test. For example, if you reject H0 but finds out that H1 is also incorrect (e.g., the suggested shape or relationship is wrong), you SHOULD NOT return a p-value < 0.05. 142 | If you think it's impossible to find a valid p-value for the falsification test, return a p-value of 1.00e+00. 143 | DO NOT perform p-hacking. 144 | 145 | Begin! 146 | 147 | {input} {agent_scratchpad}""" 148 | 149 | 150 | def load_data_to_react_globals(data_loader): 151 | for name, df in data_loader.table_dict.items(): 152 | globals()[name] = df 153 | 154 | 155 | # Set up a prompt template 156 | class CustomPromptTemplate(StringPromptTemplate): 157 | # The template to use 158 | template: str 159 | # The list of tools available 160 | tools: List[BaseTool] 161 | 162 | def format(self, **kwargs) -> str: 163 | # Get the intermediate steps (AgentAction, Observation tuples) 164 | # Format them in a particular way 165 | intermediate_steps = kwargs.pop("intermediate_steps") 166 | thoughts = "" 167 | for action, observation in intermediate_steps: 168 | thoughts += action.log 169 | thoughts += f"\nObservation: {observation}\nThought: " 170 | # Set the agent_scratchpad variable to that value 171 | kwargs["agent_scratchpad"] = thoughts 172 | # Create a tools variable from the list of tools provided 173 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) 174 | # Create a list of tool names for the tools provided 175 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) 176 | prompt = self.template.format(**kwargs) 177 | # print([prompt]) 178 | return prompt 179 | 180 | # CustomOutputParser to parse the output of the LLM and execute actions 181 | class CustomOutputParser(AgentOutputParser): 182 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: 183 | # Check if agent should finish 184 | if "Final Answer:" in llm_output: 185 | return AgentFinish( 186 | return_values={"output": llm_output}, 187 | log=llm_output, 188 | ) 189 | # Parse out the action and action input 190 | regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" 191 | match = re.search(regex, llm_output, re.DOTALL) 192 | if not match: 193 | # raise ValueError(f"Could not parse LLM output: `{llm_output}`") 194 | print(f"Warning: could not parse LLM output: `{llm_output}`, finishing chain...") 195 | return AgentFinish( 196 | return_values={"output": llm_output}, 197 | log=llm_output, 198 | ) 199 | action = match.group(1).strip() 200 | action_input = match.group(2) 201 | # Return the action and action input 202 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) 203 | 204 | 205 | class CustomPythonAstREPLTool(PythonAstREPLTool): 206 | _exec_globals:Dict = PrivateAttr() 207 | def __init__(self, *args, **kwargs): 208 | super().__init__(*args, **kwargs) 209 | # Initialize a persistent global namespace for code execution 210 | self._exec_globals = {} 211 | self._exec_globals.update(__builtins__) 212 | 213 | def _set_globals(self, table_dict=None): 214 | self._exec_globals = {} 215 | self._exec_globals.update(__builtins__) 216 | 217 | if table_dict: 218 | self._exec_globals.update(table_dict) 219 | 220 | def _run(self, query: str, run_manager=None): 221 | code_match = re.search(r"```(.*?)```", query, re.DOTALL) 222 | if code_match: 223 | # Extract code within backticks 224 | code = code_match.group(1) 225 | else: 226 | code = query 227 | code = code.strip() 228 | if code.startswith("python"): 229 | code = code[len("python"):].lstrip() 230 | 231 | if code.endswith("Observation"): 232 | code = code[:-len("Observation")].rstrip() 233 | 234 | code_lines = code.strip().split('\n') 235 | code = '\n'.join(code_lines[:-1]) # avoid printing the last line twice 236 | last_line = code_lines[-1] 237 | 238 | output_capture = io.StringIO() 239 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture): 240 | logging.getLogger().handlers[0].stream = output_capture 241 | try: 242 | exec(code, self._exec_globals) 243 | try: 244 | result = eval(last_line, self._exec_globals) 245 | if result is not None: 246 | print(result, file=output_capture) 247 | except: 248 | pass 249 | except Exception as e: 250 | return str(e) 251 | 252 | # Retrieve the output and return it 253 | output = output_capture.getvalue() 254 | return output if output else "Execution completed without output." 255 | 256 | 257 | def create_agent( 258 | llm, 259 | handlers, 260 | max_iterations = 50, 261 | early_stopping_method: str = "force", 262 | ): 263 | output_parser = CustomOutputParser() 264 | python_tool = CustomPythonAstREPLTool(callbacks=handlers) 265 | tools = [python_tool] 266 | tool_names = [tool.name for tool in tools] 267 | 268 | prompt = CustomPromptTemplate( 269 | template=template, 270 | tools=tools, 271 | input_variables=["system_prompt", "input", "intermediate_steps", "tool_names", "tools", "agent_scratchpad"] 272 | ) 273 | # llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=handlers) 274 | 275 | # agent = LLMSingleActionAgent( 276 | # llm_chain=llm_chain, 277 | # output_parser=output_parser, 278 | # stop=["\nObservation:"], 279 | # allowed_tools=tool_names 280 | # ) 281 | agent = create_react_agent( 282 | llm=llm, 283 | tools=tools, 284 | prompt=prompt, 285 | output_parser=output_parser, 286 | stop_sequence=["\nObservation:"], 287 | ) 288 | 289 | return AgentExecutor.from_agent_and_tools( 290 | agent=agent, 291 | tools=tools, 292 | verbose=True, 293 | max_iterations=max_iterations, 294 | callbacks=handlers, 295 | early_stopping_method=early_stopping_method 296 | ) -------------------------------------------------------------------------------- /popper/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import json 5 | import time 6 | import openai 7 | from glob import glob 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from popper.llm.custom_model import CustomChatModel 12 | from langchain_anthropic import ChatAnthropic 13 | from langchain_openai import ChatOpenAI 14 | from langchain_core.messages.base import get_msg_title_repr 15 | from langchain_core.utils.interactive_env import is_interactive_env 16 | 17 | def get_llm(model = 'claude-3-5-sonnet-20240620', temperature=0.7, port=30000, api_key = "EMPTY", **kwargs): 18 | source = "Local" 19 | if model[:7] == 'claude-': 20 | source = 'Anthropic' 21 | elif model[:4] == 'gpt-' or model.startswith("o1"): 22 | source = 'OpenAI' 23 | # elif model.startswith('llama'): 24 | # source = "Llama" 25 | # if source not in ['OpenAI', 'Anthropic']: 26 | # raise ValueError('Invalid source') 27 | if source == 'OpenAI': 28 | if model.startswith("o1"): 29 | return ChatOpenAI(model = model, temperature = -1, **kwargs) 30 | return ChatOpenAI(model = model, temperature = temperature, **kwargs) 31 | elif source == 'Anthropic': 32 | return ChatAnthropic(model = model, 33 | temperature = temperature, 34 | max_tokens = 4096, 35 | **kwargs) 36 | else: 37 | # assuming a locally-served model 38 | assert port is not None, f"Model {model} is not supported, please provide a local port if it is a locally-served model." 39 | llm = CustomChatModel(model = model, model_type=source, temperature = temperature) 40 | llm.client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key=api_key).chat.completions 41 | return llm 42 | 43 | class ExperimentalDataLoader: 44 | def __init__(self, data_path, table_dict_selection='default', data_sampling=-1): 45 | self.data_path = os.path.join(data_path, 'bio_database') 46 | self.available_datasets = [ 47 | "gtex_tissue_gene_tpm", 48 | "gwas_catalog", 49 | "gene_info", 50 | "genetic_interaction", 51 | "genebass_synonymous_filtered", 52 | "genebass_missense_LC_filtered", 53 | "genebass_pLoF_filtered", 54 | "affinity_capture_ms", 55 | "two_hybrid", 56 | "synthetic_growth_defect", 57 | "affinity_capture_rna", 58 | "co_fractionation", 59 | "synthetic_lethality", 60 | "dosage_growth_defect", 61 | "proximity_label_ms", 62 | "synthetic_rescue", 63 | "reconstituted_complex", 64 | "eqtl_ukbb", 65 | "pqtl_ukbb", 66 | "sqtl_ukbb", 67 | "variant_table", 68 | "trait" 69 | ] 70 | 71 | self.permute_columns = { 72 | 'gtex_tissue_gene_tpm': ['Gene'], 73 | 'gwas_catalog': ['REPORTED GENE(S)', 'MAPPED_GENE', 'UPSTREAM_GENE_ID', 'DOWNSTREAM_GENE_ID', 'SNP_GENE_IDS'], 74 | } 75 | 76 | # Load datasets based on user input (default or all_bio) 77 | if table_dict_selection == 'default': 78 | self.datasets_to_load = [ 79 | "gtex_tissue_gene_tpm", 80 | "gwas_catalog", 81 | "gene_info" 82 | ] 83 | elif table_dict_selection == 'all_bio': 84 | self.datasets_to_load = self.available_datasets # Load all datasets 85 | 86 | if data_sampling != -1: 87 | all_datasets = self.available_datasets 88 | np.random.seed(42) 89 | np.random.shuffle(all_datasets) 90 | self.datasets_to_load = all_datasets[:data_sampling] 91 | print(f"Sampled datasets: {self.datasets_to_load}") 92 | 93 | # Load the selected datasets into the table_dict 94 | self.table_dict_selection = table_dict_selection 95 | self.table_dict = self._load_selected_datasets() 96 | self.data_desc = self._generate_data_description() 97 | 98 | def _load_selected_datasets(self): 99 | """Loads only the selected datasets and returns a dictionary.""" 100 | table_dict = {} 101 | for dataset in self.datasets_to_load: 102 | df_name = f"df_{dataset}" 103 | table_dict[df_name] = self._load_data(f"{dataset}.pkl") 104 | return table_dict 105 | 106 | def _load_data(self, file_name): 107 | """Helper method to load data from a pickle file.""" 108 | try: 109 | return pd.read_pickle(os.path.join(self.data_path, file_name)) 110 | except FileNotFoundError: 111 | print(f"File {file_name} not found in path {self.data_path}") 112 | return None 113 | 114 | def _generate_data_description(self): 115 | """Generates a description of each dataset's columns and the first row of data.""" 116 | desc = "" 117 | for name, df in self.table_dict.items(): 118 | if df is not None: 119 | desc += f"{name}:\n{dict(zip(df.columns.values, df.iloc[0].values))}\n\n" 120 | return desc 121 | 122 | def get_data(self, table_name): 123 | """Returns the requested DataFrame.""" 124 | return self.table_dict.get(table_name, None) 125 | 126 | def load_into_globals(self): 127 | """Loads each dataset into the global namespace.""" 128 | for name, df in self.table_dict.items(): 129 | if df is not None: 130 | globals()[name] = df 131 | 132 | def display_data_description(self): 133 | """Prints the data description.""" 134 | print(self.data_desc) 135 | 136 | def permute_selected_columns(self, random_seed = 42): 137 | 138 | if self.table_dict_selection == 'default': 139 | self.random_seed = random_seed 140 | """Permutes the specified columns together for each dataset in permute_columns.""" 141 | for dataset_name, columns_to_permute in self.permute_columns.items(): 142 | df_name = f"df_{dataset_name}" 143 | df = self.table_dict.get(df_name, None) 144 | 145 | if df is not None and all(col in df.columns for col in columns_to_permute): 146 | # Set the random seed for reproducibility 147 | if self.random_seed is not None: 148 | np.random.seed(self.random_seed) 149 | 150 | # Permute rows of the selected columns together 151 | # Shuffle the DataFrame rows and then reassign the columns 152 | permuted_df = df[columns_to_permute].sample(frac=1, random_state=self.random_seed).reset_index(drop=True) 153 | df[columns_to_permute] = permuted_df 154 | 155 | # Update the table_dict with the permuted DataFrame 156 | self.table_dict[df_name] = df 157 | print(f"Permuted columns {columns_to_permute} in dataset {df_name}") 158 | else: 159 | print(f"Columns {columns_to_permute} not found in dataset {df_name} or dataset does not exist.") 160 | elif self.table_dict_selection == 'all_bio': 161 | for df_name in self.table_dict: 162 | df = self.table_dict[df_name] 163 | permuted_df = df.copy() 164 | for col in df.columns: 165 | permuted_df[col] = df[col].sample(frac=1, random_state=random_seed).reset_index(drop=True) 166 | random_seed = random_seed * 2 % (2 ** 31) 167 | self.table_dict[df_name] = permuted_df 168 | 169 | 170 | def load_file_dynamic(filepath): 171 | # Read the first few bytes to infer the delimiter 172 | with open(filepath, 'r') as file: 173 | first_line = file.readline() 174 | 175 | # Check for delimiter 176 | if '\t' in first_line: 177 | delimiter = '\t' 178 | elif ',' in first_line: 179 | delimiter = ',' 180 | else: 181 | raise ValueError("Unknown delimiter. File is neither CSV nor TSV.") 182 | 183 | # Load the DataFrame using the detected delimiter 184 | df = pd.read_csv(filepath, delimiter=delimiter) 185 | return df 186 | 187 | 188 | class DiscoveryBenchDataLoader: 189 | def __init__(self, data_path, metadata): 190 | self.data_path = data_path 191 | self.metadata = metadata # dictionary/json object 192 | self.available_datasets = metadata["datasets"] 193 | self.table_dict = self._load_datasets() 194 | self.data_desc = self._generate_data_description() 195 | 196 | def _load_datasets(self): 197 | table_dict = {} 198 | for entry in self.available_datasets: 199 | table_path = os.path.join(self.data_path, entry["name"]) 200 | table_dict[f'df_{entry["name"].split(".")[0]}'] = load_file_dynamic(table_path) 201 | return table_dict 202 | 203 | def _generate_data_description(self): 204 | desc = {} 205 | for entry in self.available_datasets: 206 | # print(entry["name"]) 207 | table_name = f'df_{entry["name"].split(".")[0]}' 208 | df = self.table_dict[table_name] 209 | 210 | columns = entry["columns"]["raw"] 211 | for column in columns: 212 | value = df[column['name']].iloc[0] 213 | if isinstance(value, np.generic): 214 | value = value.item() 215 | column["example_value"] = value 216 | desc_entry = { 217 | "description": entry["description"], 218 | "columns": json.dumps(columns), 219 | } 220 | desc[table_name] = desc_entry 221 | 222 | return json.dumps(desc, indent=4) 223 | 224 | def load_into_globals(self): 225 | """Loads each dataset into the global namespace.""" 226 | for name, df in self.table_dict.items(): 227 | if df is not None: 228 | globals()[name] = df 229 | 230 | def display_data_description(self): 231 | """Prints the data description.""" 232 | print(self.data_desc) 233 | 234 | def permute_selected_columns(self, columns="all", random_seed = 42): 235 | # permute all columns 236 | np.random.seed(random_seed) 237 | for df_name in self.table_dict: 238 | df = self.table_dict[df_name] 239 | self.table_dict[df_name] = df.apply(np.random.permutation) 240 | 241 | # for df_name in self.table_dict: 242 | # df = self.table_dict[df_name] 243 | # permuted_df = df.copy() 244 | # for col in df.columns: 245 | # permuted_df[col] = df[col].sample(frac=1, random_state=random_seed).reset_index(drop=True) 246 | # random_seed = random_seed * 2 % (2 ** 31) 247 | # self.table_dict[df_name] = permuted_df 248 | 249 | def pretty_print(message, printout = True): 250 | if isinstance(message, tuple): 251 | title = message 252 | else: 253 | if isinstance(message.content, list): 254 | title = get_msg_title_repr(message.type.title().upper() + " Message", bold=is_interactive_env()) 255 | if message.name is not None: 256 | title += f"\nName: {message.name}" 257 | 258 | for i in message.content: 259 | if i['type'] == 'text': 260 | title += f"\n{i['text']}\n" 261 | elif i['type'] == 'tool_use': 262 | title += f"\nTool: {i['name']}" 263 | title += f"\nInput: {i['input']}" 264 | if printout: 265 | print(f"{title}") 266 | else: 267 | title = get_msg_title_repr(message.type.title() + " Message", bold=is_interactive_env()) 268 | if message.name is not None: 269 | title += f"\nName: {message.name}" 270 | title += f"\n\n{message.content}" 271 | if printout: 272 | print(f"{title}") 273 | return title 274 | 275 | class CustomDataLoader: 276 | def __init__(self, data_folder, random_seed=42): 277 | """Initialize data loader with path to data folder. 278 | 279 | Args: 280 | data_folder (str): Path to folder containing pickle files 281 | random_seed (int): Random seed for permutations 282 | """ 283 | self.data_path = data_folder 284 | self.random_seed = random_seed 285 | self.table_dict = {} 286 | self._load_all_datasets() 287 | self.data_desc = self._generate_data_description() 288 | 289 | def _load_all_datasets(self): 290 | """Automatically loads all pickle and CSV files found in data_path.""" 291 | pickle_files = glob(os.path.join(self.data_path, "*.pkl")) 292 | csv_files = glob(os.path.join(self.data_path, "*.csv")) 293 | 294 | if not pickle_files and not csv_files: 295 | raise ValueError(f"No pickle or CSV files found in {self.data_path}") 296 | 297 | for file_path in pickle_files + csv_files: 298 | file_name = os.path.basename(file_path) 299 | dataset_name = os.path.splitext(file_name)[0] 300 | df_name = f"df_{dataset_name}" 301 | if file_path.endswith('.pkl'): 302 | self.table_dict[df_name] = self._load_data(file_name) 303 | elif file_path.endswith('.csv'): 304 | self.table_dict[df_name] = pd.read_csv(file_path) 305 | 306 | def _load_data(self, file_name): 307 | """Helper method to load data from a pickle file.""" 308 | try: 309 | df = pd.read_pickle(os.path.join(self.data_path, file_name)) 310 | if not isinstance(df, pd.DataFrame): 311 | raise ValueError(f"File {file_name} does not contain a pandas DataFrame") 312 | return df 313 | except Exception as e: 314 | print(f"Error loading {file_name}: {str(e)}") 315 | return None 316 | 317 | def _generate_data_description(self): 318 | """Generates a description of each dataset's columns and the first row.""" 319 | desc = "" 320 | for name, df in self.table_dict.items(): 321 | if df is not None: 322 | desc += f"{name}:\nColumns: {df.columns.tolist()}\n" 323 | desc += f"Sample row: {dict(zip(df.columns, df.iloc[0]))}\n\n" 324 | return desc 325 | 326 | def get_data(self, table_name): 327 | """Returns the requested DataFrame.""" 328 | return self.table_dict.get(table_name, None) 329 | 330 | def load_into_globals(self): 331 | """Loads each dataset into the global namespace.""" 332 | for name, df in self.table_dict.items(): 333 | if df is not None: 334 | globals()[name] = df 335 | 336 | def display_data_description(self): 337 | """Prints the data description.""" 338 | print(self.data_desc) 339 | 340 | def permute_columns(self, dataset_name, columns_to_permute): 341 | """Permutes specified columns in a dataset. 342 | 343 | Args: 344 | dataset_name (str): Name of dataset (without 'df_' prefix) 345 | columns_to_permute (list): List of column names to permute 346 | """ 347 | df_name = f"df_{dataset_name}" 348 | df = self.table_dict.get(df_name, None) 349 | 350 | if df is None: 351 | raise ValueError(f"Dataset {df_name} not found") 352 | 353 | if not all(col in df.columns for col in columns_to_permute): 354 | raise ValueError(f"Not all columns {columns_to_permute} found in {df_name}") 355 | 356 | np.random.seed(self.random_seed) 357 | permuted_df = df[columns_to_permute].sample(frac=1).reset_index(drop=True) 358 | df[columns_to_permute] = permuted_df 359 | self.table_dict[df_name] = df -------------------------------------------------------------------------------- /popper/prompt_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | CODING_AGENT_SYSTEM_PROMPT_APPROX = '''You are an expert statistician specialized in the field of {domain}. You are tasked with validating a {domain} hypothesis (H) by collecting evidence supporting both the alternative hypothesis (h1) and the null hypothesis (h0). 4 | 5 | You should write code to gather, process, and analyze the available data, collecting evidence favoring both h1 and h0. 6 | The goal is to structure the evidence in a way that allows for a thorough and interpretable comparison, enabling an LLM to estimate the likelihood under both h1 and h0. 7 | 8 | The code should: 9 | - The output should be data/evidence, instead of test statistics. 10 | - Organize the evidence for h1 and h0 in a structured format, including metrics and qualitative descriptors. 11 | - Provide outputs that are interpretable, enabling easy comparison between the likelihoods of h1 and h0. 12 | 13 | You have access to the following pandas dataframe tables, where each table shows precise column names and an example row: 14 | 15 | {{context}} 16 | 17 | Write code based on the user’s request. Ensure any code provided is self-contained, executable, and includes all necessary imports and variable definitions. 18 | 19 | Structure your output as follows: 20 | 1) A brief summary of the approach, 21 | 2) The required imports, 22 | 3) The complete code. 23 | 24 | Include progress bars for lengthy processes where appropriate and optimize for time efficiency by using a small number of permutations (e.g., <1000) where relevant. 25 | Do not use placeholders. Each output should directly relate to the evidence under h1 and h0. 26 | 27 | The output should provide a comparison-ready format for h1 and h0. You should print the output in the code. 28 | -------------------------------------------- 29 | 30 | Here is the user-requested falsification test specification:''' 31 | 32 | 33 | CODING_AGENT_SYSTEM_PROMPT = """You are an expert statistician specialized in the field of {domain}. You are tasked to validate rigorously if a {domain} hypothesis H is true by implementing an falsification test proposed by the user. 34 | 35 | You should write code to implement the falsification test. 36 | The test should be relevant to the main hypothesis and aims to falsify it. 37 | The test should use the available data described below, and use data processing, extraction, and perform statistical analysis to produce a p-value measuring the falsification of the main hypothesis. 38 | The test should be extremely rigorous. The p-value should be theoretically grounded. 39 | The code should be clear, concise, and efficient. Do progress bar when necessary. It will have a time limit, so please be efficient. For example, if possible, you can set the number of permutations to be small (e.g. <1000). 40 | The code should be self-contained, and do not need additional modifications from user. 41 | 42 | You have access to the following pandas dataframe tables, where each table, it shows the precise column names and a preview of column values: 43 | 44 | {{context}} 45 | 46 | Each of these dataframes have already been loaded into the global namespace. You may access each dataframe **directly as variables**. Make sure to use the **EXACT** dataframe names as shown above. 47 | 48 | Create a code from the user request. Ensure any code you provide can be executed with all required imports and variables defined. 49 | Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. 50 | Invoke the code tool to structure the output correctly. 51 | NEVER PRODUCE ANY PLACEHOLDER IN ANY FUNCTION. PLACEHOLDER IS WORSE THAN FAILURE TO PRODUCE CODE. 52 | PLACEHOLDER including coming up with placeholder genes, names, ids, functions, p-value, or any other placeholder. 53 | The output should be a single p-value. If there are multiple p-values produced by the test, you should aggregate them in a meaningful and rigorous way. 54 | When printing p-values, please use scientific notations (e.g. 3.50e-03) instead of the raw number. 55 | For querying biological IDs, write code to look directly at raw datasets to map the exact ID, avoiding the use of LLMs to generate or infer gene names or IDs. Additionally, if the dataset includes p-values in its columns, refrain from using them as direct outputs of the falsification test; instead, process or contextualize them appropriately to maintain analytical rigor. 56 | ------------------------------------------------------- 57 | 58 | Here is the user requested falsification test specification:""" 59 | 60 | 61 | def get_coding_agent_system_prompt(llm_approx, domain="biology",): 62 | if llm_approx: 63 | return CODING_AGENT_SYSTEM_PROMPT_APPROX.format(domain=domain) 64 | return CODING_AGENT_SYSTEM_PROMPT.format(domain=domain) 65 | 66 | 67 | REACT_CODING_AGENT_SYSTEM_PROMPT = """You are an expert statistician specialized in the field of {domain}. Given a Falsification Test, your task is to determine if you can reject the null hypothesis via rigorous data analysis and statistical testing. 68 | 69 | You have access to multiple datasets relevant to the hypothesis, as well as a python code execution environment to run your fasification test. The code execution environment has a persistent global namespace, meaning that states and variable names will persist through multiple rounds of code executions. Be sure to take advantage of this by developing your falsification test incrementally and reflect on the intermediate observations at each step, instead of coding up everything in one go. All datasets have already been loaded into the global namespace as pandas dataframes.""" 70 | 71 | PROMPT_REVISION = """ 72 | For querying biological IDs, write code to look directly at raw datasets to map the exact ID, avoiding the use of LLMs to generate or infer gene names or IDs. Additionally, if the dataset includes p-values in its columns, refrain from using them as direct outputs of the falsification test; instead, process or contextualize them appropriately to maintain analytical rigor. 73 | """ 74 | 75 | def get_react_coding_agent_system_prompt(domain="biology", prompt_revision=False): 76 | if prompt_revision: 77 | return REACT_CODING_AGENT_SYSTEM_PROMPT.format(domain) + PROMPT_REVISION 78 | else: 79 | return REACT_CODING_AGENT_SYSTEM_PROMPT.format(domain=domain) 80 | 81 | 82 | LIKELIHOOD_ESTIMATION_AGENT_PROMPT = """Given a scientific hypothesis H, you have designed a sub-hypothesis test h to falsify the main hypothesis. You have also collected evidence from data for the null hypothesis (h0) and the alternative hypothesis (h1). 83 | 84 | Your goal is to: 85 | 1. Estimate the probability of this evidence under the alternative hypothesis, P(data|h1). 86 | 2. Estimate the probability of this evidence under the null hypothesis, P(data|h0). 87 | 88 | Follow this rigorous rubric to evaluate estimation precision, focusing on both theoretical grounding and accuracy in likelihood estimation: 89 | 90 | - **0.1**: Extremely poor estimate, lacks theoretical grounding; estimation is inconsistent with evidence and does not consider hypothesis structure. 91 | - **0.2**: Poor estimate; limited theoretical basis, fails to account for evidence specifics, and overlooks key elements of hypothesis testing. 92 | - **0.3**: Weak estimate, marginally considers evidence but lacks appropriate statistical measures or fails to apply probability theory accurately. 93 | - **0.4**: Below average; applies some basic probability theory but lacks rigor, poorly models the relationship between evidence and hypothesis. 94 | - **0.5**: Average estimate; applies probability theory minimally, captures some evidence but with limited specificity to the hypothesis context. 95 | - **0.6**: Above average; uses sound statistical principles, somewhat models the evidence-hypothesis relationship, but with notable gaps or simplifications. 96 | - **0.7**: Good estimate; well-grounded in theory, evidence is modeled with reasonable accuracy but lacks precision or depth in interpretation. 97 | - **0.8**: Very good estimate; rigorous application of probability theory, models evidence in the context of hypothesis well, with minor limitations in capturing uncertainty or alternative explanations. 98 | - **0.9**: Excellent estimate; highly accurate, theoretically sound, robustly interprets evidence under hypothesis, addressing key uncertainties and incorporating evidence nuances. 99 | - **1.0**: Perfect estimate; fully grounded in advanced probability theory, comprehensive and precise, accurately modeling all aspects of evidence given the hypothesis, leaving no uncertainties unaddressed. 100 | 101 | --- 102 | **Process**: 103 | - First, produce an initial estimate proposal. 104 | - In each round i, perform the following steps: 105 | 1. **Critique**: Evaluate the estimation’s reasonableness, theoretical rigor, and alignment with this rubric. 106 | 2. **Reflect**: Identify specific improvements to enhance accuracy and theoretical grounding based on critique. 107 | - If the estimation achieves a rigorous standard (e.g., reaching 0.9 or 1.0), return the final estimates: 108 | - P(data|h1) = [final value] 109 | - P(data|h0) = [final value] 110 | - If refinement is needed, improve or propose a new estimation, then proceed to the next round. 111 | 112 | --- 113 | **Information**: 114 | - Main Scientific Hypothesis H: 115 | {main_hypothesis} 116 | 117 | - Falsification Test Sub-Hypothesis h: 118 | {falsification_test} 119 | 120 | - Evidence: 121 | {data} 122 | """ 123 | 124 | def get_likelihood_estimation_agent_prompt(main_hypothesis, falsification_test, data): 125 | return LIKELIHOOD_ESTIMATION_AGENT_PROMPT.format(main_hypothesis=main_hypothesis, falsification_test=falsification_test, data=data) 126 | 127 | 128 | TEST_PROPOSAL_AGENT_SYSTEM_PROMPT = """You are an expert statistician specialized in the field of {domain}.""" 129 | 130 | TEST_PROPOSAL_AGENT_USER_PROMPT = ''' 131 | Given a {domain} hypothesis "{main_hypothesis}", your goal is to propose a novel falsification test given the available {domain} data sources. 132 | A falsification test is a test that can potentially falsify the main hypothesis. 133 | The outcome of the falsification test is to return a p-value that measures the evidence to falsify the main hypothesis. 134 | 135 | Notably, the falsification test should satisfy the following property: if the main hypotheiss is null, then the falsification sub-hypothesis should also be null. 136 | 137 | Here are the list of available data sources, and you can directly call the dataframe as it has already been loaded; no need to load from file path. Each is a pandas dataframe with columns and example rows: 138 | 139 | {data} 140 | 141 | For the final test, return 142 | (1) Name: name of the test 143 | (2) Test description: be clear and concise. Describe the falsification outcomes. 144 | (3) Null sub-hypothesis h_0: what is the statistical null sub-hypothesis does this falsification test aim to test? 145 | (4) Alternate sub-hypothesis h_1: what is the statistical alternative sub-hypothesis does this falsification test aim to test? 146 | 147 | Here are the falsification tests that you've created in the previous rounds and their corresponding test results: 148 | 149 | """ 150 | {existing_falsification_test} 151 | """ 152 | 153 | You may use these information to formulate your next subhypothesis and falsification test, but make sure the proposed falsification test is non-redundant with any of the existing tests. 154 | 155 | The proposed test should also avoid these failed falsification tests in the previous rounds: 156 | 157 | """ 158 | {failed_falsification_test} 159 | """ 160 | 161 | A good falsification test should serve as a strong evidence for the main hypothesis. However, make sure it is answerable with the given available data sources. 162 | You should aim to maximize the implication strength of the proposed falsification test using the relevant parts of the provided data. 163 | 164 | ---- 165 | First produce an initial falsification test proposal. 166 | 167 | Then, in each round i, you will do the following: 168 | (1) critic: ask if the main hypothesis is null, is this test also null? be rigorous. this is super important, otherwise, the test is invalid. Is it redundant on capabilities with existing tests? Is it overlapping with failed tests? Can this be answered and implemented based on the given data? 169 | (2) reflect: how to improve this test definition. 170 | 171 | If you think the test definition is good enough, return the final test definition to the user. 172 | If not, either refine the test definition that is better than the previous one or propose a new test definition, then go to the next round. 173 | ''' 174 | 175 | def get_test_proposal_agent_system_prompt(domain): 176 | return TEST_PROPOSAL_AGENT_SYSTEM_PROMPT.format(domain=domain) 177 | 178 | def get_test_proposal_agent_user_prompt(domain, main_hypothesis, data, existing_tests, failed_tests): 179 | return TEST_PROPOSAL_AGENT_USER_PROMPT.format(domain = domain, main_hypothesis = main_hypothesis, data = data, existing_falsification_test = existing_tests, failed_falsification_test = failed_tests) 180 | 181 | 182 | SUMMARIZER_SYSTEM_PROMPT = """You are a helpful assistant trained to help scientists summarize their experiment observations. 183 | You have observed a sequential falsification test procedure of a scientific hypothesis and your goal is to accurately summarize and extract insights to present to a human scientist. 184 | For the observed list of falsification tests, each test includes the test description and its test results. 185 | 186 | The final output should state the following: 187 | (1) The main scientific hypothesis under study 188 | (2) The result of the sequential falsification test 189 | (3) Reasoning, summarizing, and analyzing these results 190 | (4) Your conclusion on whether or not this hypothesis is true or false; just return True/False 191 | (5) Rationale of the conclusion 192 | 193 | Remember, your MUST STRICTLY ADHERE to the experiment observations WITHOUT your personal bias or interpretations. For example, if the experiments fail to reject the null hypothesis, you MUST output the conclusion as False EVEN IF YOU BELIEVE THE STATEMENT IS TRUE. 194 | """ 195 | 196 | def get_summarizer_system_prompt(): 197 | return SUMMARIZER_SYSTEM_PROMPT 198 | 199 | 200 | RELEVANCE_PROMPT = """ 201 | Given a main hypothesis and a proposed sub-hypothesis test, assess the relevance of this sub-hypothesis test to the main hypothesis. 202 | Use the following rubric to guide your response, providing a score from 0.1 to 1.0 and a brief justification for the score. 203 | Each score level represents a different degree of relevance based on evidence strength, mechanistic connection, and predictive value of the test results. 204 | 205 | Rubric: 206 | 207 | 1.0 - Highly Relevant: The sub-hypothesis provides direct evidence or a clear mechanistic insight that strongly supports or refutes the main hypothesis. The test is specific to variables or mechanisms involved in the main hypothesis, with significant predictive value. 208 | 0.8 - Strongly Relevant: The test addresses a major component of the main hypothesis, providing substantial supporting or refuting evidence, and shows strong mechanistic alignment. The results would significantly impact the confidence in the main hypothesis. 209 | 0.6 - Moderately Relevant: The test examines elements supporting the main hypothesis without direct mechanistic insight. Some aspects align with the main hypothesis, offering moderate predictive value. 210 | 0.4 - Slightly Relevant: The test is related to the main hypothesis but provides limited direct evidence. It explores loosely associated variables and has minimal predictive value. 211 | 0.2 - Barely Relevant: The test is tangentially related, providing minimal information that could impact the main hypothesis, with no clear mechanistic link and negligible predictive value. 212 | 0.1 - Irrelevant: The sub-hypothesis does not provide relevant evidence or mechanistic connection to the main hypothesis, with no predictive value. 213 | 214 | Instructions: 215 | 1. Read the main hypothesis and the sub-hypothesis test carefully. 216 | 2. Choose the relevance score from the rubric that best matches the relationship. 217 | 3. Explain your reasoning for selecting this score, referring to evidence strength, mechanistic connection, and predictive value of the sub-hypothesis test results. 218 | """ 219 | 220 | def get_relevance_prompt(): 221 | return RELEVANCE_PROMPT 222 | 223 | 224 | 225 | 226 | def bind_tools_to_system_prompt(system_prompt, tools): 227 | return f'''You are an intelligent agent capable of calling tools to complete user-assigned tasks. 228 | Here are the instructions specified by the user: 229 | """{system_prompt}""" 230 | 231 | In addition, you have access to the following tools: 232 | {json.dumps(tools, indent=4)} 233 | 234 | You may output any intermediate thoughts or reasonings before delivering your final response. 235 | Your final response must either be at least one tool call or a response message to the user. 236 | 237 | To make one or more tool calls, wrap your final response in the following JSON format: 238 | {{ 239 | "type": "tool_calls", 240 | "content": [ 241 | {{ 242 | "name": "name of the function to call", 243 | "id": "an unique id for this tool call", 244 | "arguments": {{ 245 | "argument1": value1, 246 | "argument2": value2, 247 | ... 248 | }} 249 | }}, 250 | ... 251 | ] 252 | }} 253 | 254 | To send a direct response message to the user, wrap your final response in the following JSON format: 255 | {{ 256 | "type": "text_message", 257 | "content": "content of the message according to the user instructions" 258 | }} 259 | 260 | You must choose either to send tool calls or a direct response message. Be sure to format the final response properly according to the given JSON specs. 261 | 262 | DO NOT put anything after the final response JSON object.''' -------------------------------------------------------------------------------- /baseline_agents/react_utils.py: -------------------------------------------------------------------------------- 1 | # Set up the base template 2 | from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser 3 | from langchain.prompts import StringPromptTemplate 4 | from langchain.tools import BaseTool 5 | from langchain.chains.llm import LLMChain 6 | from langchain_experimental.tools.python.tool import PythonAstREPLTool 7 | from typing import List, Union, Dict 8 | from langchain.schema import AgentAction, AgentFinish 9 | from pydantic import Field, PrivateAttr 10 | import contextlib 11 | import io 12 | import logging 13 | import re 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | template = """{system_prompt} 18 | 19 | You have access to the following tools: 20 | {tools} 21 | 22 | Use the following format: 23 | 24 | Question: an input hypothesis that you must decide if it is True or False 25 | Datasets: the names and descriptions of datasets relevant to the input hypothesis 26 | Thought: you should always think about what to do 27 | Action: the action to take, should be one of [{tool_names}] 28 | Action Input: the input to the action 29 | Observation: the result of the action 30 | ... (this Thought/Action/Action Input/Observation can repeat N times) 31 | Thought: I now know the final answer 32 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer 33 | Final Answer: True/False. Please output True if the input hypothesis is valid (e.g., you are able to reject the null hypothesis with statistical significance) and False if the input hypothesis is invalid (e.g., if you fail to reject the null hypothesis). 34 | 35 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as "WORKFLOW SUMMARY:" 36 | 37 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section. 38 | 39 | 40 | Example 41 | Question: Is the following hypothesis True or False? There is a statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates. 42 | Datasets: 43 | {{ 44 | "name": "df_body-size-evolution-in-south-american-freshwater-fishes", 45 | "description": "Data on body size evolution in South American freshwater fishes, including speciation and extinction rates", 46 | "columns": {{ 47 | "raw": [ 48 | {{ 49 | "name": "HYBAS_ID", 50 | "description": "Unique identifier for each hydrological basin" 51 | }}, 52 | {{ 53 | "name": "long", 54 | "description": "Longitude of the basin location" 55 | }}, 56 | {{ 57 | "name": "lat", 58 | "description": "Latitude of the basin location" 59 | }}, 60 | {{ 61 | "name": "BAMM_speciation", 62 | "description": "Rate of speciation as calculated by the BAMM method" 63 | }}, 64 | {{ 65 | "name": "BAMM_extinction", 66 | "description": "Rate of extinction as calculated by the BAMM method" 67 | }}, 68 | {{ 69 | "name": "BAMM_NetDiv", 70 | "description": "Net diversification rate, calculated as speciation minus extinction" 71 | }}, 72 | {{ 73 | "name": "aet", 74 | "description": "Mean annual evapotranspiration for each basin" 75 | }}, 76 | {{ 77 | "name": "Elevation", 78 | "description": "Average elevation of the basin" 79 | }}, 80 | {{ 81 | "name": "sgr", 82 | "description": "Species growth rate in each basin" 83 | }}, 84 | {{ 85 | "name": "soil_div", 86 | "description": "Soil diversity index for each basin" 87 | }}, 88 | {{ 89 | "name": "area", 90 | "description": "Total area of the basin in square kilometers" 91 | }}, 92 | {{ 93 | "name": "diversity", 94 | "description": "Diversity index for the species in each basin" 95 | }} 96 | ] 97 | }} 98 | }} 99 | Thought: First, I need to load the dataset from the global namespace in Python and inspect the data to identify the relevant columns for this hypothesis test. 100 | Action: python_repl_ast 101 | Action Input: import pandas as pd\n\ndf = df_body-size-evolution-in-south-american-freshwater-fishes\ndf.head() 102 | Observation: 103 | HYBAS_ID long lat BAMM_speciation BAMM_extinction BAMM_NetDiv ... aet Elevation sgr soil_div area diversity 104 | 0 6050000010 -76.477422 7.742693 0.137392 0.026807 0.110585 ... 1387 330.150088 166 0.482402 72363.7 68 105 | 1 6050000740 -74.628725 9.803586 0.117235 0.025796 0.091438 ... 1082 69.475294 23 0.457436 17944.3 35 106 | 2 6050068100 -75.295995 8.448815 0.119381 0.023826 0.095555 ... 1312 143.032178 74 0.378793 17105.5 44 107 | 3 6050068110 -74.608408 8.922863 0.132477 0.027777 0.104700 ... 1445 14.724138 3 0.468328 610.1 48 108 | 4 6050070260 -75.591588 5.770093 0.120127 0.022940 0.097187 ... 1371 1378.729945 421 0.158870 61901.9 81 109 | [5 rows x 21 columns] 110 | Thought: Now that the dataset is loaded and I can see the columns, I need to perform a statistical test to assess the significance of the relationship between 'BAMM_speciation' and 'BAMM_NetDiv'. 111 | Action: python_repl_ast 112 | Action Input: from scipy.stats import linregress\n\n# Perform linear regression to test for a statistically significant relationship\nresult = linregress(df['BAMM_speciation'], df['BAMM_NetDiv'])\ncoefficient = result.slope\np_value = result.pvalue\ncoefficient, p_value 113 | Observation: (0.5175306498596297, 0.0035) 114 | Thought: The linear regression analysis provides a coefficient of approximately 0.518, indicating a positive relationship, and the p-value is 0.0035, which is statistically significant at the 0.05 level. Based on this, I can conclude that the hypothesis is true. 115 | WORKFLOW SUMMARY: 116 | 1. Data Loading: Loaded the dataset from the global namespace using Python. 117 | 2. Data Inspection: Displayed the first few rows of the dataset to confirm relevant columns. 118 | 3. Statistical Analysis: Performed a linear regression analysis between 'BAMM_speciation' (predictor) and 'BAMM_NetDiv' (response). The analysis yielded a positive coefficient and a p-value of 0.0035, indicating statistical significance. 119 | Final Answer: True 120 | 121 | 122 | Begin! 123 | 124 | {input} 125 | {agent_scratchpad}""" 126 | 127 | 128 | template_v2 = """{system_prompt} 129 | 130 | You have access to the following tools: 131 | {tools} 132 | 133 | Use the following format: 134 | 135 | Question: an input hypothesis that you must decide if it is True or False 136 | Datasets: the names and descriptions of datasets relevant to the input hypothesis 137 | Thought: you should always think about what to do 138 | Action: the action to take, should be one of [{tool_names}] 139 | Action Input: the input to the action 140 | Observation: the result of the action 141 | ... (this Thought/Action/Action Input/Observation can repeat N times) 142 | Thought: I now know the final answer 143 | WORKFLOW SUMMARY: this is the workflow that I used to find the final answer 144 | Final Answer: True/False. Please output True if you believe the input hypothesis is correct and False if the input hypothesis is not based on your analysis. 145 | 146 | Please make sure the Final Answer is either True or False. Also generate a summary of the full workflow starting from data loading that led to the final answer as "WORKFLOW SUMMARY:" 147 | 148 | IMPORTANT: all datasets have already been loaded into the global namespace as Pandas dataframes. You may access the data by referring to the EXACT dataframe names as provided in the "Datasets:" section. 149 | 150 | 151 | Example 152 | Question: Is the following hypothesis True or False? There is a statistically significant positive relationship between the rate of maximum body length evolution and spatial variation in speciation rates. 153 | Datasets: 154 | {{ 155 | "name": "df_body-size-evolution-in-south-american-freshwater-fishes", 156 | "description": "Data on body size evolution in South American freshwater fishes, including speciation and extinction rates", 157 | "columns": {{ 158 | "raw": [ 159 | {{ 160 | "name": "HYBAS_ID", 161 | "description": "Unique identifier for each hydrological basin" 162 | }}, 163 | {{ 164 | "name": "long", 165 | "description": "Longitude of the basin location" 166 | }}, 167 | {{ 168 | "name": "lat", 169 | "description": "Latitude of the basin location" 170 | }}, 171 | {{ 172 | "name": "BAMM_speciation", 173 | "description": "Rate of speciation as calculated by the BAMM method" 174 | }}, 175 | {{ 176 | "name": "BAMM_extinction", 177 | "description": "Rate of extinction as calculated by the BAMM method" 178 | }}, 179 | {{ 180 | "name": "BAMM_NetDiv", 181 | "description": "Net diversification rate, calculated as speciation minus extinction" 182 | }}, 183 | {{ 184 | "name": "aet", 185 | "description": "Mean annual evapotranspiration for each basin" 186 | }}, 187 | {{ 188 | "name": "Elevation", 189 | "description": "Average elevation of the basin" 190 | }}, 191 | {{ 192 | "name": "sgr", 193 | "description": "Species growth rate in each basin" 194 | }}, 195 | {{ 196 | "name": "soil_div", 197 | "description": "Soil diversity index for each basin" 198 | }}, 199 | {{ 200 | "name": "area", 201 | "description": "Total area of the basin in square kilometers" 202 | }}, 203 | {{ 204 | "name": "diversity", 205 | "description": "Diversity index for the species in each basin" 206 | }} 207 | ] 208 | }} 209 | }} 210 | Thought: First, I need to load the dataset from the global namespace in Python and inspect the data to identify the relevant columns for this hypothesis test. 211 | Action: python_repl_ast 212 | Action Input: import pandas as pd\n\ndf = df_body-size-evolution-in-south-american-freshwater-fishes\ndf.head() 213 | Observation: 214 | HYBAS_ID long lat BAMM_speciation BAMM_extinction BAMM_NetDiv ... aet Elevation sgr soil_div area diversity 215 | 0 6050000010 -76.477422 7.742693 0.137392 0.026807 0.110585 ... 1387 330.150088 166 0.482402 72363.7 68 216 | 1 6050000740 -74.628725 9.803586 0.117235 0.025796 0.091438 ... 1082 69.475294 23 0.457436 17944.3 35 217 | 2 6050068100 -75.295995 8.448815 0.119381 0.023826 0.095555 ... 1312 143.032178 74 0.378793 17105.5 44 218 | 3 6050068110 -74.608408 8.922863 0.132477 0.027777 0.104700 ... 1445 14.724138 3 0.468328 610.1 48 219 | 4 6050070260 -75.591588 5.770093 0.120127 0.022940 0.097187 ... 1371 1378.729945 421 0.158870 61901.9 81 220 | [5 rows x 21 columns] 221 | Thought: Now that the dataset is loaded and I can see the columns, I need to perform a statistical test to assess the significance of the relationship between 'BAMM_speciation' and 'BAMM_NetDiv'. 222 | Action: python_repl_ast 223 | Action Input: from scipy.stats import linregress\n\n# Perform linear regression to test for a statistically significant relationship\nresult = linregress(df['BAMM_speciation'], df['BAMM_NetDiv'])\ncoefficient = result.slope\np_value = result.pvalue\ncoefficient, p_value 224 | Observation: (0.5175306498596297, 0.0035) 225 | Thought: The linear regression analysis provides a coefficient of approximately 0.518, indicating a positive relationship, and the p-value is 0.0035, which is statistically significant at the 0.05 level. Based on this, I can conclude that the hypothesis is true. 226 | WORKFLOW SUMMARY: 227 | 1. Data Loading: Loaded the dataset from the global namespace using Python. 228 | 2. Data Inspection: Displayed the first few rows of the dataset to confirm relevant columns. 229 | 3. Statistical Analysis: Performed a linear regression analysis between 'BAMM_speciation' (predictor) and 'BAMM_NetDiv' (response). The analysis yielded a positive coefficient and a p-value of 0.0035, indicating statistical significance. 230 | Final Answer: True 231 | 232 | 233 | Begin! 234 | 235 | {input} 236 | {agent_scratchpad}""" 237 | 238 | 239 | def load_data_to_react_globals(data_loader): 240 | for name, df in data_loader.table_dict.items(): 241 | globals()[name] = df 242 | 243 | 244 | # Set up a prompt template 245 | class CustomPromptTemplate(StringPromptTemplate): 246 | # The template to use 247 | template: str 248 | # The list of tools available 249 | tools: List[BaseTool] 250 | 251 | def format(self, **kwargs) -> str: 252 | # Get the intermediate steps (AgentAction, Observation tuples) 253 | # Format them in a particular way 254 | intermediate_steps = kwargs.pop("intermediate_steps") 255 | thoughts = "" 256 | for action, observation in intermediate_steps: 257 | thoughts += action.log 258 | thoughts += f"\nObservation: {observation}\nThought: " 259 | # Set the agent_scratchpad variable to that value 260 | kwargs["agent_scratchpad"] = thoughts 261 | # Create a tools variable from the list of tools provided 262 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) 263 | # Create a list of tool names for the tools provided 264 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) 265 | return self.template.format(**kwargs) 266 | 267 | # CustomOutputParser to parse the output of the LLM and execute actions 268 | class CustomOutputParser(AgentOutputParser): 269 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: 270 | # Check if agent should finish 271 | if "Final Answer:" in llm_output: 272 | output = llm_output.split("Final Answer:")[-1].split()[0].strip().lower() 273 | if output not in ["true", "false", "yes", "no", "y", "n"]: 274 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 275 | return AgentFinish( 276 | return_values={"output": output in ["true", "yes", 'y']}, 277 | log=llm_output, 278 | ) 279 | # Parse out the action and action input 280 | regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" 281 | match = re.search(regex, llm_output, re.DOTALL) 282 | if not match: 283 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 284 | action = match.group(1).strip() 285 | action_input = match.group(2) 286 | # Return the action and action input 287 | return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) 288 | 289 | 290 | class CustomPythonAstREPLTool(PythonAstREPLTool): 291 | _exec_globals:Dict = PrivateAttr() 292 | def __init__(self, *args, **kwargs): 293 | super().__init__(*args, **kwargs) 294 | # Initialize a persistent global namespace for code execution 295 | self._exec_globals = {} 296 | self._exec_globals.update(__builtins__) 297 | 298 | def _set_globals(self, table_dict=None): 299 | self._exec_globals = {} 300 | self._exec_globals.update(__builtins__) 301 | 302 | if table_dict: 303 | self._exec_globals.update(table_dict) 304 | 305 | def _run(self, query: str, run_manager=None): 306 | code_match = re.search(r"```(.*?)```", query, re.DOTALL) 307 | if code_match: 308 | # Extract code within backticks 309 | code = code_match.group(1) 310 | else: 311 | code = query 312 | code = code.strip() 313 | if code.startswith("python"): 314 | code = code[len("python"):].lstrip() 315 | 316 | code_lines = code.strip().split('\n') 317 | code = '\n'.join(code_lines[:-1]) # avoid printing the last line twice 318 | last_line = code_lines[-1] 319 | 320 | output_capture = io.StringIO() 321 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture): 322 | logging.getLogger().handlers[0].stream = output_capture 323 | try: 324 | exec(code, self._exec_globals) 325 | try: 326 | result = eval(last_line, self._exec_globals) 327 | if result is not None: 328 | print(result, file=output_capture) 329 | except: 330 | pass 331 | except Exception as e: 332 | return str(e) 333 | 334 | # Retrieve the output and return it 335 | output = output_capture.getvalue() 336 | return output if output else "Execution completed without output." 337 | 338 | 339 | def create_agent( 340 | llm, 341 | handlers, 342 | max_iterations = None, 343 | early_stopping_method: str = "force", 344 | simple_template = False 345 | ): 346 | output_parser = CustomOutputParser() 347 | python_tool = CustomPythonAstREPLTool(callbacks=handlers) 348 | tools = [python_tool] 349 | tool_names = [tool.name for tool in tools] 350 | 351 | if simple_template: 352 | use_template = template_v2 353 | else: 354 | use_template = template 355 | 356 | prompt = CustomPromptTemplate( 357 | template=use_template, 358 | tools=tools, 359 | input_variables=["system_prompt", "input", "intermediate_steps"] 360 | ) 361 | llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=handlers) 362 | 363 | agent = LLMSingleActionAgent( 364 | llm_chain=llm_chain, 365 | output_parser=output_parser, 366 | stop=["\nObservation:"], 367 | allowed_tools=tool_names 368 | ) 369 | 370 | return AgentExecutor.from_agent_and_tools( 371 | agent=agent, 372 | tools=tools, 373 | verbose=True, 374 | max_iterations=max_iterations, 375 | callbacks=handlers, 376 | early_stopping_method=early_stopping_method 377 | ) -------------------------------------------------------------------------------- /popper/llm/custom_model.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import sys 4 | import uuid 5 | import json 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Any, 9 | AsyncIterator, 10 | Callable, 11 | Dict, 12 | Iterator, 13 | List, 14 | Literal, 15 | Mapping, 16 | Optional, 17 | Sequence, 18 | Tuple, 19 | Type, 20 | Union, 21 | ) 22 | 23 | from langchain_core.callbacks import ( 24 | AsyncCallbackManagerForLLMRun, 25 | CallbackManagerForLLMRun, 26 | ) 27 | from langchain_core.messages import ( 28 | AIMessage, 29 | AIMessageChunk, 30 | BaseMessage, 31 | BaseMessageChunk, 32 | ChatMessage, 33 | ChatMessageChunk, 34 | FunctionMessage, 35 | FunctionMessageChunk, 36 | HumanMessage, 37 | HumanMessageChunk, 38 | InvalidToolCall, 39 | SystemMessage, 40 | SystemMessageChunk, 41 | ToolCall, 42 | ToolMessage, 43 | ToolMessageChunk, 44 | ) 45 | from langchain_core.language_models import BaseChatModel, SimpleChatModel, LanguageModelInput 46 | from langchain_core.language_models.llms import create_base_retry_decorator 47 | from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage 48 | from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult 49 | # from langchain_core.pydantic import BaseModel, Field 50 | from pydantic import BaseModel, Field, PrivateAttr 51 | from langchain_core.runnables import run_in_executor, Runnable 52 | from langchain_core.output_parsers.openai_tools import ( 53 | JsonOutputKeyToolsParser, 54 | PydanticToolsParser, 55 | make_invalid_tool_call, 56 | parse_tool_call, 57 | ) 58 | from langchain_core.utils import secret_from_env 59 | from langchain_core.utils.function_calling import convert_to_openai_tool 60 | from langchain_core.tools import BaseTool 61 | # from langchain_community.adapters.openai import ( 62 | # convert_dict_to_message, 63 | # convert_message_to_dict, 64 | # ) 65 | from langchain_openai import ChatOpenAI 66 | # from langchain_openai.chat_models import _convert_dict_to_message 67 | # from langchain_community.chat_models.openai import ChatOpenAI 68 | from popper.llm.prompt_utils import bind_tools_to_system_prompt 69 | from popper.llm.utils import parse_llm_output 70 | 71 | 72 | def _convert_message_to_dict(message: BaseMessage) -> dict: 73 | """Convert a LangChain message to a dictionary. 74 | 75 | Args: 76 | message: The LangChain message. 77 | 78 | Returns: 79 | The dictionary. 80 | """ 81 | message_dict: Dict[str, Any] 82 | if isinstance(message, ChatMessage): 83 | message_dict = {"role": message.role, "content": message.content} 84 | elif isinstance(message, HumanMessage): 85 | message_dict = {"role": "user", "content": message.content} 86 | elif isinstance(message, AIMessage): 87 | message_dict = {"role": "assistant", "content": message.content} 88 | if "function_call" in message.additional_kwargs: 89 | message_dict["function_call"] = message.additional_kwargs["function_call"] 90 | # If function call only, content is None not empty string 91 | if message_dict["content"] == "": 92 | message_dict["content"] = None 93 | if "tool_calls" in message.additional_kwargs: 94 | message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] 95 | # If tool calls only, content is None not empty string 96 | if message_dict["content"] == "": 97 | message_dict["content"] = None 98 | if "full_message" in message.additional_kwargs and message.additional_kwargs["full_message"]: 99 | # full_message is not empty, replace the current content with full message 100 | message_dict["content"] = message.additional_kwargs["full_message"] 101 | elif isinstance(message, SystemMessage): 102 | message_dict = {"role": "system", "content": message.content} 103 | elif isinstance(message, FunctionMessage): 104 | message_dict = { 105 | "role": "function", 106 | "content": message.content, 107 | "name": message.name, 108 | } 109 | elif isinstance(message, ToolMessage): 110 | run_results = { 111 | "name": message.name, 112 | "id": message.tool_call_id, 113 | "return_value": message.content 114 | } 115 | message_dict = { 116 | "role": "user", 117 | "content": "Tool call results:\n" + json.dumps(run_results, indent=4), 118 | "tool_call_id": message.tool_call_id, 119 | } 120 | else: 121 | raise TypeError(f"Got unknown type {message}") 122 | if "name" in message.additional_kwargs: 123 | message_dict["name"] = message.additional_kwargs["name"] 124 | return message_dict 125 | 126 | 127 | def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: 128 | """Convert a dictionary to a LangChain message. 129 | 130 | Args: 131 | _dict: The dictionary. 132 | 133 | Returns: 134 | The LangChain message. 135 | """ 136 | role = _dict.get("role") 137 | name = _dict.get("name") 138 | id_ = _dict.get("id") 139 | if role == "user": 140 | return HumanMessage(content=_dict.get("content", ""), id=id_, name=name) 141 | elif role == "assistant": 142 | # Fix for azure 143 | # Also OpenAI returns None for tool invocations 144 | content = _dict.get("content", "") or "" 145 | additional_kwargs: Dict = {} 146 | if function_call := _dict.get("function_call"): 147 | additional_kwargs["function_call"] = dict(function_call) 148 | if full_message := _dict.get("full_message"): 149 | additional_kwargs["full_message"] = full_message 150 | tool_calls = [] 151 | invalid_tool_calls = [] 152 | if raw_tool_calls := _dict.get("tool_calls"): 153 | additional_kwargs["tool_calls"] = raw_tool_calls 154 | for raw_tool_call in raw_tool_calls: 155 | try: 156 | tool_calls.append(ToolCall(name=raw_tool_call['function']['name'], args=raw_tool_call['function']['arguments'], id=raw_tool_call['id'], type="tool_call")) 157 | except Exception as e: 158 | invalid_tool_calls.append( 159 | make_invalid_tool_call(str(raw_tool_call), str(e)) 160 | ) 161 | # print(tool_calls) 162 | # print() 163 | # print(invalid_tool_calls) 164 | return AIMessage( 165 | content=content, 166 | additional_kwargs=additional_kwargs, 167 | name=name, 168 | id=id_, 169 | tool_calls=tool_calls, 170 | invalid_tool_calls=invalid_tool_calls, 171 | ) 172 | elif role == "system": 173 | return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) 174 | elif role == "function": 175 | return FunctionMessage( 176 | content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_ 177 | ) 178 | elif role == "tool": 179 | additional_kwargs = {} 180 | if "name" in _dict: 181 | additional_kwargs["name"] = _dict["name"] 182 | return ToolMessage( 183 | content=_dict.get("content", ""), 184 | tool_call_id=cast(str, _dict.get("tool_call_id")), 185 | additional_kwargs=additional_kwargs, 186 | name=name, 187 | id=id_, 188 | ) 189 | else: 190 | return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type] 191 | 192 | class CustomChatModel(ChatOpenAI): 193 | model_type: str = Field(default="custom-chat") 194 | tools: Optional[List[Any]] = Field(default=None) 195 | tool_choice: Optional[Union[dict, str, Literal["auto", "none", "required", "any"], bool]] = Field(default=None) 196 | 197 | @property 198 | def lc_secrets(self) -> Dict[str, str]: 199 | return {"openai_api_key": "EMPTY"} 200 | 201 | @property 202 | def _llm_type(self) -> str: 203 | """Return type of chat model.""" 204 | return self.model_type + "-chat" 205 | 206 | def bind_tools( 207 | self, 208 | tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], 209 | tool_choice: Optional[ 210 | Union[dict, str, Literal["auto", "none", "required", "any"], bool] 211 | ] = None, 212 | strict: Optional[bool] = None, 213 | **kwargs: Any, 214 | ) -> Runnable[LanguageModelInput, BaseMessage]: 215 | """Bind tool-like objects to this chat model. 216 | 217 | Assumes model is compatible with OpenAI tool-calling API. 218 | 219 | Args: 220 | tools: A list of tool definitions to bind to this chat model. 221 | Supports any tool definition handled by 222 | :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. 223 | tool_choice: Which tool to require the model to call. 224 | Must be the name of the single provided function or 225 | "auto" to automatically determine which function to call 226 | (if any), or a dict of the form: 227 | {"type": "function", "function": {"name": <>}}. 228 | kwargs: Any additional parameters are passed directly to 229 | ``self.bind(**kwargs)``. 230 | """ 231 | formatted_tools = [convert_to_openai_tool(tool, strict=strict) for tool in tools] 232 | 233 | if tool_choice: 234 | if isinstance(tool_choice, str): 235 | # tool_choice is a tool/function name 236 | if tool_choice not in ("auto", "none", "any", "required"): 237 | tool_choice = { 238 | "type": "function", 239 | "function": {"name": tool_choice}, 240 | } 241 | # 'any' is not natively supported by OpenAI API. 242 | # We support 'any' since other models use this instead of 'required'. 243 | if tool_choice == "any": 244 | tool_choice = "required" 245 | elif isinstance(tool_choice, bool): 246 | tool_choice = "required" 247 | elif isinstance(tool_choice, dict): 248 | tool_names = [ 249 | formatted_tool["function"]["name"] 250 | for formatted_tool in formatted_tools 251 | ] 252 | if not any( 253 | tool_name == tool_choice["function"]["name"] 254 | for tool_name in tool_names 255 | ): 256 | raise ValueError( 257 | f"Tool choice {tool_choice} was specified, but the only " 258 | f"provided tools were {tool_names}." 259 | ) 260 | else: 261 | raise ValueError( 262 | f"Unrecognized tool_choice type. Expected str, bool or dict. " 263 | f"Received: {tool_choice}" 264 | ) 265 | kwargs["tool_choice"] = tool_choice 266 | 267 | self.tools = formatted_tools 268 | self.tool_choice = tool_choice 269 | return super().bind(tools=formatted_tools, **kwargs) 270 | 271 | 272 | def bind_functions( 273 | self, 274 | functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], 275 | function_call: Optional[str] = None, 276 | **kwargs: Any, 277 | ) -> Runnable[LanguageModelInput, BaseMessage]: 278 | """Bind functions (and other objects) to this chat model. 279 | 280 | Args: 281 | functions: A list of function definitions to bind to this chat model. 282 | Can be a dictionary, pydantic model, or callable. Pydantic 283 | models and callables will be automatically converted to 284 | their schema dictionary representation. 285 | function_call: Which function to require the model to call. 286 | Must be the name of the single provided function or 287 | "auto" to automatically determine which function to call 288 | (if any). 289 | kwargs: Any additional parameters to pass to the 290 | :class:`~langchain.runnable.Runnable` constructor. 291 | """ 292 | from langchain.chains.openai_functions.base import convert_to_openai_function 293 | 294 | formatted_functions = [convert_to_openai_function(fn) for fn in functions] 295 | self.tools = formatted_tools 296 | if function_call is not None: 297 | if len(formatted_functions) != 1: 298 | raise ValueError( 299 | "When specifying `function_call`, you must provide exactly one " 300 | "function." 301 | ) 302 | if formatted_functions[0]["name"] != function_call: 303 | raise ValueError( 304 | f"Function call {function_call} was specified, but the only " 305 | f"provided function was {formatted_functions[0]['name']}." 306 | ) 307 | function_call_ = {"name": function_call} 308 | kwargs = {**kwargs, "function_call": function_call_} 309 | return super().bind( 310 | functions=formatted_functions, 311 | **kwargs, 312 | ) 313 | 314 | 315 | def _get_request_payload( 316 | self, 317 | input_: LanguageModelInput, 318 | *, 319 | stop: Optional[List[str]] = None, 320 | **kwargs: Any, 321 | ) -> dict: 322 | 323 | if "tools" in kwargs: 324 | self.tools = kwargs['tools'] 325 | if "tool_choice" in kwargs: 326 | self.tool_choice = kwargs["tool_choice"] 327 | 328 | messages = self._convert_input(input_).to_messages() 329 | if stop is not None: 330 | kwargs["stop"] = stop 331 | 332 | message_dicts = [_convert_message_to_dict(m) for m in messages] 333 | if self.tools: 334 | has_system_prompt = False 335 | for msg in message_dicts: 336 | if msg['role'] == 'system': 337 | system_prompt = msg['content'] 338 | msg['content'] = bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice) 339 | has_system_prompt = True 340 | break 341 | if not has_system_prompt: 342 | system_prompt = "You are a helpful assistant" 343 | message_dicts = [{ 344 | 'role': 'system', 345 | 'content': bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice), 346 | }] + message_dicts 347 | 348 | if self.tool_choice is not None and message_dicts[-1]['role'] == 'user': 349 | last_user_message = message_dicts[-1]['content'] 350 | message_dicts[-1]['content'] = f"""{last_user_message} 351 | 352 | Remember to format your reponse as a call to one of the following tools: 353 | {json.dumps(self.tools, indent=4)} 354 | Your tool call should have the following JSON format: 355 | following JSON format: 356 | {{ 357 | "type": "tool_calls", 358 | "content": [ 359 | {{ 360 | "name": "name of the function to call", 361 | "id": "an unique id for this tool call", 362 | "arguments": {{ 363 | "argument1": value1, 364 | "argument2": value2, 365 | ... 366 | }} 367 | }}, 368 | ... 369 | ] 370 | }} 371 | """ 372 | # print(message_dicts) 373 | 374 | return { 375 | "messages": message_dicts, 376 | **self._default_params, 377 | **kwargs, 378 | } 379 | 380 | def _create_message_dicts( 381 | self, messages: List[BaseMessage], stop: Optional[List[str]] 382 | ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: 383 | params = self._client_params 384 | if stop is not None: 385 | if "stop" in params: 386 | raise ValueError("`stop` found in both the input and default params.") 387 | params["stop"] = stop 388 | message_dicts = [_convert_message_to_dict(m) for m in messages] 389 | if self.tools: 390 | has_system_prompt = False 391 | for msg in message_dicts: 392 | if msg['role'] == 'system': 393 | system_prompt = msg['content'] 394 | msg['content'] = bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice) 395 | has_system_prompt = True 396 | break 397 | if not has_system_prompt: 398 | system_prompt = "You are a helpful assistant" 399 | message_dicts = [{ 400 | 'role': 'system', 401 | 'content': bind_tools_to_system_prompt(system_prompt, self.tools, self.tool_choice), 402 | }] + message_dicts 403 | 404 | return message_dicts, params 405 | 406 | def _create_chat_result(self, response: Union[dict, BaseModel], generation_info: Optional[Dict] = None) -> ChatResult: 407 | generations = [] 408 | if not isinstance(response, dict): 409 | response = response.dict() 410 | for res in response["choices"]: 411 | # print(res) 412 | if self.tools: 413 | # attempt to parse the tool calls 414 | full_message = res["message"]["content"] 415 | scratchpad, parsed_message = parse_llm_output(full_message) 416 | if parsed_message['type'] == 'text_message': 417 | # res["message"]['full_message'] = full_message 418 | res["message"]["content"] = parsed_message["content"] 419 | else: 420 | assert parsed_message['type'] == 'tool_calls' 421 | tool_calls = [] 422 | for tool_call in parsed_message['content']: 423 | if 'id' not in tool_call: 424 | tool_call['id'] = 'call_' + str(uuid.uuid4()) 425 | tool_calls.append({ 426 | 'id': tool_call['id'], 427 | 'type': 'function', 428 | 'function': { 429 | 'name': tool_call['name'], 430 | 'arguments': tool_call['arguments'] 431 | } 432 | }) 433 | res["message"]["tool_calls"] = tool_calls 434 | res["message"]['content'] = None 435 | res['finish_reason'] = 'tool_calls' 436 | res["message"]['full_message'] = scratchpad + json.dumps(parsed_message, indent=4) 437 | message = _convert_dict_to_message(res["message"]) 438 | generation_info = dict(finish_reason=res.get("finish_reason")) 439 | if "logprobs" in res: 440 | generation_info["logprobs"] = res["logprobs"] 441 | gen = ChatGeneration( 442 | message=message, 443 | generation_info=generation_info, 444 | ) 445 | generations.append(gen) 446 | # print(message) 447 | token_usage = response.get("usage", {}) 448 | llm_output = { 449 | "token_usage": token_usage, 450 | "model_name": self.model_name, 451 | "system_fingerprint": response.get("system_fingerprint", ""), 452 | } 453 | return ChatResult(generations=generations, llm_output=llm_output) -------------------------------------------------------------------------------- /baseline_agents/self_refine_agent.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import logging 4 | import traceback 5 | import multiprocessing 6 | import re 7 | import sys 8 | import numpy as np 9 | 10 | from typing import Optional, Tuple, Dict, Any, List, Literal 11 | from typing_extensions import TypedDict, Annotated 12 | from pydantic import BaseModel, Field 13 | 14 | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder 15 | from langchain_core.messages.base import get_msg_title_repr 16 | from langchain_core.utils.interactive_env import is_interactive_env 17 | from langgraph.graph import StateGraph, START, END 18 | 19 | import openai 20 | from falsification_agent.llm.custom_model import CustomChatModel 21 | from langchain_anthropic import ChatAnthropic 22 | from langchain_ollama import ChatOllama 23 | from langchain_openai import ChatOpenAI 24 | 25 | logging.getLogger("httpx").setLevel(logging.WARNING) 26 | 27 | 28 | CODING_AGENT_SYSTEM_PROMPT = """You are an expert statistician. You are tasked to validate rigorously if a scientific hypothesis H is true by implementing a rigorous statistical test. 29 | 30 | You should write code to implement a falsification test for the given hypothesis. 31 | The test should be relevant to the main hypothesis and aims to falsify it. 32 | The test should use the available data described below, and use data processing, extraction, and perform statistical analysis to produce a p-value measuring the falsification of the main hypothesis. 33 | The test should be extremely rigorous. The p-value should be theoretically grounded. 34 | The code should be clear, concise, and efficient. Do progress bar when necessary. It will have a time limit, so please be efficient. For example, if possible, you can set the number of permutations to be small (e.g. <1000). 35 | The code should be self-contained, and do not need additional modifications from user. 36 | 37 | Create a code from the user request. Ensure any code you provide can be executed with all required imports and variables defined. 38 | Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. 39 | Invoke the CodeOutputSpec tool to structure the output correctly. 40 | NEVER PRODUCE ANY PLACEHOLDER IN ANY FUNCTION. PLACEHOLDER IS WORSE THAN FAILURE TO PRODUCE CODE. 41 | PLACEHOLDER including coming up with placeholder genes, names, ids, functions, p-value, or any other placeholder. 42 | The output should be a single p-value. If there are multiple p-values produced by the test, you should aggregate them in a meaningful and rigorous way. 43 | When printing p-values, please use scientific notations (e.g. 3.50e-03) instead of the raw number. 44 | ------------------------------------------------------- 45 | 46 | Here is the user requested falsification test specification:""" 47 | 48 | FEEDBACK_AGENT_SYSTEM_PROMPT = """You are an evaluateGPT. Given a test specification, an implementation of the test, and the results from executing the implemented test, your task is to evaluate if the test implementation and output is valid and provide detailed feedback on any identified issues and suggestions for improving the test. 49 | 50 | To evaluate the validity of the test, you should 51 | 1. Check if the test implementation strictly follows the test specification 52 | 2. Make sure the output shows a valid p-value without any errors. The p-value should have a reasonable value: it cannot be smaller than or equal to 0 or larger than 1. 53 | 3. Double-check that all data inputs used in the experiment are accessed from the provided data sources, and there are no fake/made-up data entries. 54 | 4. Examine carefully through the experiment implementation; make sure it uses rigorous statistical test and there are no bugs or logical issues in the code. 55 | 5. Carefully examine any other potential problems of the experiment, such as unhandled edge-cases, invalid sample sizes, or other subtle issues that might lead to a misleading result. 56 | 57 | If you find any problems with the test implementation or experiment output, please provide a detailed feedback on how to fix and refine the test. If the test is valid, please output the final p-value formatted in scientific notation.""" 58 | 59 | def get_llm(model = 'claude-3-5-sonnet-20240620', temperature=0.7, **kwargs): 60 | source = "custom" 61 | if model[:7] == 'claude-': 62 | source = 'Anthropic' 63 | elif model[:4] == 'gpt-': 64 | source = 'OpenAI' 65 | elif model.startswith('llama'): 66 | source = "Llama" 67 | # if source not in ['OpenAI', 'Anthropic']: 68 | # raise ValueError('Invalid source') 69 | if source == 'OpenAI': 70 | return ChatOpenAI(model = model, temperature = temperature, **kwargs) 71 | elif source == 'Anthropic': 72 | return ChatAnthropic(model = model, 73 | temperature = temperature, 74 | max_tokens = 4096, 75 | **kwargs) 76 | elif source == 'Llama': 77 | llm = CustomChatModel(model = model, model_type=source, temperature = temperature) 78 | llm.client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY").chat.completions 79 | return llm 80 | else: 81 | # listen to a different port 82 | llm = CustomChatModel(model = model, model_type=source, temperature = temperature) 83 | llm.client = openai.Client(base_url="http://127.0.0.1:40000/v1", api_key="EMPTY").chat.completions 84 | return llm 85 | 86 | 87 | class FeedbackOutputSpec(BaseModel): 88 | """Output specification for the experiment evaluation & feedback""" 89 | 90 | is_valid: str = Field( 91 | description="The validity of the experiment implementation and output. Use 'Yes' if the experiment is valid, or 'No' if any issues are detected." 92 | ) 93 | p_value: Optional[str] = Field(description="The p-value from the experiment, formatted in scientific notation. Include this field only if the experiment is valid.") 94 | feedback: str = Field( 95 | description="Detailed feedback to address identified issues and suggestions for improving the experiment." 96 | ) 97 | 98 | class CodeOutputSpec(BaseModel): 99 | """Code output""" 100 | 101 | prefix: str = Field(description="Description of the problem and approach") 102 | imports: str = Field(description="Code block import statements") 103 | code: str = Field(description="Code block not including import statements") 104 | 105 | 106 | class SelfRefineState(TypedDict): 107 | iteration: int 108 | messages: List[Tuple[str, str]] 109 | code_impl: Optional[CodeOutputSpec] 110 | code_out: Optional[str] 111 | feedback_valid: bool 112 | p_value: Optional[float] 113 | feedback_text: Optional[str] 114 | done: bool 115 | 116 | 117 | class CodingAgent: 118 | """ 119 | - Takes test_specification & data as input. 120 | - Produces code using CODING_AGENT_SYSTEM_PROMPT + user specification. 121 | - Executes the code and captures the output. 122 | """ 123 | 124 | def __init__(self, llm="claude-3-5-sonnet-20241022", time_limit: int = 3): 125 | """ 126 | Args: 127 | llm: name of the language model 128 | time_limit: Time limit (in minutes) to run code. 129 | """ 130 | self.llm = get_llm(llm) 131 | self.time_limit = time_limit 132 | 133 | # We will build a ChatPromptTemplate for code generation 134 | # Then parse the result with CodeOutputSpec. 135 | self.code_generation_prompt = ChatPromptTemplate.from_messages( 136 | [ 137 | ( 138 | "system", 139 | CODING_AGENT_SYSTEM_PROMPT, 140 | ), 141 | ("placeholder", "{messages}"), 142 | ] 143 | ) 144 | 145 | self.structured_llm = self.llm.with_structured_output(CodeOutputSpec) 146 | 147 | self._exec_globals = {} 148 | self._exec_globals.update(__builtins__) 149 | 150 | def _set_globals(self, table_dict=None): 151 | self._exec_globals = {} 152 | self._exec_globals.update(__builtins__) 153 | 154 | if table_dict: 155 | self._exec_globals.update(table_dict) 156 | 157 | def generate_code(self, messages) -> CodeOutputSpec: 158 | """Generate code from the LLM given a test specification.""" 159 | # Invoke the pipeline 160 | code_gen_chain = self.code_generation_prompt | self.structured_llm 161 | result: CodeOutputSpec = code_gen_chain.invoke({"messages": messages}) 162 | 163 | return result 164 | 165 | def execute_code(self, code_implementation: CodeOutputSpec) -> str: 166 | """ 167 | Safely executes the provided code in a subprocess (with a time limit). 168 | Returns stdout/stderr as a string. 169 | """ 170 | # Combine imports + code 171 | full_code = code_implementation.imports + "\n\n" + code_implementation.code 172 | 173 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S') 174 | logging.getLogger().setLevel(logging.INFO) 175 | 176 | # Worker function to run code in a separate process 177 | def run_code(queue): 178 | output_capture = io.StringIO() 179 | try: 180 | with contextlib.redirect_stdout(output_capture), contextlib.redirect_stderr(output_capture): 181 | logging.getLogger().handlers[0].stream = output_capture 182 | exec_globals = self._exec_globals 183 | exec(full_code, exec_globals) 184 | except Exception: 185 | error_message = traceback.format_exc() 186 | queue.put(error_message) 187 | else: 188 | queue.put(output_capture.getvalue()) 189 | 190 | # Use multiprocessing to enforce time limit 191 | queue = multiprocessing.Queue() 192 | proc = multiprocessing.Process(target=run_code, args=(queue,)) 193 | proc.start() 194 | proc.join(timeout=self.time_limit * 60) 195 | 196 | if proc.is_alive(): 197 | # Timed out 198 | proc.terminate() 199 | proc.join() 200 | return "TimeoutError: Code execution took too long." 201 | 202 | # Retrieve captured output 203 | if not queue.empty(): 204 | code_output = queue.get() 205 | else: 206 | code_output = "No output was captured." 207 | return code_output 208 | 209 | def run( 210 | self, 211 | messages: Any 212 | ) -> Tuple[CodeOutputSpec, str]: 213 | """ 214 | High-level method: 215 | 1) Generate code from LLM 216 | 2) Execute the code 217 | 3) Return code + output 218 | """ 219 | code_impl = self.generate_code(messages) 220 | print("---------Code Implementation-----------") 221 | print(code_impl) 222 | code_out = self.execute_code(code_impl) 223 | print("----------Code Output-----------") 224 | print(code_out) 225 | return code_impl, code_out 226 | 227 | class FeedbackAgent: 228 | """ 229 | - Takes in the test specification, data, code implementation, and execution result. 230 | - Evaluates using FEEDBACK_AGENT_SYSTEM_PROMPT. 231 | - Produces a FeedbackOutputSpec indicating whether the test is valid, a final p-value if valid, and suggestions. 232 | """ 233 | 234 | def __init__(self, llm="claude-3-5-sonnet-20241022"): 235 | self.llm = get_llm(llm) 236 | 237 | # Build a ChatPromptTemplate for feedback 238 | self.feedback_prompt = ChatPromptTemplate.from_messages( 239 | [ 240 | ("system", FEEDBACK_AGENT_SYSTEM_PROMPT), 241 | ("user", "{evaluation_input}"), 242 | ] 243 | ) 244 | self.structured_llm = self.llm.with_structured_output(FeedbackOutputSpec) 245 | 246 | def run( 247 | self, 248 | test_specification: str, 249 | data: str, 250 | code_implementation: CodeOutputSpec, 251 | code_output: str, 252 | ) -> FeedbackOutputSpec: 253 | """ 254 | 1) Concatenate specification + code + output into a single message 255 | 2) Ask LLM for feedback with FEEDBACK_AGENT_SYSTEM_PROMPT 256 | 3) Parse into FeedbackOutputSpec 257 | """ 258 | spec_text = ( 259 | f"Main Hypothesis: {test_specification}\n" 260 | ) 261 | 262 | code_text = ( 263 | f"--- Code Implementation ---\n" 264 | f"Prefix:\n{code_implementation.prefix}\n\n" 265 | f"Imports:\n{code_implementation.imports}\n\n" 266 | f"Code:\n{code_implementation.code}\n\n" 267 | ) 268 | 269 | output_text = f"--- Execution Result ---\n{code_output}\n\n" 270 | 271 | # Summarize data if needed (or directly pass it) 272 | # For brevity, we just pass a note that data is available: 273 | data_text = f"Data description:\n{data}\n\n" 274 | 275 | evaluation_input = spec_text + code_text + output_text + data_text 276 | 277 | # Now run the LLM 278 | final_chain = self.feedback_prompt | self.structured_llm 279 | feedback_result: FeedbackOutputSpec = final_chain.invoke( 280 | {"evaluation_input": evaluation_input} 281 | ) 282 | 283 | print("---------Feedback----------") 284 | print(feedback_result) 285 | return feedback_result 286 | 287 | class SelfRefineAgent: 288 | """ 289 | - Orchestrates iterative refinement: 290 | 1) Generate code with CodingAgent 291 | 2) Evaluate with FeedbackAgent 292 | 3) If invalid, refine or retry until success or max attempts 293 | 4) Once valid, parse p-value < 0.05 => {"output": True} else {"output": False}. 294 | """ 295 | 296 | def __init__( 297 | self, 298 | llm: str = "claude-3-5-sonnet-20241022", 299 | max_iterations: int = 10, 300 | p_value_threshold: float = 0.05, 301 | ): 302 | self.coding_agent = CodingAgent(llm) 303 | self.feedback_agent = FeedbackAgent(llm) 304 | self.max_iterations = max_iterations 305 | self.p_value_threshold = p_value_threshold 306 | 307 | self.graph = self._build_graph() 308 | 309 | def _build_graph(self): 310 | """ 311 | Create and compile the LangGraph for the iterative refinement. 312 | """ 313 | def generate_and_run_code(state: SelfRefineState): 314 | """ 315 | Node that calls coding_agent.run(...) to produce code & execution results. 316 | """ 317 | code_impl, code_out = self.coding_agent.run( 318 | state["messages"] 319 | ) 320 | state["code_impl"] = code_impl 321 | state["code_out"] = code_out 322 | state["messages"] += [ 323 | ( 324 | "assistant", 325 | f"{code_impl.prefix} \n Imports: {code_impl.imports} \n Code: {code_impl.code}", 326 | ) 327 | ] 328 | return state 329 | 330 | def evaluate_feedback(state: SelfRefineState): 331 | """ 332 | Node that calls feedback_agent.run(...) to see if code is valid. 333 | """ 334 | feedback = self.feedback_agent.run( 335 | self.test_specification, 336 | self.data, 337 | state["code_impl"], 338 | state["code_out"] 339 | ) 340 | # Suppose feedback has structure: 341 | # { 342 | # "is_valid": "Yes" or "No", 343 | # "p_value": "...", 344 | # "feedback": "..." 345 | # } 346 | # We store these in state 347 | state["feedback_valid"] = (feedback.is_valid.lower() == "yes") 348 | if feedback.is_valid.lower() == "yes": 349 | # attempt to parse p-value 350 | try: 351 | pval = float(feedback.p_value) if feedback.p_value else None 352 | except: 353 | pval = None 354 | state["p_value"] = pval 355 | else: 356 | state["p_value"] = None 357 | state["feedback_text"] = feedback.feedback 358 | state["messages"] += [ 359 | ("user", feedback.feedback) 360 | ] 361 | return state 362 | 363 | def decide_if_done(state: SelfRefineState) -> Literal["check_p_value", "maybe_retry"]: 364 | """ 365 | Decide if the feedback was valid. If valid, move to next step. If not, maybe keep iterating. 366 | """ 367 | if state["feedback_valid"]: 368 | state["done"] = True 369 | return "end" 370 | else: 371 | return "check_max_iteration" 372 | 373 | def check_max_iteration(state: SelfRefineState) -> Literal["generate_and_run_code", "end"]: 374 | 375 | state["iteration"] += 1 376 | if state["iteration"] > self.max_iterations: 377 | state["done"] = True 378 | return state 379 | 380 | def maybe_retry(state: SelfRefineState) -> Literal["generate_and_run_code", "end"]: 381 | if state["done"] == True: 382 | return "end" 383 | return "generate_and_run_code" 384 | 385 | graph_builder = StateGraph(SelfRefineState) 386 | 387 | graph_builder.add_node("generate_and_run_code", generate_and_run_code) 388 | graph_builder.add_node("evaluate_feedback", evaluate_feedback) 389 | graph_builder.add_node("check_max_iteration", check_max_iteration) 390 | 391 | graph_builder.add_edge(START, "generate_and_run_code") 392 | graph_builder.add_edge("generate_and_run_code", "evaluate_feedback") 393 | graph_builder.add_conditional_edges( 394 | "evaluate_feedback", 395 | decide_if_done, 396 | { 397 | "end": END, 398 | "check_max_iteration": "check_max_iteration", 399 | } 400 | ) 401 | graph_builder.add_conditional_edges( 402 | "check_max_iteration", 403 | maybe_retry, 404 | { 405 | "generate_and_run_code": "generate_and_run_code", 406 | "end": END, 407 | } 408 | ) 409 | 410 | return graph_builder.compile() 411 | 412 | def generate( 413 | self, 414 | query: str, 415 | data_loader: Any, 416 | ) -> Dict[str, Any]: 417 | """ 418 | Returns: 419 | Dict with keys: 420 | - 'output': bool (True if p-value < threshold, False otherwise) 421 | - 'feedback': feedback text from the final iteration 422 | - 'p_value': (optional) p-value from the final iteration 423 | """ 424 | self.test_specification = query 425 | self.data = data_loader.data_desc 426 | self.coding_agent._set_globals(data_loader.table_dict) 427 | messages = [ 428 | ("user", "Here is the hypothesis to falsify:" + self.test_specification + "\n\n" + "And here are the available data relevant to the hypothesis:\n" + self.data + "\n\nEach of these dataframes have already been loaded into the global namespace. You may access each dataframe **directly as variables**. Make sure to use the **EXACT** dataframe names as shown above.") 429 | ] 430 | 431 | state_dict = SelfRefineState( 432 | iteration=1, 433 | messages=messages, 434 | code_impl=None, 435 | code_out=None, 436 | feedback_valid=False, 437 | p_value=None, 438 | feedback_text=None, 439 | done=False 440 | ) 441 | # run the graph 442 | result = self.graph.invoke(state_dict) 443 | 444 | # interpret final results 445 | if result["p_value"] is not None and result["p_value"] < self.p_value_threshold: 446 | return { 447 | "output": True, 448 | "feedback": result["feedback_text"], 449 | "p_value": result["p_value"], 450 | } 451 | elif result["p_value"] is not None: 452 | return { 453 | "output": False, 454 | "feedback": result["feedback_text"], 455 | "p_value": result["p_value"], 456 | } 457 | else: 458 | # no valid p-value or max iterations exhausted 459 | return { 460 | "output": False, 461 | "feedback": "No valid p-value found or max iteration limit reached." 462 | } --------------------------------------------------------------------------------