├── src ├── expts │ ├── __init__.py │ ├── client.py │ ├── dsbench.py │ ├── server.py │ ├── training.py │ ├── common.py │ └── eval.py ├── fhda │ ├── __init__.py │ ├── templates │ │ ├── base │ │ │ ├── cell_id_anchor.j2 │ │ │ ├── celltags.j2 │ │ │ ├── mathjax.html.j2 │ │ │ ├── jupyter_widgets.html.j2 │ │ │ ├── display_priority.j2 │ │ │ └── null.j2 │ │ └── lab │ │ │ ├── conf.json │ │ │ ├── index.html.j2 │ │ │ ├── mermaidjs.html.j2 │ │ │ └── base.html.j2 │ ├── dev.yaml │ ├── kernel_requirements.txt │ ├── config.py │ ├── Dockerfile.pinned │ ├── dataset.py │ ├── data_analysis_env.py │ ├── models.py │ ├── prompts.py │ ├── tortoise.py │ └── storage.py ├── scripts │ ├── __init__.py │ ├── bixbench_evaluation │ │ ├── server.yaml │ │ ├── run.sh │ │ └── runner.yaml │ ├── expt_logging.py │ ├── configurable.py │ └── config.py └── __init__.py ├── tutorial ├── datasets │ └── GSE52778_All_Sample_FPKM_Matrix.txt.gz ├── platform_api.ipynb ├── example.ipynb ├── multi_agent_orchestration.ipynb └── consensus.ipynb ├── pyproject.toml ├── .github └── workflows │ └── tests.yml ├── .pre-commit-config.yaml ├── README.md ├── .gitignore └── LICENSE /src/expts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fhda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fhda 2 | from . import expts 3 | from . import scripts 4 | 5 | __all__ = ["fhda", "expts", "scripts"] 6 | -------------------------------------------------------------------------------- /tutorial/datasets/GSE52778_All_Sample_FPKM_Matrix.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Future-House/data-analysis-crow/HEAD/tutorial/datasets/GSE52778_All_Sample_FPKM_Matrix.txt.gz -------------------------------------------------------------------------------- /src/fhda/templates/base/cell_id_anchor.j2: -------------------------------------------------------------------------------- 1 | {%- macro cell_id_anchor(cell) -%} 2 | {% if cell.id | length > 0 -%} 3 | id="{{ ('cell-id=' ~ cell.id) | escape_html -}}" 4 | {%- endif %} 5 | {%- endmacro %} 6 | -------------------------------------------------------------------------------- /src/fhda/dev.yaml: -------------------------------------------------------------------------------- 1 | job: 2 | cpu: 2 3 | memory: 4Gi 4 | timeout: 1200s 5 | env: 6 | CROW_AGENT: ldp.agent.SimpleAgent 7 | CROW_ENVIRONMENT: data_analysis.env.DataAnalysisEnv 8 | OPENAI_API_KEY: gcsm:crow-openai-api-key 9 | ANTHROPIC_API_KEY: gcsm:crow-anthropic-api-key 10 | -------------------------------------------------------------------------------- /src/fhda/templates/base/celltags.j2: -------------------------------------------------------------------------------- 1 | {%- macro celltags(cell) -%} 2 | {% if cell.metadata.tags | length > 0 -%} 3 | {% for tag in (cell.metadata.tags) -%} 4 | {{ (' celltag_' ~ tag) | escape_html -}} 5 | {%- endfor -%} 6 | {%- endif %} 7 | {%- endmacro %} 8 | -------------------------------------------------------------------------------- /src/fhda/templates/lab/conf.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_template": "base", 3 | "mimetypes": { 4 | "text/html": true 5 | }, 6 | "preprocessors": { 7 | "100-pygments": { 8 | "type": "nbconvert.preprocessors.CSSHTMLHeaderPreprocessor", 9 | "enabled": true 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/scripts/bixbench_evaluation/server.yaml: -------------------------------------------------------------------------------- 1 | expt: expts.server.CapsuleDatasetServer 2 | 3 | port: 8042 4 | 5 | dataset: 6 | prompt_template_key: v1.3.2 7 | local_repo_path: LOCAL_PATH_TO_DATASET 8 | capsule_mode: open 9 | avoid_images: true 10 | local_output_path: LOCAL_PATH_TO_OUTPUT_TRAJECTORIES 11 | -------------------------------------------------------------------------------- /src/scripts/bixbench_evaluation/run.sh: -------------------------------------------------------------------------------- 1 | echo "Starting dataset server" 2 | run_expt server.yaml & 3 | 4 | echo "Waiting for servers to start..." 5 | while ! curl -s localhost:8042 >/dev/null 2>&1; do 6 | sleep 5 7 | echo "Waiting for first server on port 8042..." 8 | done 9 | echo "First server is running" 10 | echo "Starting runners" 11 | 12 | run_expt runner.yaml & 13 | -------------------------------------------------------------------------------- /src/fhda/kernel_requirements.txt: -------------------------------------------------------------------------------- 1 | anndata==0.11.1 2 | biopython==1.84 3 | ete3==3.1.3 4 | fcsparser==0.2.8 5 | cython==3.0.12 6 | gseapy==1.1.4 7 | keras==3.7.0 8 | jupyter==1.0.0 9 | matplotlib==3.10.0 10 | matplotlib-venn==1.1.1 11 | mygene==3.2.2 12 | nbconvert==7.16.4 13 | numpy==1.26.4 # Pinned lower for fcsparser <2 14 | optuna==4.1.0 15 | openpyxl==3.1.5 16 | pandas==2.2.3 17 | plotly==5.24.1 18 | rpy2==3.5.11 19 | scipy==1.14.1 20 | scanpy==1.10.4 21 | seaborn==0.13.2 22 | scikit-learn==1.6.0 23 | statsmodels==0.14.4 24 | umap-learn==0.5.7 25 | -------------------------------------------------------------------------------- /src/scripts/bixbench_evaluation/runner.yaml: -------------------------------------------------------------------------------- 1 | expt: expts.eval.NBEvalExpt 2 | 3 | output_repo: 4 | name: ludo/data-analysis/capsules/rollout1 5 | overwrite: True 6 | 7 | env: 8 | host: localhost 9 | port: 8042 10 | request_timeout: 600 11 | split: all 12 | 13 | evaluator: 14 | batch_size: 1 15 | num_eval_iterations: null 16 | max_rollout_steps: 25 17 | shuffle: False 18 | catch_agent_failures: True 19 | catch_env_failures: True 20 | clear_ctx_at_each_iter: False 21 | 22 | agent: 23 | agent_type: ReActAgent 24 | agent_kwargs: 25 | llm_model: 26 | model: gpt-4o #claude-3-5-sonnet-20241022 27 | parallel_tool_calls: False 28 | num_retries: 5 29 | temperature: 1.0 30 | 31 | hide_old_env_states: True -------------------------------------------------------------------------------- /src/fhda/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | USE_DOCKER = bool(os.getenv("USE_DOCKER", "false").lower() == "true") 5 | USE_R = bool(os.getenv("USE_R", "false").lower() == "true") 6 | NB_ENVIRONMENT_DOCKER_IMAGE = os.getenv( 7 | "NB_ENVIRONMENT_DOCKER_IMAGE", "futurehouse/bixbench:aviary-notebook-env" 8 | ) 9 | 10 | # Some R error messages can be 100,000 of characters 11 | NB_OUTPUT_LIMIT = 3000 # chars 12 | # Streams from a docker container. Don't set to `sys.stdout.fileno()` 13 | # because we want to differentiate from file I/O 14 | DOCKER_STREAM_TYPE_STDOUT = 1 15 | DOCKER_STREAM_TYPE_STDERR = 2 16 | 17 | STAGE = os.getenv("STAGE", "local") 18 | if STAGE == "local": 19 | DATA_STORAGE_PATH = Path("storage") 20 | else: 21 | DATA_STORAGE_PATH = Path("/storage") 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools>=64"] 4 | 5 | [project] 6 | authors = [ 7 | {email = "hello@futurehouse.org", name = "FutureHouse technical staff"} 8 | ] 9 | dependencies = [ 10 | "aiodocker==0.24.0", 11 | "anthropic==0.52.2", # this is necessary for tortoise, remove in favor of LMI when it works with search 12 | "fhaviary[server]==0.19.0", 13 | "ldp==0.26.0", 14 | "pandas==2.2.3", 15 | "numpy==2.2.3", 16 | "matplotlib==3.10.0", 17 | "aiofiles==24.1.0", 18 | "google-auth==2.38.0", 19 | "google-cloud-storage==3.0.0", 20 | "google-cloud-secret-manager==2.23.0", 21 | "futurehouse-client==0.3.19", 22 | "jupyter==1.1.1", 23 | "nbconvert==7.16.6", 24 | "notebook==7.3.2", 25 | "nbformat==5.10.4", 26 | "seaborn==0.13.2" 27 | ] 28 | description = "Data analysis crow" 29 | name = "fhda" 30 | requires-python = ">=3.12" 31 | version = "1.0.0" 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "black", 36 | "isort", 37 | "mypy", 38 | "pre-commit", 39 | "pytest", 40 | "pytest-asyncio", 41 | "pytest-cov", 42 | "ruff" 43 | ] 44 | 45 | [project.scripts] 46 | run_expt = 'scripts.configurable:_run_expt' 47 | 48 | [tool.setuptools] 49 | package-dir = {"" = "src"} 50 | 51 | [tool.setuptools.packages.find] 52 | where = ["src"] 53 | -------------------------------------------------------------------------------- /src/fhda/templates/base/mathjax.html.j2: -------------------------------------------------------------------------------- 1 | 2 | {%- macro mathjax(url="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML-full,Safe") -%} 3 | 4 | 5 | 6 | 37 | 38 | {%- endmacro %} 39 | -------------------------------------------------------------------------------- /src/fhda/templates/base/jupyter_widgets.html.j2: -------------------------------------------------------------------------------- 1 | {%- macro jupyter_widgets(widgets_cdn_url, html_manager_semver_range, widget_renderer_url='') -%} 2 | 3 | 35 | 36 | {%- endmacro %} 37 | -------------------------------------------------------------------------------- /src/scripts/expt_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from ldp.utils import configure_stdout_logs 5 | from llmclient import configure_llm_logs 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def configure_logs( 11 | log_file: str | os.PathLike | None = None, 12 | stdout_level: int | str | tuple[str, int | str] | None = logging.INFO, 13 | fmt: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 14 | ) -> None: 15 | """Configure logs. 16 | 17 | Args: 18 | log_file: Optional log file to add to all loggers. 19 | stdout_level: If int (default) or str, it's a log level for stdout. If two-tuple 20 | of str and int, it's a logger name and log level for that logger. Otherwise, 21 | if None, don't configure stdout logs. 22 | fmt: Logging format string. 23 | """ 24 | configure_llm_logs() 25 | 26 | # Set some good default log levels to avoid too much verbosity 27 | logging.getLogger("dask").setLevel(logging.WARNING) 28 | logging.getLogger("vcr.cassette").setLevel(logging.WARNING) 29 | 30 | if stdout_level is not None: 31 | if isinstance(stdout_level, tuple): 32 | configure_stdout_logs(name=stdout_level[0], level=stdout_level[1], fmt=fmt) 33 | else: 34 | configure_stdout_logs(level=stdout_level, fmt=fmt) 35 | 36 | if log_file is not None: 37 | # Configure all loggers to write to a log file 38 | file_handler = logging.FileHandler(log_file) 39 | file_handler.setLevel(logging.DEBUG) 40 | file_handler.setFormatter(logging.Formatter(fmt)) 41 | logger.info(f"Logging to {log_file}.") 42 | 43 | # apply retroactively to root logger and all existing loggers 44 | for logger_name in ("root", *logging.root.manager.loggerDict.keys()): 45 | logging.getLogger(logger_name).addHandler(file_handler) 46 | -------------------------------------------------------------------------------- /src/fhda/templates/base/display_priority.j2: -------------------------------------------------------------------------------- 1 | {%- extends 'base/null.j2' -%} 2 | 3 | {#display data priority#} 4 | 5 | 6 | {%- block data_priority scoped -%} 7 | {%- for type in output.data | filter_data_type -%} 8 | {%- if type == 'application/pdf' -%} 9 | {%- block data_pdf -%} 10 | {%- endblock -%} 11 | {%- elif type == 'image/svg+xml' -%} 12 | {%- block data_svg -%} 13 | {%- endblock -%} 14 | {%- elif type == 'image/png' -%} 15 | {%- block data_png -%} 16 | {%- endblock -%} 17 | {%- elif type == 'text/html' -%} 18 | {%- block data_html -%} 19 | {%- endblock -%} 20 | {%- elif type == 'text/markdown' -%} 21 | {%- block data_markdown -%} 22 | {%- endblock -%} 23 | {%- elif type == 'image/jpeg' -%} 24 | {%- block data_jpg -%} 25 | {%- endblock -%} 26 | {%- elif type == 'text/plain' -%} 27 | {%- block data_text -%} 28 | {%- endblock -%} 29 | {%- elif type == 'text/latex' -%} 30 | {%- block data_latex -%} 31 | {%- endblock -%} 32 | {%- elif type == 'text/vnd.mermaid' -%} 33 | {%- block data_mermaid -%} 34 | {%- endblock -%} 35 | {%- elif type == 'application/javascript' -%} 36 | {%- block data_javascript -%} 37 | {%- endblock -%} 38 | {%- elif type == 'application/vnd.jupyter.widget-view+json' -%} 39 | {%- block data_widget_view -%} 40 | {%- endblock -%} 41 | {%- elif type == resources.output_mimetype -%} 42 | {%- block data_native -%} 43 | {%- endblock -%} 44 | {%- else -%} 45 | {%- block data_other -%} 46 | {%- endblock -%} 47 | {%- endif -%} 48 | {%- endfor -%} 49 | {%- endblock data_priority -%} 50 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.12"] 14 | 15 | steps: 16 | - name: Check out Git repository 17 | uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install setuptools>=66 wheel>=0.36 build 26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 27 | if [ -f pyproject.toml ]; then pip install -e .[dev]; fi 28 | 29 | - name: Run Lint 30 | run: | 31 | # Check for linting issues 32 | ruff check . 33 | # Check for formatting issues (will fail if code needs formatting) 34 | ruff format --check . 35 | 36 | test: 37 | runs-on: ubuntu-latest 38 | strategy: 39 | matrix: 40 | python-version: ["3.12"] 41 | 42 | steps: 43 | - name: Check out Git repository 44 | uses: actions/checkout@v4 45 | 46 | - name: Set up Python ${{ matrix.python-version }} 47 | uses: actions/setup-python@v5 48 | with: 49 | python-version: ${{ matrix.python-version }} 50 | - name: Install dependencies 51 | run: | 52 | python -m pip install --upgrade pip 53 | pip install setuptools>=66 wheel>=0.36 build 54 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 55 | if [ -f pyproject.toml ]; then pip install -e .[dev]; fi 56 | 57 | - name: Run Test 58 | run: | 59 | python -m pytest 60 | env: 61 | GITHUB_ACTIONS: true 62 | -------------------------------------------------------------------------------- /src/expts/client.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing 3 | from enum import StrEnum, auto 4 | 5 | from aviary.core import ( 6 | TaskDatasetClient, 7 | TaskEnvironmentClient, 8 | ) 9 | from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin 10 | 11 | 12 | class TaskDatasetSubsetClient(TaskDatasetClient, ComputeTrajectoryMetricsMixin): 13 | """Convenience class to subset a dataset using a single server.""" 14 | 15 | def __init__(self, client: TaskDatasetClient, task_idcs: list[int]) -> None: 16 | super().__init__( 17 | server_url=client.server_url, request_timeout=client.request_timeout 18 | ) 19 | self.idcs = task_idcs 20 | 21 | def __len__(self) -> int: 22 | return len(self.idcs) 23 | 24 | def get_new_env_by_idx(self, idx: int) -> TaskEnvironmentClient: 25 | return super().get_new_env_by_idx(self.idcs[idx]) 26 | 27 | 28 | class TaskDatasetSplit(StrEnum): 29 | TRAIN = auto() 30 | EVAL = auto() 31 | TEST = auto() 32 | ALL = auto() 33 | 34 | def get_random_split( 35 | self, dataset_client: TaskDatasetClient, seed: int = 0 36 | ) -> TaskDatasetClient: 37 | if self == TaskDatasetSplit.ALL: 38 | return dataset_client 39 | 40 | # Slightly hacky way to make a split for now 41 | # Split the dataset into a 80/10/10 split using a deterministic seed 42 | n_total = len(dataset_client) 43 | all_idcs = random.Random(seed).sample(range(n_total), n_total) 44 | 45 | match self: 46 | case TaskDatasetSplit.TRAIN: 47 | idcs = all_idcs[: int(0.8 * n_total)] 48 | case TaskDatasetSplit.EVAL: 49 | idcs = all_idcs[int(0.8 * n_total) : int(0.9 * n_total)] 50 | case TaskDatasetSplit.TEST: 51 | idcs = all_idcs[int(0.9 * n_total) :] 52 | 53 | case _: 54 | typing.assert_never(self) 55 | 56 | return TaskDatasetSubsetClient(dataset_client, idcs) 57 | -------------------------------------------------------------------------------- /src/scripts/configurable.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import importlib 3 | import sys 4 | from abc import ABC, abstractmethod 5 | from typing import Self 6 | 7 | from .config import ConfigModel, load_arg_dict, load_config 8 | 9 | 10 | class ConfigurableExpt(ConfigModel, ABC): 11 | """A base class for configurable experiments. 12 | 13 | Example usage: 14 | ```py 15 | expt = DummyExpt() 16 | await expt.run() # prints "Hello, world!" 17 | 18 | expt = DummyExpt.from_cli_args(argv=["--who", "friend"]) 19 | await expt.run() # prints "Hello, friend!" 20 | ``` 21 | """ 22 | 23 | @classmethod 24 | def from_cli_args(cls, **kwargs) -> Self: 25 | return load_config(cls, **kwargs) 26 | 27 | @abstractmethod 28 | async def run(self) -> None: 29 | """The entry point for the executable.""" 30 | 31 | 32 | class DummyExpt(ConfigurableExpt): 33 | """For unit tests.""" 34 | 35 | who: str = "world" 36 | 37 | # Returning string for unit tests 38 | async def run(self) -> str: # type: ignore[override] 39 | print(f"Hello, {self.who}!") 40 | return self.who 41 | 42 | 43 | def _run_expt() -> None: 44 | """ 45 | Import and run a ConfigurableExpt. 46 | 47 | NOTE: this is not meant to be called from python code, instead it's exposed 48 | (in pyproject.toml) as `run_expt` command line entry point. 49 | """ 50 | 51 | argv = sys.argv[1:] 52 | first_arg: str | None = argv[0] if argv else None 53 | 54 | if not first_arg or first_arg in {"-h", "--help"}: 55 | print("Usage: run_expt [app_args...]") 56 | return 57 | 58 | # check if expt_name was specified 59 | if first_arg.startswith("--") or first_arg.endswith(".yaml"): 60 | # expt_name was not specified in CLI args. Try to infer from remaining args 61 | parsed_args = load_arg_dict(argv=argv) 62 | try: 63 | expt_name = parsed_args["expt"] 64 | except KeyError: 65 | # NOTE: not using `raise ValueError` to avoid lengthy traceback 66 | print( 67 | "Error: experiment was not specified in CLI args nor in configuration.", 68 | file=sys.stderr, 69 | ) 70 | sys.exit(1) 71 | 72 | else: 73 | expt_name = argv.pop(0) 74 | 75 | # Import the expt 76 | expt_module = importlib.import_module(name=".".join(expt_name.split(".")[:-1])) 77 | expt_class = getattr(expt_module, expt_name.split(".")[-1]) 78 | if not issubclass(expt_class, ConfigurableExpt): 79 | # NOTE: not using `raise TypeError` to avoid lengthy traceback 80 | print( 81 | f"Error: {expt_name} is not a subclass of ConfigurableExpt.", 82 | file=sys.stderr, 83 | ) 84 | sys.exit(1) 85 | 86 | # Skip 'expt' if it's in the args, since that was just used to infer expt_name 87 | expt = expt_class.from_cli_args(argv=argv, args_to_exclude=["expt"]) 88 | asyncio.run(expt.run()) 89 | -------------------------------------------------------------------------------- /src/fhda/Dockerfile.pinned: -------------------------------------------------------------------------------- 1 | # DANGER: Beware of changing this dockerfile, orchestrating the versioning in these R/python packages was very challenging 2 | FROM continuumio/miniconda3:24.9.2-0 3 | 4 | RUN mkdir /workspace && \ 5 | mkdir /envs 6 | WORKDIR /envs 7 | 8 | ENV DEBIAN_FRONTEND=noninteractive 9 | RUN apt-get update && \ 10 | apt-get install -yq --no-install-recommends \ 11 | wget \ 12 | gpg \ 13 | software-properties-common \ 14 | build-essential && \ 15 | rm -rf /var/lib/apt/lists/* 16 | 17 | RUN conda install mamba=2.0.5 -c conda-forge -y 18 | 19 | # Install R packages from conda-forge 20 | RUN mamba install -c conda-forge -y \ 21 | r-base=4.3.3 \ 22 | r-recommended=4.3 \ 23 | r-irkernel=1.3.2 \ 24 | r-factominer=2.11 \ 25 | r-rcolorbrewer=1.1_3 \ 26 | r-devtools=2.4.5 \ 27 | r-broom=1.0.7 \ 28 | r-data.table=1.15.4 \ 29 | r-enrichr=3.2 \ 30 | r-factoextra=1.0.7 \ 31 | r-ggnewscale=0.5.0 \ 32 | r-ggrepel=0.9.6 \ 33 | r-ggpubr=0.6.0 \ 34 | r-ggvenn=0.1.10 \ 35 | r-janitor=2.2.1 \ 36 | r-multcomp=1.4_26 \ 37 | r-matrix=1.6_5 \ 38 | r-pheatmap=1.0.12 \ 39 | r-tidyverse=2.0.0 \ 40 | r-readxl=1.4.3 \ 41 | r-reshape=0.8.9 \ 42 | r-rstatix=0.7.2 \ 43 | r-viridis=0.6.5 \ 44 | udocker=1.3.17 \ 45 | imbalanced-learn=0.13.0 \ 46 | ipykernel=6.29.5 \ 47 | sqlite=3.47.2 48 | 49 | RUN python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)" 50 | RUN R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")' 51 | 52 | # I separate these because not all packages need both channels, additionally, 53 | # creating multiple layers makes caching easier 54 | RUN mamba install -c conda-forge -c bioconda -y \ 55 | biokit=0.5.0 \ 56 | gseapy=1.1.4 \ 57 | blast=2.16.0 \ 58 | clipkit=2.3.0 \ 59 | fastqc=0.12.1 \ 60 | iqtree=2.3.6 \ 61 | mafft=7.526 \ 62 | metaeuk=7.bba0d80 \ 63 | mygene=3.2.2 \ 64 | perl=5.32.1 \ 65 | phykit=2.0.1 \ 66 | pydeseq2=0.4.12 \ 67 | spades=4.0.0 \ 68 | trim-galore=0.6.10 \ 69 | bioconductor-enhancedvolcano=1.20.0 \ 70 | bioconductor-deseq2=1.42.0 \ 71 | bioconductor-clusterprofiler=4.10.0 \ 72 | bioconductor-org.hs.eg.db=3.18.0 \ 73 | bioconductor-genomicranges=1.54.1 \ 74 | bioconductor-summarizedexperiment=1.32.0 \ 75 | bioconductor-apeglm=1.24.0 76 | 77 | 78 | COPY kernel_requirements.txt . 79 | 80 | # Install conda packages first 81 | RUN mamba install -c conda-forge --file kernel_requirements.txt -y 82 | 83 | # Install pip packages 84 | RUN pip install aiodocker ldp==0.26.0 fhaviary[server]==0.19.0 futurehouse-client==0.3.14 85 | 86 | # Certain tools are not easily installable via conda. A common practice for 87 | # bioinformaticians is to use udocker to run certain heavy duty omics processing 88 | # tools in an isolated environment 89 | # RUN udocker --allow-root install && \ 90 | # udocker --allow-root pull ezlabgva/busco:v5.8.0_cv1 91 | 92 | WORKDIR /workspace 93 | 94 | RUN mamba clean --all -f -y && \ 95 | conda clean --all -f -y && \ 96 | rm -rf /root/.cache/pip 97 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | default_language_version: 3 | python: python3 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | - id: check-added-large-files 9 | - id: check-byte-order-marker 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-shebang-scripts-are-executable 13 | - id: check-symlinks 14 | - id: check-toml 15 | - id: check-yaml 16 | - id: debug-statements 17 | - id: detect-private-key 18 | - id: end-of-file-fixer 19 | - id: mixed-line-ending 20 | - id: trailing-whitespace 21 | - repo: https://github.com/astral-sh/ruff-pre-commit 22 | rev: v0.9.1 23 | hooks: 24 | - id: ruff 25 | args: [--fix, --exit-non-zero-on-fix] 26 | - id: ruff-format 27 | - repo: https://github.com/rbubley/mirrors-prettier 28 | rev: v3.4.2 29 | hooks: 30 | - id: prettier 31 | - repo: https://github.com/jumanjihouse/pre-commit-hooks 32 | rev: 3.0.0 33 | hooks: 34 | - id: check-mailmap 35 | - repo: https://github.com/codespell-project/codespell 36 | rev: v2.3.0 37 | hooks: 38 | - id: codespell 39 | additional_dependencies: [".[toml]"] 40 | exclude_types: [jupyter] 41 | - repo: https://github.com/pappasam/toml-sort 42 | rev: v0.24.2 43 | hooks: 44 | - id: toml-sort-fix 45 | exclude: poetry.lock 46 | - repo: https://github.com/srstevenson/nb-clean 47 | rev: 4.0.1 48 | hooks: 49 | - id: nb-clean 50 | args: [--preserve-cell-outputs, --remove-empty-cells] 51 | - repo: https://github.com/henryiii/validate-pyproject-schema-store 52 | rev: 2025.01.10 53 | hooks: 54 | - id: validate-pyproject 55 | - repo: https://github.com/pre-commit/mirrors-mypy 56 | rev: v1.14.1 57 | hooks: 58 | - id: mypy 59 | additional_dependencies: 60 | - aiohttp 61 | - boto3-stubs[s3] 62 | - docstring_parser 63 | - fh-llm-client[deepseek]>=0.0.11 # Match aviary_internal pyproject.toml 64 | - fhaviary[server] >= 0.18.0 # Match aviary_internal pyproject.toml 65 | - gitpython 66 | - google-auth>=2.31 # Match aviary_internal pyproject.toml 67 | - google-cloud 68 | - google-cloud-run 69 | - google-cloud-tasks 70 | - google-cloud-secret-manager 71 | - google-cloud-storage 72 | - httpx<0.28 # Match aviary_internal pyproject.toml 73 | - jupyter-client 74 | - ldp>=0.22.0 # Match aviary_internal pyproject.toml 75 | - litellm>=1.40.9 # Match aviary_internal pyproject.toml 76 | - nbformat 77 | - numpy<2 # Match aviary_internal pyproject.toml 78 | - omegaconf 79 | - openai>=1 # Match aviary_internal pyproject.toml 80 | - pandas-stubs 81 | - pydantic~=2.0 # Match aviary_internal pyproject.toml 82 | - rich 83 | - SQLAlchemy[aiosqlite]~=2.0 # Match fhaviary pyproject.toml and dev-requirements.txt 84 | - tenacity 85 | - tiktoken 86 | - torch==2.5.1 # Match aviary_internal/nn/requirements.txt 87 | - types-aiofiles 88 | - types-Pillow 89 | - types-PyYAML 90 | - types-requests 91 | - types-tqdm 92 | - typing-extensions 93 | - wandb 94 | -------------------------------------------------------------------------------- /src/expts/dsbench.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | from aviary_internal import __version__, utils 7 | from pydantic import Field 8 | from tqdm import tqdm 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class GetKaggleInfo(utils.ConfigurableExpt): 14 | dataset_repo: utils.DataRepo = Field( 15 | default_factory=lambda: utils.DataRepo( 16 | name="baseline-envs/dsbench/data_modeling" 17 | ) 18 | ) 19 | 20 | async def run(self) -> None: 21 | try: 22 | from kaggle.api.kaggle_api_extended import KaggleApi 23 | except ImportError: 24 | raise ImportError( 25 | "Please `pip install kaggle` and set up authentication." 26 | ) from None 27 | 28 | api = KaggleApi() 29 | # Will raise if user is not authenticated 30 | api.authenticate() 31 | 32 | src_dir = Path(self.dataset_repo.local_path) 33 | competitions = sorted([d.name for d in (src_dir / "data_resplit").glob("*")]) 34 | kaggle_info: dict[str, dict[str, float | bool | list[float]]] = {} 35 | 36 | for comp in tqdm(competitions, desc="Querying Kaggle", ncols=0): 37 | # Bit ugly: to determine if 'best' is max or min, we get the GT result and compare 38 | # to the actual submissions. I can't find any documentation saying the leaderboard 39 | # is ordered. 40 | 41 | try: 42 | target_result = float( 43 | (src_dir / "save_performance/GT" / comp / "result.txt").read_text() 44 | ) 45 | except FileNotFoundError: 46 | logger.error(f"Could not find GT result file for {comp} - skipping.") 47 | continue 48 | 49 | leaderboard = api.competition_leaderboard_view(comp) 50 | scores = [float(entry.score) for entry in leaderboard if entry.hasScore] 51 | if not scores: 52 | logger.error(f"No scores found for {comp} - skipping.") 53 | continue 54 | 55 | max_score, min_score = max(scores), min(scores) 56 | 57 | if min_score >= target_result: 58 | # smaller is better 59 | kaggle_info[comp] = { 60 | "best_score": min_score, 61 | "max_is_best": False, 62 | "scores": scores, 63 | } 64 | 65 | elif max_score <= target_result: 66 | # larger is better 67 | kaggle_info[comp] = { 68 | "best_score": max_score, 69 | "max_is_best": True, 70 | "scores": scores, 71 | } 72 | 73 | else: 74 | raise RuntimeError(f"Could not determine best score for {comp}.") 75 | 76 | with (src_dir / "kaggle_submissions.json").open("w") as f: 77 | json.dump( 78 | { 79 | "metadata": { 80 | "description": "Created by data_analysis.expts.dsbench.GetKaggleInfo.", 81 | "timestamp": datetime.now().isoformat(), 82 | "aviary_internal": __version__, 83 | }, 84 | "kaggle_info": kaggle_info, 85 | }, 86 | f, 87 | indent=2, 88 | ) 89 | 90 | self.dataset_repo.push(progress=True) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Analysis Crow: A Jupyter Notebook Agent 2 | 3 | Data Analysis Crow is an AI agent framework designed to perform complex scientific data analysis tasks by iteratively working through Jupyter notebooks. This agent takes in datasets and prompts, then systematically explores, analyzes, and interprets the data to provide comprehensive answers and insights. 4 | 5 | The agent was used to produce the trajectories for the [BixBench benchmark](https://github.com/Future-House/bixbench). 6 | 7 | ## Key Features 8 | 9 | - Accepts datasets and natural language prompts 10 | - Iteratively builds Jupyter notebooks to answer research questions 11 | - Works with Python, R, and Bash code execution 12 | - Specializes in bioinformatics analysis but adaptable to various domains 13 | - Comes with a Docker image including most common bioinformatics packages 14 | 15 | ## Links 16 | 17 | - [Installation](#installation) 18 | - [Using the Agent](#using-the-agent) 19 | - [Advanced Usage](#advanced-usage) 20 | - [BixBench Benchmark](#bixbench-benchmark) 21 | 22 | ## Installation 23 | 24 | ```bash 25 | # Clone the repository 26 | git clone https://github.com/Future-House/data-analysis-crow.git 27 | cd data-analysis-crow 28 | 29 | # Install dependencies 30 | pip install -e . 31 | 32 | # OPTIONAL:pull the docker image with bioinformatics packages 33 | docker pull futurehouse/bixbench:aviary-notebook-env 34 | ``` 35 | 36 | ## Prerequisites 37 | 38 | ### API Keys 39 | 40 | We support all LLMs that are supported by [litellm](https://github.com/BerriAI/litellm). Create a `.env` file with the API keys for the LLMs you want to use. For example: 41 | 42 | ``` 43 | OPENAI_API_KEY = "your-openai-api-key" 44 | ANTHROPIC_API_KEY = "your-anthropic-api-key" 45 | ``` 46 | 47 | ## Using the Agent 48 | 49 | The agent works by taking a dataset and a prompt, then iteratively building a Jupyter notebook to answer the question. Visit the [tutorial](https://github.com/Future-House/data-analysis-crow/blob/main/tutorial/example.ipynb) for a simple step-by-step guide on how to use the agent. 50 | 51 | ## Advanced Usage 52 | For advanced evaluations, you can configure `server.yaml` and `runner.yaml` in the `src/scripts/bixbench_evaluation` directory and then run the evaluation script: 53 | ```bash 54 | bash src/scripts/bixbench_evaluation/run.sh 55 | ``` 56 | 57 | This will: 58 | 1. Load the specified dataset 59 | 2. Process the prompt to understand the research question 60 | 3. Generate a Jupyter notebook with progressive analysis steps 61 | 4. Provide a final answer based on the analysis 62 | 63 | Results are saved in the output directory specified in your configuration file. 64 | 65 | Note that the dataset and environment configuration must be updated appropriately. For an example, see [dataset.py](https://github.com/Future-House/data-analysis-crow/blob/main/src/fhda/dataset.py) which includes the capsule dataset configuration used for the BixBench benchmark. 66 | 67 | We also recommend visiting the BixBench repository where we share a full evaluation harness for the agent. 68 | 69 | ## Hosted Agent 70 | Coming soon! 71 | 72 | ## BixBench Benchmark 73 | 74 | Data Analysis Crow was used to produce the trajectories for the [BixBench benchmark](https://github.com/Future-House/bixbench), which evaluates AI agents on real-world bioinformatics tasks. 75 | 76 | BixBench tests AI agents' ability to: 77 | 78 | - Explore biological datasets 79 | - Perform long, multi-step computational analyses 80 | - Interpret nuanced results in the context of a research question 81 | 82 | You can find the BixBench dataset in [Hugging Face](https://huggingface.co/datasets/futurehouse/BixBench), the paper [here](https://arxiv.org/abs/2503.00096), and the blog post [here](https://www.futurehouse.org/research-announcements/bixbench). 83 | 84 | ### Running BixBench Evaluations 85 | 86 | To use this agent for BixBench evaluations, we recommend visiting the [BixBench repository](https://github.com/Future-House/bixbench) for more details. -------------------------------------------------------------------------------- /src/fhda/templates/lab/index.html.j2: -------------------------------------------------------------------------------- 1 | {%- extends 'base.html.j2' -%} 2 | {% from 'mathjax.html.j2' import mathjax %} 3 | {% from 'mermaidjs.html.j2' import mermaid_js %} 4 | {% from 'jupyter_widgets.html.j2' import jupyter_widgets %} 5 | 6 | {%- block header -%} 7 | 8 | 9 | 10 | {%- block html_head -%} 11 | 12 | 13 | {% set nb_title = nb.metadata.get('title', resources['metadata']['name']) | escape_html_keep_quotes %} 14 | {{nb_title}} 15 | 16 | {%- block html_head_js -%} 17 | {%- block html_head_js_requirejs -%} 18 | 19 | {%- endblock html_head_js_requirejs -%} 20 | {%- endblock html_head_js -%} 21 | 22 | {% block jupyter_widgets %} 23 | {%- if "widgets" in nb.metadata -%} 24 | {{ jupyter_widgets(resources.jupyter_widgets_base_url, resources.html_manager_semver_range, resources.widget_renderer_url) }} 25 | {%- endif -%} 26 | {% endblock jupyter_widgets %} 27 | 28 | {% block extra_css %} 29 | {% endblock extra_css %} 30 | 31 | {% for css in resources.inlining.css -%} 32 | 35 | {% endfor %} 36 | 37 | {% block notebook_css %} 38 | {{ resources.include_css("static/index.css") }} 39 | {% if resources.theme == 'dark' %} 40 | {{ resources.include_css("static/theme-dark.css") }} 41 | {% elif resources.theme == 'light' %} 42 | {{ resources.include_css("static/theme-light.css") }} 43 | {% else %} 44 | {{ resources.include_lab_theme(resources.theme) }} 45 | {% endif %} 46 | 104 | 105 | {% endblock notebook_css %} 106 | 107 | {%- block html_head_js_mathjax -%} 108 | {{ mathjax(resources.mathjax_url) }} 109 | {%- endblock html_head_js_mathjax -%} 110 | 111 | {%- block html_head_js_mermaidjs -%} 112 | {{ mermaid_js(resources.mermaid_js_url) }} 113 | {%- endblock html_head_js_mermaidjs -%} 114 | 115 | {%- block html_head_css -%} 116 | {%- endblock html_head_css -%} 117 | 118 | {%- endblock html_head -%} 119 | 120 | {%- endblock header -%} 121 | 122 | {%- block body_header -%} 123 | {% if resources.theme == 'dark' %} 124 | 125 | {% else %} 126 | 127 | {% endif %} 128 |
129 | {%- endblock body_header -%} 130 | 131 | {% block body_footer %} 132 |
133 | 134 | {% endblock body_footer %} 135 | 136 | {% block footer %} 137 | {% block footer_js %} 138 | {% endblock footer_js %} 139 | {{ super() }} 140 | 141 | {% endblock footer %} 142 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # User-specific 3 | local 4 | 5 | # IntelliJ 6 | out/ 7 | 8 | # Local History for Visual Studio Code 9 | .history/ 10 | 11 | # Built Visual Studio Code Extensions 12 | *.vsix 13 | # General 14 | .DS_Store 15 | .AppleDouble 16 | .LSOverride 17 | 18 | # Files that might appear in the root of a volume 19 | .DocumentRevisions-V100 20 | .fseventsd 21 | .Spotlight-V100 22 | .TemporaryItems 23 | .Trashes 24 | .VolumeIcon.icns 25 | .com.apple.timemachine.donotpresent 26 | 27 | # Directories potentially created on remote AFP share 28 | .AppleDB 29 | .AppleDesktop 30 | Network Trash Folder 31 | Temporary Items 32 | .apdisk 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | share/python-wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | src/fhda/storage/ 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | *.py,cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | cover/ 85 | 86 | # Translations 87 | *.mo 88 | *.pot 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | *.ipynb 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | tutorial/tmp_results_dir 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | # PyCharm 143 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 144 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 145 | # and can be added to the global gitignore or merged into this file. For a more nuclear 146 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 147 | .idea/ 148 | # Local .terraform directories 149 | **/.terraform/* 150 | 151 | # .tfstate files 152 | *.tfstate 153 | *.tfstate.* 154 | 155 | # Crash log files 156 | crash.log 157 | crash.*.log 158 | 159 | # Exclude all .tfvars files, which are likely to contain sensitive data, such as 160 | # password, private keys, and other secrets. These should not be part of version 161 | # control as they are data points which are potentially sensitive and subject 162 | # to change depending on the environment. 163 | *.tfvars 164 | *.tfvars.json 165 | 166 | # Ignore override files as they are usually used to override resources locally and so 167 | # are not checked in 168 | override.tf 169 | override.tf.json 170 | *_override.tf 171 | *_override.tf.json 172 | 173 | # Include override files you do wish to add to version control using negated pattern 174 | # !example_override.tf 175 | 176 | # Include tfplan files to ignore the plan output of command: terraform plan -out=tfplan 177 | # example: *tfplan* 178 | 179 | # Ignore CLI configuration files 180 | .terraformrc 181 | terraform.rc 182 | 183 | # SLURM artifacts 184 | slurm_outputs/ 185 | 186 | # SWE-agent auto-creates these files 187 | keys.cfg 188 | 189 | # Version files made by setuptools_scm 190 | **/version.py 191 | 192 | # WandB cache files (e.g. generated by pytest) 193 | wandb/ 194 | 195 | # VSCode repo settings 196 | .vscode/ 197 | -------------------------------------------------------------------------------- /src/expts/server.py: -------------------------------------------------------------------------------- 1 | """Utilities to run TaskDatasetServers on various notebook task datasets.""" 2 | 3 | import json 4 | import logging 5 | import shutil 6 | from abc import ABC, abstractmethod 7 | from pathlib import Path 8 | from typing import Generic, TypeVar 9 | 10 | from aviary.core import TaskDataset, TaskDatasetServer 11 | from fhda.storage import DataRepo 12 | from fhda.utils import collect_notebook_stats 13 | from fhda.data_analysis_env import DataAnalysisEnv 14 | from fhda.dataset import CapsuleDataset, CapsuleDatasetConfig 15 | from scripts.configurable import ConfigurableExpt 16 | from pydantic import BaseModel, Field 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class SaveWorkspaceRequest(BaseModel): 23 | env_id: str 24 | traj_id: str 25 | workspace_repo: DataRepo 26 | exception: bool 27 | cost: float 28 | time: float 29 | 30 | 31 | class NBTaskDatasetServer(TaskDatasetServer[DataAnalysisEnv]): 32 | def _setup_routes(self) -> None: 33 | super()._setup_routes() 34 | 35 | @self.app.post("/save_workspace") 36 | async def save_workspace(req: SaveWorkspaceRequest): 37 | async with self.lock: 38 | env = self._get_env(req.env_id) 39 | 40 | problem_id = env.problem_id 41 | this_workspace_repo = DataRepo( 42 | name=f"{req.workspace_repo.name}/{problem_id.replace('/', '-')}-{req.traj_id}" 43 | ) 44 | this_workspace_repo.mkdir() 45 | out_dir = Path(this_workspace_repo.local_path) 46 | logger.info(f"Saving workspace to {this_workspace_repo.name}") 47 | 48 | # # Copy the full output directory 49 | for file in Path(env.state.work_dir).glob("**/*"): 50 | if file.suffix in {".ipynb", ".json"}: 51 | dest = out_dir / file.relative_to(env.state.work_dir) 52 | dest.parent.mkdir(parents=True, exist_ok=True) 53 | shutil.copy2(file, dest) 54 | res = { 55 | "problem_id": problem_id, 56 | "traj_id": req.traj_id, 57 | "reward": env.state.total_reward, 58 | "agent_answer": env.state.answer, 59 | "ideal_answer": env.answer, 60 | "problem": env.problem, 61 | "mcq_options": [q.options for q in env.mcqs] if env.mcqs else [], 62 | "mcq_question": [q.question for q in env.mcqs] if env.mcqs else [], 63 | "question_rewards": env.question_rewards, 64 | "cost": req.cost, 65 | "exception": req.exception, 66 | "notebook_stats": collect_notebook_stats(env.state.nb), 67 | "time": req.time, 68 | "actions": env.state.actions, 69 | "run_id": req.workspace_repo.name, 70 | "metadata": env.metadata, 71 | "insufficient_options": { 72 | q.question_id: q.unsure_answer_letter for q in (env.mcqs or []) 73 | }, 74 | } 75 | with (out_dir / "metadata.json").open("w") as f: 76 | json.dump( 77 | res, 78 | f, 79 | indent=4, 80 | ) 81 | 82 | # Push just this specific workspace, not the whole workspace repo 83 | this_workspace_repo.push(progress=True) 84 | # # Delete the workspace directory after pushing 85 | shutil.rmtree(out_dir) 86 | 87 | 88 | TDataset = TypeVar("TDataset", bound=TaskDataset) 89 | 90 | 91 | class DatasetServer(ConfigurableExpt, ABC, Generic[TDataset]): 92 | port: int 93 | 94 | @abstractmethod 95 | def make_dataset(self) -> TDataset: 96 | pass 97 | 98 | async def run(self) -> None: 99 | dataset = self.make_dataset() 100 | logger.info(f"Starting {dataset.__class__.__name__} server on port {self.port}") 101 | server = NBTaskDatasetServer(dataset, port=self.port) 102 | await server.astart() 103 | 104 | 105 | class CapsuleDatasetServer(DatasetServer[CapsuleDataset]): 106 | dataset: CapsuleDatasetConfig = Field(default_factory=CapsuleDatasetConfig) 107 | 108 | def make_dataset(self) -> CapsuleDataset: 109 | return CapsuleDataset(config=self.dataset) 110 | -------------------------------------------------------------------------------- /src/expts/training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from collections.abc import Mapping, Sequence 4 | 5 | from aviary.core import ( 6 | TaskDatasetClient, 7 | ) 8 | from aviary_internal import utils 9 | from aviary_internal.agent import DQNAgentVariant 10 | from aviary_internal.agent.dqn_agent import LLMSamplingMode 11 | from aviary_internal.alg.optimizer.dqn import DQNOptimizer 12 | from aviary_internal.nn.sft_optimizer import LocalLLMSFTOptimizer 13 | from aviary_internal.serialization import disable_serialization_backend 14 | from cloning.expts.local_sft import CloningOnlineLocalTrainingExpt 15 | from gsm8k.expts.dqn.online import GSM8kDQNOnlineTrainingExpt 16 | from ldp.alg.callbacks import Callback 17 | from ldp.alg.runners import OnlineTrainerConfig 18 | from ldp.data_structures import Trajectory 19 | 20 | from .client import TaskDatasetSplit 21 | from .common import SaveWorkspaceCallback, prev_choice_rep_fn 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class EnvServerConfig(utils.ConfigModel): 27 | host: str 28 | port: int 29 | request_timeout: float | None = 300.0 30 | 31 | async def make_datasets(self) -> dict[str, TaskDatasetClient]: 32 | base_dataset = await TaskDatasetClient.create( 33 | server_url=f"http://{self.host}:{self.port}", 34 | request_timeout=self.request_timeout, 35 | ) 36 | return { 37 | "train_dataset": TaskDatasetSplit.TRAIN.get_random_split(base_dataset), 38 | "eval_dataset": TaskDatasetSplit.EVAL.get_random_split(base_dataset), 39 | } 40 | 41 | 42 | class NBDQNOnlineTrainingExpt(GSM8kDQNOnlineTrainingExpt): 43 | env: EnvServerConfig 44 | 45 | async def make_datasets(self) -> dict[str, TaskDatasetClient]: 46 | return await self.env.make_datasets() 47 | 48 | def make_callbacks( 49 | self, 50 | agent: DQNAgentVariant, 51 | optimizer: DQNOptimizer, 52 | datasets: Mapping[str, TaskDatasetClient], 53 | ) -> list[Callback]: 54 | callbacks = super().make_callbacks(agent, optimizer, datasets) 55 | callbacks.append( 56 | SaveWorkspaceCallback( 57 | dataset_client=datasets["train_dataset"], 58 | workspace_repo=utils.DataRepo( 59 | name=f"{self.output_repo.name}-workspaces" 60 | ), 61 | ) 62 | ) 63 | return callbacks 64 | 65 | def make_agent(self, **kwargs) -> DQNAgentVariant: 66 | if self.agent.llm_sampling_mode == LLMSamplingMode.SEQUENTIAL: 67 | self.agent.llm_kwargs["prev_choice_rep_fn"] = prev_choice_rep_fn 68 | return super().make_agent(**kwargs) 69 | 70 | 71 | class NBOnlineTrainingConfig(OnlineTrainerConfig): 72 | save_all_checkpoints: bool = True 73 | num_val_trajs: int 74 | num_train_trajs: int | None = None 75 | 76 | 77 | class NBOnlineLocalTrainingExpt(CloningOnlineLocalTrainingExpt): 78 | env: EnvServerConfig 79 | trainer: NBOnlineTrainingConfig 80 | 81 | async def _get_demonstration_examples( 82 | self, opt: LocalLLMSFTOptimizer 83 | ) -> tuple[list[dict], list[dict]]: 84 | backend = await self.make_backend() 85 | trajectories = await backend.get_trajectories() 86 | 87 | random.Random(self.data_seed).shuffle(trajectories) 88 | val_trajs = self._filter_trajectories( 89 | trajectories[: self.trainer.num_val_trajs], opt 90 | ) 91 | train_trajs = self._filter_trajectories( 92 | trajectories[self.trainer.num_val_trajs :][: self.trainer.num_train_trajs], 93 | opt, 94 | ) 95 | logger.info( 96 | f"Loaded {len(train_trajs)} ({len(val_trajs)}) train (val) trajectories." 97 | ) 98 | 99 | # Disable the backend so we don't accidentally overwrite input data 100 | disable_serialization_backend() 101 | 102 | # convert to examples 103 | train_examples = self._trajs_to_examples(train_trajs, opt) 104 | val_examples = self._trajs_to_examples(val_trajs, opt) 105 | return train_examples, val_examples 106 | 107 | def _filter_trajectories( 108 | self, trajectories: Sequence[Trajectory], opt: LocalLLMSFTOptimizer 109 | ): 110 | return [t for t in trajectories if opt.trajectory_passes(t)] 111 | -------------------------------------------------------------------------------- /src/expts/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from collections.abc import Sequence 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import numpy as np 9 | from aviary.core import ( 10 | Environment, 11 | Message, 12 | Messages, 13 | TaskDatasetClient, 14 | TaskEnvironmentClient, 15 | ToolRequestMessage, 16 | ) 17 | 18 | # from aviary_internal import utils 19 | # from aviary_internal.graph.multiple_completion_op import ( 20 | # SequentialMultipleCompletionLLMCallOp, 21 | # ) 22 | from ldp.agent import Agent 23 | from ldp.alg import Callback 24 | from ldp.data_structures import Trajectory, Transition 25 | from llmclient.cost_tracker import GLOBAL_COST_TRACKER 26 | from fhda.storage import DataRepo 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class VerboseCallback(Callback): 32 | """Callback to visualize notebook state before each transition.""" 33 | 34 | async def before_transition( 35 | self, 36 | traj_id: str, 37 | agent: Agent, 38 | env: Environment, 39 | agent_state: Any, 40 | obs: list[Message], 41 | ) -> None: 42 | for msg in obs: 43 | if msg.content: 44 | logger.info("VerboseCallback:\n%s", msg.content) 45 | 46 | 47 | class SaveWorkspaceCallback(Callback): 48 | def __init__(self, dataset_client: TaskDatasetClient, workspace_repo: DataRepo): 49 | self.dataset_client = dataset_client 50 | self.workspace_repo = workspace_repo 51 | 52 | async def before_transition( 53 | self, 54 | traj_id: str, 55 | agent: Agent, 56 | env: Environment, 57 | agent_state, 58 | obs: list[Message], 59 | ) -> None: 60 | self.start = time.time() 61 | 62 | async def after_transition( 63 | self, 64 | traj_id: str, 65 | agent: Agent, 66 | env: TaskEnvironmentClient, # type: ignore[override] 67 | transition: Transition, 68 | ) -> None: 69 | if not any((transition.done, transition.truncated)): 70 | # only save if the trajectory is over 71 | return 72 | 73 | # TODO: figure out how to support overwrite flag 74 | async with self.dataset_client.get_http_client() as client: 75 | response = await client.post( 76 | "/save_workspace", 77 | json={ 78 | "env_id": env.state.env_id, 79 | "traj_id": traj_id, 80 | "workspace_repo": self.workspace_repo.model_dump(), 81 | "exception": transition.failed, 82 | "cost": GLOBAL_COST_TRACKER.lifetime_cost_usd, 83 | "time": time.time() - self.start, 84 | }, 85 | ) 86 | if not response.is_success: 87 | logger.error(f"Failed to save workspace: {response.content!r}") 88 | 89 | 90 | class LoggingCallback(Callback): 91 | def __init__(self, output_repo: DataRepo): 92 | self.output_repo = output_repo 93 | self.rewards: list[float] = [] 94 | 95 | async def after_eval_step(self, trajectories: Sequence[Trajectory]) -> None: 96 | this_batch_rewards = [ 97 | sum(step.reward for step in traj.steps) for traj in trajectories 98 | ] 99 | self.rewards += this_batch_rewards 100 | self.reward_mean, self.reward_stde = self._compute_summary_stats(self.rewards) 101 | # NOTE: assumes that positive reward implies success 102 | self.acc_mean, self.acc_stde = self._compute_summary_stats( 103 | [r > 0 for r in self.rewards] 104 | ) 105 | 106 | print(flush=True) 107 | logger.info( 108 | f"Accuracy={self.acc_mean:.2f}±{self.acc_stde:.2f}; " 109 | f"Rewards={self.reward_mean:.2f}±{self.reward_stde:.2f}" 110 | ) 111 | 112 | async def after_eval_loop(self) -> None: 113 | results = { 114 | "reward_mean": self.reward_mean, 115 | "reward_stde": self.reward_stde, 116 | "acc_mean": self.acc_mean, 117 | "acc_stde": self.acc_stde, 118 | } 119 | 120 | with open(Path(self.output_repo.local_path) / "results.json", "w") as f: 121 | json.dump(results, f, indent=4) 122 | logger.info(f"These are the results: {results}") 123 | with open(Path(self.output_repo.local_path) / "rewards.json", "w") as f: 124 | json.dump(self.rewards, f) 125 | 126 | def _compute_summary_stats(self, metrics: list) -> tuple[float, float]: 127 | return np.mean(metrics), np.std(metrics) / np.sqrt(len(metrics) + 1) 128 | 129 | 130 | def prev_choice_rep_fn(output_messages: Messages) -> str: 131 | rep = "" 132 | for i, msg in enumerate(output_messages): 133 | assert isinstance(msg, ToolRequestMessage) 134 | assert len(msg.tool_calls) == 1 135 | tc = msg.tool_calls[0] 136 | 137 | match tc.function.name: 138 | case "submit_answer": 139 | rep += f"Option {i + 1}: Submitting solution." 140 | 141 | case "list_workdir": 142 | rep += f"Option {i + 1}: Listing workdir contents." 143 | 144 | case "edit_cell": 145 | idx = tc.function.arguments.get("idx", None) 146 | if idx is None: 147 | rep += f"Option {i + 1}: Adding cell:\n```\n" 148 | else: 149 | rep += f"Option {i + 1}: Editing cell {idx}:\n```\n" 150 | rep += tc.function.arguments["contents"] + "\n```\n" 151 | 152 | case _: 153 | # Don't throw error for now, since there may be a case I haven't considered 154 | # But eventually this should be an exception. 155 | logger.error(f"Unexpected tool call: {tc.function.name}") 156 | 157 | rep += "\n" 158 | 159 | return rep 160 | -------------------------------------------------------------------------------- /src/fhda/templates/lab/mermaidjs.html.j2: -------------------------------------------------------------------------------- 1 | {%- macro mermaid_js( 2 | url="https://cdnjs.cloudflare.com/ajax/libs/mermaid/10.7.0/mermaid.esm.min.mjs" 3 | ) -%} 4 | 138 | 187 | 188 | {%- endmacro %} 189 | -------------------------------------------------------------------------------- /tutorial/platform_api.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Installation\n", 8 | "\n", 9 | "From the root of the repository, run:\n", 10 | "\n", 11 | "```bash\n", 12 | "pip install -e .\n", 13 | "```" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import time\n", 24 | "\n", 25 | "from futurehouse_client import FutureHouseClient, JobNames\n", 26 | "from futurehouse_client.models import TaskRequest, RuntimeConfig\n", 27 | "from futurehouse_client.models.app import AuthType\n", 28 | "import fhda.prompts as prompts" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# Instantiate the FutureHouse client with your API key\n", 38 | "FH_API_KEY = \"\" # Add your API key here\n", 39 | "UPLOAD_ID = (\n", 40 | " \"finch_tutorial\" # This is the folder name of the dataset you uploaded to GCS\n", 41 | ")\n", 42 | "\n", 43 | "client = FutureHouseClient(\n", 44 | " auth_type=AuthType.API_KEY,\n", 45 | " api_key=FH_API_KEY,\n", 46 | ")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# Load your dataset – note you only have to do this once\n", 56 | "# File path can be an absolute path or a relative path to either a directory or a file containing the dataset\n", 57 | "client.upload_file(\n", 58 | " JobNames.FINCH, file_path=\"datasets/brain_size_data.csv\", upload_id=UPLOAD_ID\n", 59 | ")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Check what files were uploaded to your gcs folder\n", 69 | "client.list_files(JobNames.FINCH, upload_id=UPLOAD_ID)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# Define your task\n", 79 | "# Here is where you can update the prompt. As shown below, by default we use CoT prompting,\n", 80 | "# but it is not necessary and we encourage users to experiment with different prompting strategies.\n", 81 | "LANGUAGE = \"PYTHON\" # Choose between \"R\" and \"PYTHON\"\n", 82 | "MAX_STEPS = 30 # You can change this to impose a limit on the number of steps the agent can take\n", 83 | "query = \"Make a short notebook with visualizations exploring the dataset.\"\n", 84 | "\n", 85 | "task = (\n", 86 | " f\"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=LANGUAGE)}\\n\"\n", 87 | " f\"{prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=LANGUAGE)}\"\n", 88 | " f\"Here is the research question to address:\\n\"\n", 89 | " f\"\\n\"\n", 90 | " f\"{query}\\n\"\n", 91 | " f\"\\n\"\n", 92 | ")\n", 93 | "\n", 94 | "# This is extra R prompting to avoid long R output blocks – also feel free to discard this\n", 95 | "if LANGUAGE == \"R\":\n", 96 | " task += f\"\\n{prompts.R_SPECIFIC_GUIDELINES}\"" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# This is how to create a task – you shouldn't need to change anything here\n", 106 | "task_data = TaskRequest(\n", 107 | " name=JobNames.FINCH,\n", 108 | " query=task,\n", 109 | " runtime_config=RuntimeConfig(\n", 110 | " max_steps=MAX_STEPS,\n", 111 | " upload_id=UPLOAD_ID,\n", 112 | " environment_config={\n", 113 | " \"default_cot_prompt\": False,\n", 114 | " \"language\": LANGUAGE,\n", 115 | " },\n", 116 | " ),\n", 117 | ")\n", 118 | "trajectory_id = client.create_task(task_data)\n", 119 | "print(\n", 120 | " f\"Task running on platform, you can view progress live at:https://platform.futurehouse.org/trajectories/{trajectory_id}\"\n", 121 | ")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Jobs take on average 3-10 minutes to complete\n", 131 | "status = \"in progress\"\n", 132 | "while status in [\"in progress\", \"queued\"]:\n", 133 | " time.sleep(15)\n", 134 | " status = client.get_task(trajectory_id).status\n", 135 | "\n", 136 | "if status == \"failed\":\n", 137 | " raise Exception(\"Task failed\")\n", 138 | "\n", 139 | "job_result = client.get_task(trajectory_id, verbose=True)\n", 140 | "answer = job_result.environment_frame[\"state\"][\"state\"][\"answer\"]\n", 141 | "print(f\"The agent's answer to your research question is: \\n{answer}\")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# In addition to viewing the notebook and reasoning trace via the platform,\n", 151 | "# you can also list the files in the trajectory directory and download any files you need\n", 152 | "print(client.list_files(JobNames.FINCH, trajectory_id=trajectory_id))\n", 153 | "\n", 154 | "destination_path = \"output/notebook.ipynb\"\n", 155 | "file_path = \"notebook.ipynb\"\n", 156 | "client.download_file(\n", 157 | " JobNames.FINCH,\n", 158 | " trajectory_id=trajectory_id,\n", 159 | " file_path=file_path,\n", 160 | " destination_path=destination_path,\n", 161 | ")\n", 162 | "print(f\"Notebook saved to {os.path.abspath(destination_path)}\")" 163 | ] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": ".venv", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 2 186 | } 187 | -------------------------------------------------------------------------------- /src/fhda/templates/base/null.j2: -------------------------------------------------------------------------------- 1 | {# 2 | 3 | DO NOT USE THIS AS A BASE, 4 | IF YOU ARE COPY AND PASTING THIS FILE 5 | YOU ARE PROBABLY DOING THINGS INCORRECTLY. 6 | 7 | Null template, does nothing except defining a basic structure 8 | To layout the different blocks of a notebook. 9 | 10 | Subtemplates can override blocks to define their custom representation. 11 | 12 | If one of the block you do overwrite is not a leaf block, consider 13 | calling super. 14 | 15 | {%- block nonLeafBlock -%} 16 | #add stuff at beginning 17 | {{ super() }} 18 | #add stuff at end 19 | {%- endblock nonLeafBlock -%} 20 | 21 | consider calling super even if it is a leaf block, we might insert more blocks later. 22 | 23 | #} 24 | {%- block header -%} 25 | {%- endblock header -%} 26 | {%- block body -%} 27 | {%- block body_header -%} 28 | {%- endblock body_header -%} 29 | {%- block body_loop -%} 30 | {%- for cell in nb.cells -%} 31 | {%- block any_cell scoped -%} 32 | {%- if cell.cell_type == 'code'-%} 33 | {%- if resources.global_content_filter.include_code -%} 34 | {%- block codecell scoped -%} 35 | {%- if resources.global_content_filter.include_input and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 36 | {%- block input_group -%} 37 | {%- if resources.global_content_filter.include_input_prompt -%} 38 | {%- block in_prompt -%}{%- endblock in_prompt -%} 39 | {%- endif -%} 40 | {%- block input -%}{%- endblock input -%} 41 | {%- endblock input_group -%} 42 | {%- endif -%} 43 | {%- if cell.outputs and resources.global_content_filter.include_output -%} 44 | {%- block output_group -%} 45 | {%- if resources.global_content_filter.include_output_prompt -%} 46 | {%- block output_prompt -%}{%- endblock output_prompt -%} 47 | {%- endif -%} 48 | {%- block outputs scoped -%} 49 | {%- for output in cell.outputs -%} 50 | {%- block output scoped -%} 51 | {%- if output.output_type == 'execute_result' -%} 52 | {%- block execute_result scoped -%}{%- endblock execute_result -%} 53 | {%- elif output.output_type == 'stream' -%} 54 | {%- block stream scoped -%} 55 | {%- if output.name == 'stdout' -%} 56 | {%- block stream_stdout scoped -%} 57 | {%- endblock stream_stdout -%} 58 | {%- elif output.name == 'stderr' -%} 59 | {%- block stream_stderr scoped -%} 60 | {%- endblock stream_stderr -%} 61 | {%- elif output.name == 'stdin' -%} 62 | {%- block stream_stdin scoped -%} 63 | {%- endblock stream_stdin -%} 64 | {%- endif -%} 65 | {%- endblock stream -%} 66 | {%- elif output.output_type == 'display_data' -%} 67 | {%- block display_data scoped -%} 68 | {%- block data_priority scoped -%} 69 | {%- endblock data_priority -%} 70 | {%- endblock display_data -%} 71 | {%- elif output.output_type == 'error' -%} 72 | {%- block error scoped -%} 73 | {%- for line in output.traceback -%} 74 | {%- block traceback_line scoped -%}{%- endblock traceback_line -%} 75 | {%- endfor -%} 76 | {%- endblock error -%} 77 | {%- endif -%} 78 | {%- endblock output -%} 79 | {%- endfor -%} 80 | {%- endblock outputs -%} 81 | {%- endblock output_group -%} 82 | {%- endif -%} 83 | {%- endblock codecell -%} 84 | {%- endif -%} 85 | {%- elif cell.cell_type in ['markdown'] -%} 86 | {%- if resources.global_content_filter.include_markdown and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 87 | {%- block markdowncell scoped-%} {%- endblock markdowncell -%} 88 | {%- endif -%} 89 | {%- elif cell.cell_type in ['raw'] -%} 90 | {%- if resources.global_content_filter.include_raw and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 91 | {%- block rawcell scoped -%} 92 | {%- if cell.metadata.get('raw_mimetype', '').lower() in resources.get('raw_mimetypes', ['']) -%} 93 | {{ cell.source }} 94 | {%- endif -%} 95 | {%- endblock rawcell -%} 96 | {%- endif -%} 97 | {%- else -%} 98 | {%- if resources.global_content_filter.include_unknown and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 99 | {%- block unknowncell scoped-%} 100 | {%- endblock unknowncell -%} 101 | {%- endif -%} 102 | {%- endif -%} 103 | {%- endblock any_cell -%} 104 | {%- endfor -%} 105 | {%- endblock body_loop -%} 106 | {%- block body_footer -%} 107 | {%- endblock body_footer -%} 108 | {%- endblock body -%} 109 | 110 | {%- block footer -%} 111 | {%- endblock footer -%} 112 | -------------------------------------------------------------------------------- /src/fhda/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | from pathlib import Path 4 | from tempfile import mkdtemp 5 | 6 | from pydantic import Field 7 | 8 | from aviary.core import EvalAnswerMode, TaskDataset 9 | from .storage import DataRepo 10 | from .data_analysis_env import DataAnalysisEnv 11 | from .utils import NBLanguage, load_mcq 12 | from . import prompts 13 | from .models import ConfigModel 14 | from .notebook_env import NBEnvironment 15 | import logging 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class CapsuleDatasetConfig(ConfigModel): 22 | repo: DataRepo = Field( 23 | default_factory=lambda: DataRepo(name="baseline-envs/data-analysis/v3.1"), 24 | description="The hosted repo to use for the dataset.", 25 | ) 26 | 27 | local_repo_path: str | None = Field( 28 | default=None, 29 | description="If provided, will source the data from this local path instead of the hosted repo.", 30 | ) 31 | 32 | local_output_path: str | None = Field( 33 | default=None, 34 | description="If provided, will save the output to this local path instead of the hosted repo.", 35 | ) 36 | 37 | capsule_mode: str | None = Field( 38 | default="mcq", 39 | description="Determines whether the agent is to answer MCQs, open questions or whether a hypothesis is supported by the data", 40 | ) 41 | 42 | eval_mode: EvalAnswerMode = Field( 43 | default=EvalAnswerMode.LLM, 44 | description="If exact, the target will be 'answer' in the metadata json (i.e. T/F) " 45 | "If llm, the target will be 'result'. Contains/score not supported", 46 | ) 47 | 48 | avoid_images: bool = Field( 49 | default=False, 50 | description="If True, the agent will be prompted to avoid using images in its notebook.", 51 | ) 52 | 53 | preload_notebook: bool = Field( 54 | default=False, 55 | description=( 56 | "If False, the agent will have to start from a virgin notebook. " 57 | "If True, the agent environment will be preloaded with a notebook " 58 | "containing a portion of the capsule problem already completed " 59 | "eg package & data loading." 60 | ), 61 | ) 62 | 63 | prompt_template_key: str = Field( 64 | default="v1.3.1", 65 | description="The key of the prompt template from the CAPSULE_PROMPT_TEMPLATES dict to use for the problem.", 66 | ) 67 | 68 | 69 | class CapsuleDataset(TaskDataset[DataAnalysisEnv]): 70 | """A dataset of tasks derived from data analysis capsules.""" 71 | 72 | def __init__(self, config: CapsuleDatasetConfig): 73 | # Load dataset from local path or hosted repo 74 | if config.local_repo_path: 75 | repo_path = config.local_repo_path 76 | else: 77 | config.repo.pull(progress=True) 78 | repo_path = config.repo.local_path 79 | self.capsules = list(Path(repo_path).rglob("CapsuleFolder*")) 80 | 81 | # Load prompt template 82 | self.prompt = prompts.CAPSULE_PROMPT_TEMPLATES[config.prompt_template_key] 83 | self.config = config 84 | 85 | def get_new_env_by_idx(self, idx: int) -> DataAnalysisEnv: 86 | capsule_path = self.capsules[idx] 87 | metadata = json.load((capsule_path / "metadata.json").open()) 88 | 89 | notebook_name = NBEnvironment.NOTEBOOK_NAME 90 | # Define local capsule directory 91 | if self.config.local_output_path: 92 | problem_dir = Path(self.config.local_output_path) / capsule_path.name 93 | else: 94 | problem_dir = Path(mkdtemp()) 95 | problem_dir.mkdir(parents=True, exist_ok=True) 96 | 97 | # Copy capsule contents to local directory 98 | for item in capsule_path.iterdir(): 99 | if self.config.preload_notebook and str(item).endswith("_stripped.ipynb"): 100 | shutil.copy(item, problem_dir) 101 | elif str(item).endswith((".ipynb", "metadata.json", "checksum")): 102 | continue 103 | elif item.is_dir(): 104 | shutil.copytree(item, problem_dir / item.name) 105 | else: 106 | shutil.copy(item, problem_dir) 107 | 108 | nb_path = problem_dir / notebook_name 109 | 110 | # Define system prompt and problem 111 | if self.config.capsule_mode == "hypothesis": 112 | system_prompt = prompts.CAPSULE_SYSTEM_PROMPT_HYPOTHESIS 113 | problem = self.prompt.replace("{{hypothesis}}", metadata["hypothesis"]) 114 | answer = metadata["answer"] 115 | processed_questions = None 116 | elif self.config.capsule_mode == "mcq": 117 | raw_mcqs = metadata["notebook_questions"]["questions"] 118 | processed_questions = [ 119 | load_mcq(i, open_question=False, question_id=i["id"]) for i in raw_mcqs 120 | ] 121 | system_prompt = prompts.CAPSULE_SYSTEM_PROMPT_MCQ 122 | problem = self.prompt.format( 123 | questions="\n-------\n".join( 124 | [i.question_prompt for i in processed_questions] 125 | ) 126 | ) 127 | answer = {i.question_id: i.ideal_answer for i in processed_questions} 128 | elif self.config.capsule_mode == "open": 129 | system_prompt = prompts.CAPSULE_SYSTEM_PROMPT_OPEN 130 | raw_open_questions = metadata["notebook_questions"]["questions"] 131 | processed_questions = [ 132 | load_mcq(i, open_question=True, question_id=i["id"]) 133 | for i in raw_open_questions 134 | ] 135 | problem = self.prompt.format( 136 | questions="\n-------\n".join( 137 | [i.question_prompt for i in processed_questions] 138 | ) 139 | ) 140 | answer = {i.question_id: i.ideal_answer for i in processed_questions} 141 | else: 142 | raise ValueError(f"Invalid capsule mode: {self.config.capsule_mode}") 143 | 144 | if self.config.avoid_images: 145 | problem += prompts.AVOID_IMAGES 146 | 147 | # Temporarily hard code language to python, but can also use R 148 | language = NBLanguage.PYTHON 149 | return DataAnalysisEnv( 150 | problem_id=capsule_path.name, 151 | problem=problem, 152 | eval_mode=self.config.eval_mode, 153 | nb_path=nb_path, 154 | work_dir=problem_dir, 155 | language=language, 156 | system_prompt=system_prompt, 157 | metadata=metadata, 158 | answer=answer, 159 | mcqs=processed_questions, 160 | ) 161 | 162 | def __len__(self) -> int: 163 | return len(self.capsules) 164 | -------------------------------------------------------------------------------- /src/fhda/data_analysis_env.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | import shutil 4 | from typing import Any, cast 5 | import time 6 | from aviary.core import ( 7 | EvalAnswerMode, 8 | Frame, 9 | Message, 10 | Messages, 11 | Tool, 12 | ) 13 | 14 | from .notebook_env import NBEnvironment 15 | from .utils import NBLanguage, nb_to_html 16 | from . import prompts 17 | from . import config as cfg 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | CORRECT_MSG = "Correct answer!" 22 | INCORRECT_MSG = "Incorrect answer." 23 | 24 | 25 | class DataAnalysisEnv(NBEnvironment): 26 | def __init__( 27 | self, 28 | *, 29 | problem_id: str, 30 | problem: str, 31 | answer: str | int | float | None = None, # noqa: PYI041 32 | system_prompt: str | None = None, 33 | correct_reward: float = 1.0, 34 | eval_mode: EvalAnswerMode, 35 | metadata: dict[str, Any] | None = None, # used for NBEvalExpt 36 | **kwargs, 37 | ): 38 | super().__init__(**kwargs) 39 | 40 | self.problem_id = problem_id 41 | self.problem = problem 42 | self.answer = answer 43 | self.eval_mode = eval_mode 44 | self.correct_reward = correct_reward 45 | self.system_prompt = system_prompt 46 | self.metadata = metadata 47 | self.question_rewards: dict[str, int] = {} 48 | 49 | async def reset(self) -> tuple[Messages, list[Tool]]: 50 | # Discard base class's init_obs and make our own with the problem statement 51 | _, tools = await super().reset() 52 | messages = [ 53 | Message(content=self.problem), 54 | self.get_env_state_msg(), 55 | ] 56 | if self.system_prompt: 57 | messages.append(Message(role="system", content=self.system_prompt)) 58 | init_obs = cast( 59 | Messages, 60 | messages, 61 | ) 62 | 63 | return init_obs, tools 64 | 65 | async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> str: # type: ignore[override] 66 | """Submit an answer to the problem. 67 | 68 | Note that this tool may only be called once and ends the episode. 69 | 70 | Args: 71 | answer: The answer to the problem 72 | """ 73 | self.state.answer = answer 74 | self.state.done = True 75 | logger.info("Submitting answer and closing environment") 76 | await self.close() 77 | logger.info("Answer: %s", answer) 78 | 79 | return f"Submitted answer: {answer}" 80 | 81 | @classmethod 82 | def from_task( 83 | cls, task: str, gcs_artifact_path: str | None = None 84 | ) -> "DataAnalysisEnv": 85 | """ 86 | Perform data analysis on a user query. 87 | 88 | Args: 89 | task: The user query structured as | 90 | 91 | eg "CaspuleFolder-a7812fg | How many genes are differentially expressed between the two conditions?" 92 | """ 93 | logger.info("User task: %s", task) 94 | logger.info("GCS artifact path: %s", gcs_artifact_path) 95 | 96 | if ( 97 | gcs_artifact_path 98 | ): # The files are already in the GCS bucket in a job-specific directory 99 | trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path 100 | nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME 101 | query = task 102 | task_hash = gcs_artifact_path 103 | else: 104 | # Extract data path and query from task 105 | data_path, query = task.split("|") 106 | # Hash the task to get a unique identifier 107 | task_hash = hashlib.sha256(task.encode()).hexdigest() 108 | # Create temporary directory in GCP mounted storage volume 109 | trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}" 110 | trajectory_path.mkdir(parents=True, exist_ok=True) 111 | nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME 112 | # Copy task data to trajectory path 113 | for item in (cfg.DATA_STORAGE_PATH / data_path).iterdir(): 114 | if item.is_file(): 115 | shutil.copy2(item, trajectory_path) 116 | elif item.is_dir(): 117 | shutil.copytree( 118 | item, trajectory_path / item.name, dirs_exist_ok=True 119 | ) 120 | 121 | # Augment incoming task with CoT instructions 122 | augmented_task = f"""\ 123 | Here is the user query to address: 124 | 125 | 126 | {query} 127 | 128 | 129 | {prompts.CHAIN_OF_THOUGHT_AGNOSTIC} 130 | {prompts.GENERAL_NOTEBOOK_GUIDELINES}""" 131 | 132 | language = NBLanguage.PYTHON # In future, this should be a hyperparameter 133 | if language == NBLanguage.R: 134 | augmented_task += f"\n{prompts.R_OUTPUT_RECOMMENDATION_PROMPT}" 135 | 136 | # Log all parameters being passed to constructor 137 | logger.info( 138 | "Creating DataAnalysisEnv with parameters: " 139 | "problem_id=data-analysis-task-%s, " 140 | "problem=%s, " 141 | "eval_mode=%s, " 142 | "nb_path=%s, " 143 | "work_dir=%s, " 144 | "language=%s, " 145 | "system_prompt=%s, " 146 | "use_tmp_work_dir=%s, " 147 | "gcs_artifact_path=%s", 148 | task_hash, 149 | augmented_task, 150 | EvalAnswerMode.LLM, 151 | nb_path, 152 | trajectory_path, 153 | language, 154 | prompts.CAPSULE_SYSTEM_PROMPT_QUERY, 155 | False, 156 | gcs_artifact_path, 157 | ) 158 | if trajectory_path.exists(): 159 | logger.info( 160 | "Files in directory: %s", [f.name for f in trajectory_path.iterdir()] 161 | ) 162 | 163 | return cls( 164 | problem_id=f"data-analysis-task-{task_hash}", 165 | problem=augmented_task, 166 | eval_mode=EvalAnswerMode.LLM, 167 | nb_path=nb_path, 168 | work_dir=trajectory_path, 169 | language=language, 170 | system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_QUERY, 171 | use_tmp_work_dir=False, 172 | ) 173 | 174 | def export_frame(self) -> Frame: 175 | return Frame( 176 | state={ 177 | "last_action": self.state.actions[-1], 178 | "answer": self.state.answer, 179 | "done": self.state.done, 180 | "total_reward": self.state.total_reward, 181 | "nb_state": self.state.nb, 182 | "nb_state_html": nb_to_html(self.state.nb), 183 | }, 184 | info={ 185 | "eval_mode": self.eval_mode, 186 | "language": self.state.language, 187 | "problem": self.problem, 188 | "problem_id": self.problem_id, 189 | }, 190 | ) 191 | -------------------------------------------------------------------------------- /src/scripts/config.py: -------------------------------------------------------------------------------- 1 | """Module for handling yaml config/CLI args and translating them into pydantic configs.""" 2 | 3 | import contextlib 4 | import inspect 5 | import logging 6 | import os 7 | import shutil 8 | import sys 9 | import textwrap 10 | from argparse import ArgumentParser 11 | from collections.abc import Iterable 12 | from pathlib import Path 13 | from typing import Any, TypeVar 14 | 15 | import yaml 16 | from pydantic import BaseModel, ConfigDict 17 | from pydantic_core import PydanticUndefined 18 | 19 | from .logging import configure_logs 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class ConfigModel(BaseModel): 25 | model_config = ConfigDict( 26 | extra="forbid", arbitrary_types_allowed=True, populate_by_name=True 27 | ) 28 | 29 | 30 | TConfig = TypeVar("TConfig", bound=BaseModel) 31 | 32 | 33 | def load_arg_dict(argv: list[str]) -> dict[str, Any]: 34 | """Loads arguments from command line and yaml files into a dictionary. 35 | 36 | For example, if the command line args are `--foo.bar 1 --foo.baz 2`, the resulting 37 | dictionary is {'foo': {'bar': 1, 'baz': 2}}. YAML files are directly parsed as dictionaries. 38 | """ 39 | parser = ArgumentParser(add_help=False) 40 | parser.add_argument("config_files", nargs="*", type=str) 41 | 42 | if not any(a.endswith(".yaml") for a in argv): 43 | # add a dummy arg to avoid argparse error 44 | argv = ["INVALID.yaml", *argv] 45 | args, remaining_args = parser.parse_known_args(argv) 46 | 47 | config_acc: dict[str, Any] = {} 48 | for cfg in args.config_files: 49 | if cfg == "INVALID.yaml": 50 | continue 51 | with open(cfg) as fcfg: 52 | config = yaml.load(fcfg, Loader=yaml.Loader) # noqa: S506 53 | _recursive_update(config_acc, config) 54 | 55 | _parse_cli_args(remaining_args, config_acc) 56 | 57 | return config_acc 58 | 59 | 60 | def load_config( 61 | config_cls: type[TConfig], 62 | verbose: bool = True, 63 | argv: list[str] | None = None, 64 | args_to_exclude: Iterable[str] | None = None, 65 | ) -> TConfig: 66 | """Utility function for handling config and command line args supplied via command line. 67 | 68 | Args: 69 | config_cls: Config class object 70 | verbose: Boolean indicating extent of logging info 71 | argv: List of command line args. If not specified (default), will use sys.argv. 72 | args_to_exclude: Arguments to skip when constructing the config object. 73 | 74 | Returns: 75 | Config object synthesizing CLI args and supplied yaml. 76 | """ 77 | if argv is None: 78 | argv = sys.argv[1:] 79 | 80 | if "-h" in argv or "--help" in argv: 81 | print(get_config_help_string(config_cls)) 82 | sys.exit(0) 83 | 84 | config_acc = load_arg_dict(argv) 85 | if args_to_exclude: 86 | for arg in args_to_exclude: 87 | config_acc.pop(arg, None) 88 | 89 | config = config_cls(**config_acc) 90 | 91 | if verbose: 92 | logger.info("\n%s", yaml.dump({config_cls.__name__: config.model_dump()})) 93 | 94 | return config 95 | 96 | 97 | def _parse_cli_args(remaining_args: list[str], config_acc: dict): 98 | while remaining_args: 99 | arg = remaining_args.pop(0) 100 | if not arg.startswith("--"): 101 | raise ValueError(f"Invalid argument {arg}") 102 | 103 | arg = arg[2:] 104 | try: 105 | value = remaining_args[0] 106 | if value.startswith("--"): 107 | # moved on to next arg 108 | value = "True" 109 | else: 110 | # consumed value - remove from args 111 | remaining_args.pop(0) 112 | except IndexError: 113 | # end of args, assume it was a flag 114 | value = "True" 115 | value = _resolve_value(value) 116 | 117 | arg_hierarchy = arg.split(".") 118 | update_dict: dict[str, Any] = {} 119 | current_dict = update_dict 120 | for arg in arg_hierarchy[:-1]: 121 | current_dict[arg] = {} 122 | current_dict = current_dict[arg] 123 | current_dict[arg_hierarchy[-1]] = value 124 | _recursive_update(config_acc, update_dict) 125 | 126 | 127 | def dump_config(config: BaseModel, path: os.PathLike | str) -> None: 128 | """Dump the input Pydantic config to a YAML file.""" 129 | path = Path(path) 130 | if path.is_dir(): 131 | path /= "config.yaml" 132 | with path.open("w") as f: 133 | yaml.dump(config.model_dump(), f) 134 | 135 | 136 | def get_config_help_string(config_cls: type[BaseModel], indent: int = 0) -> str: 137 | s = ( 138 | textwrap.indent(f"{config_cls.__name__}:", " " * indent) + "\n" 139 | if indent == 0 140 | else "" 141 | ) 142 | 143 | indent += 1 144 | for key, value in config_cls.model_fields.items(): 145 | annot: Any = value.annotation 146 | # Removing the description printing for now, since it's just too verbose. 147 | # TODO: see if we can format it in a more readable way. 148 | # desc = f" # {value.description}" if value.description else "" 149 | desc = "" 150 | 151 | if inspect.isclass(annot): 152 | if issubclass(annot, BaseModel): 153 | s += textwrap.indent(f"{key}:{desc}", " " * indent) + "\n" 154 | s += get_config_help_string(annot, indent) 155 | continue 156 | 157 | annot = annot.__name__ 158 | 159 | if value.is_required(): 160 | s += textwrap.indent(f"{key}: {annot}{desc}", " " * indent) + "\n" 161 | else: 162 | default = ( 163 | value.default_factory 164 | if value.default is PydanticUndefined 165 | else value.default 166 | ) 167 | s += ( 168 | textwrap.indent(f"{key}: {annot} = {default!r}{desc}", " " * indent) 169 | + "\n" 170 | ) 171 | 172 | return s 173 | 174 | 175 | DEFAULT_OUTPUT_LOG_NAME = "output.log" 176 | 177 | 178 | def set_up_output_dir( 179 | directory_path: str | os.PathLike, 180 | config: BaseModel | None = None, 181 | log_name: str | None = DEFAULT_OUTPUT_LOG_NAME, 182 | is_main_process: bool = True, 183 | remove_existing: bool = False, 184 | ) -> Path: 185 | if remove_existing and is_main_process: 186 | shutil.rmtree(directory_path, ignore_errors=True) 187 | directory_path = Path(directory_path) 188 | directory_path.mkdir(parents=True, exist_ok=True) 189 | 190 | if log_name: 191 | configure_logs(log_file=directory_path / log_name) 192 | 193 | if config is not None and is_main_process: 194 | dump_config(config, directory_path) 195 | 196 | return directory_path 197 | 198 | 199 | def _resolve_value(value: str) -> Any: 200 | if value.lower() == "true": 201 | return True 202 | if value.lower() == "false": 203 | return False 204 | 205 | with contextlib.suppress(ValueError): 206 | return int(value) 207 | with contextlib.suppress(ValueError): 208 | return float(value) 209 | 210 | if value == "None": 211 | return None 212 | 213 | return value 214 | 215 | 216 | def configure_yaml_multiline() -> None: 217 | # copied from SWE-agent 218 | def multiline_representer(dumper, data): 219 | """Configures yaml for dumping multiline strings. 220 | 221 | Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data. 222 | """ 223 | if data.count("\n") > 0: # check for multiline string 224 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") 225 | return dumper.represent_scalar("tag:yaml.org,2002:str", data) 226 | 227 | yaml.add_representer(str, multiline_representer) 228 | 229 | 230 | def _recursive_update(d: dict, u: dict) -> dict: 231 | for k, v in u.items(): 232 | if isinstance(v, dict): 233 | d[k] = _recursive_update(d.get(k, {}), v) 234 | else: 235 | d[k] = v 236 | return d 237 | 238 | 239 | CONFIGURATION_ENABLE = {"1", "true", "yes", "on"} 240 | CONFIGURATION_DISABLE = {"0", "false", "no", "off"} 241 | -------------------------------------------------------------------------------- /src/expts/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import shutil 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import Self, cast 7 | 8 | import litellm 9 | from aviary.core import EvalAnswerMode, TaskDatasetClient 10 | from scripts.config import ConfigModel, set_up_output_dir 11 | from scripts.configurable import ConfigurableExpt 12 | from fhda.utils import NBLanguage 13 | from fhda.storage import DataRepo 14 | from ldp.agent import Agent, AgentConfig 15 | from ldp.alg import Evaluator, EvaluatorConfig, TrajectoryFileCallback 16 | from ldp.alg.callbacks import Callback 17 | from ldp.alg.rollout import RolloutManager 18 | from ldp.data_structures import Transition 19 | from llmclient.cost_tracker import enable_cost_tracking 20 | from pydantic import Field, model_validator 21 | 22 | from fhda.data_analysis_env import DataAnalysisEnv 23 | 24 | from .client import TaskDatasetSplit 25 | from .common import ( 26 | LoggingCallback, 27 | SaveWorkspaceCallback, 28 | VerboseCallback, 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class EnvServerConfig(ConfigModel): 35 | split: TaskDatasetSplit 36 | host: str = "localhost" 37 | port: int 38 | request_timeout: float | None = 300.0 39 | 40 | 41 | class NBEvalExpt(ConfigurableExpt): 42 | output_repo: DataRepo 43 | comment: str = "" 44 | overwrite: bool = False 45 | 46 | env: EnvServerConfig 47 | 48 | agent: AgentConfig 49 | evaluator: EvaluatorConfig = Field( 50 | default_factory=lambda: EvaluatorConfig(num_eval_iterations=25) 51 | ) 52 | 53 | async def make_dataset(self) -> TaskDatasetClient: 54 | base_dataset = await TaskDatasetClient.create( 55 | server_url=f"http://{self.env.host}:{self.env.port}", 56 | request_timeout=self.env.request_timeout, 57 | ) 58 | return self.env.split.get_random_split(base_dataset) 59 | 60 | @model_validator(mode="after") 61 | def post_init(self) -> Self: 62 | if self.overwrite: 63 | shutil.rmtree(self.output_repo.local_path, ignore_errors=True) 64 | self.output_repo.mkdir() 65 | return self 66 | 67 | async def run(self) -> None: 68 | set_up_output_dir(self.output_repo.local_path, config=self) 69 | dataset = await self.make_dataset() 70 | agent = self.agent.construct_agent() 71 | callbacks: list[Callback] = [ 72 | TrajectoryFileCallback(self.output_repo.local_path), 73 | LoggingCallback(self.output_repo), 74 | SaveWorkspaceCallback( 75 | dataset_client=dataset, 76 | workspace_repo=DataRepo(name=f"{self.output_repo.name}-workspaces"), 77 | ), 78 | ] 79 | if self.evaluator.batch_size == 1: 80 | callbacks.append(VerboseCallback()) 81 | litellm.drop_params = True 82 | enable_cost_tracking(enabled=True) 83 | evaluator = Evaluator( 84 | config=self.evaluator, 85 | agent=agent, 86 | dataset=dataset, 87 | callbacks=callbacks, 88 | ) 89 | await evaluator.run() 90 | 91 | self.output_repo.push(progress=True) 92 | 93 | 94 | class AdHocExptCallback(Callback): 95 | def __init__(self, output_dir: Path): 96 | self.output_dir = output_dir 97 | 98 | async def after_transition( 99 | self, 100 | traj_id: str, 101 | agent: Agent, 102 | env: DataAnalysisEnv, # type: ignore[override] 103 | transition: Transition, 104 | ) -> None: 105 | if transition.done or transition.truncated or transition.failed: 106 | target_dir = self.output_dir / env.problem_id 107 | if target_dir.exists(): 108 | shutil.rmtree(target_dir) 109 | shutil.copytree(env.state.work_dir, target_dir) 110 | 111 | if transition.action: 112 | action = transition.action.value 113 | submitted_answers = [ 114 | tc.function.arguments["answer"] 115 | for tc in action.tool_calls 116 | if tc.function.name == "submit_answer" 117 | ] 118 | with (self.output_dir / (env.problem_id + "-answer.json")).open( 119 | "w" 120 | ) as f: 121 | json.dump(submitted_answers, f, indent=2) 122 | 123 | 124 | class AdHocExpt(ConfigurableExpt): 125 | problem: str = Field(description="Problem to solve.") 126 | problem_id: str = Field( 127 | default_factory=lambda: f"analysis-{datetime.now().strftime('%Y%m%d-%H%M%S')}", 128 | description="Arbitrary problem ID - outputs will be stored with this name. " 129 | "Auto-assigned with timestamp if not provided.", 130 | ) 131 | 132 | input_dir: str = Field(description="Directory containing input data.") 133 | input_repo: DataRepo | None = Field( 134 | default=None, 135 | description="If provided, will set `input_dir` to `input_repo.local_path`.", 136 | ) 137 | 138 | output_dir: str | None = Field( 139 | default=None, 140 | description="Directory to save output notebooks. " 141 | "If not provided, will use `input_dir`.", 142 | ) 143 | output_repo: DataRepo | None = Field( 144 | default=None, 145 | description="If provided, will set `output_dir` to `output_repo.local_path`.", 146 | ) 147 | 148 | agent: AgentConfig 149 | max_rollout_steps: int | None = None 150 | verbose_callback: bool = True 151 | copy_workspace_callback: bool = True 152 | language: str = "python" 153 | 154 | async def run(self) -> None: 155 | output_path = Path(cast(str, self.output_dir)) 156 | agent = self.agent.construct_agent() 157 | 158 | # Sanity check to prevent misconfiguration for now - may revisit 159 | if not getattr(agent, "hide_old_env_states", True): 160 | raise RuntimeError( 161 | "It is strongly recommended that hide_old_env_states=True " 162 | "if the agent provides this option." 163 | ) 164 | 165 | callbacks: list[Callback] = [] 166 | if self.verbose_callback: 167 | callbacks.append(VerboseCallback()) 168 | if self.copy_workspace_callback: 169 | callbacks.append(AdHocExptCallback(output_path)) 170 | 171 | rollout = RolloutManager(agent=agent, callbacks=callbacks) 172 | 173 | language = NBLanguage.PYTHON if self.language == "python" else NBLanguage.R 174 | 175 | input_path = Path(self.input_dir) 176 | env = DataAnalysisEnv( 177 | problem_id=self.problem_id, 178 | problem=self.problem, 179 | # doesn't really matter, since there's no answer 180 | eval_mode=EvalAnswerMode.EXACT, 181 | # use_tmp_work_dir=True by default, so self.data_dir will be copied 182 | nb_path=(input_path / "analysis.ipynb"), 183 | work_dir=input_path, 184 | language=language, 185 | ) 186 | 187 | await rollout.sample_trajectories( 188 | environments=[env], max_steps=self.max_rollout_steps 189 | ) 190 | 191 | await env.close() 192 | 193 | if self.output_repo is not None: 194 | self.output_repo.push(progress=True) 195 | 196 | @model_validator(mode="before") 197 | @classmethod 198 | def set_dirs(cls, data): 199 | if isinstance(data, dict): 200 | for pfx in ("input", "output"): 201 | if f"{pfx}_repo" in data: 202 | assert f"{pfx}_dir" not in data, ( 203 | f"Cannot provide both {pfx}_dir and {pfx}_repo" 204 | ) 205 | data[f"{pfx}_repo"] = DataRepo(**data[f"{pfx}_repo"]) 206 | data[f"{pfx}_dir"] = data[f"{pfx}_repo"].local_path 207 | return data 208 | 209 | @model_validator(mode="after") 210 | def post_init(self) -> Self: 211 | if self.input_repo is not None: 212 | self.input_repo.pull(progress=True) 213 | 214 | if self.output_repo is not None: 215 | self.output_repo.mkdir() 216 | 217 | if self.output_dir is None: 218 | self.output_dir = self.input_dir 219 | 220 | return self 221 | -------------------------------------------------------------------------------- /tutorial/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "import hashlib\n", 12 | "import shutil\n", 13 | "from pathlib import Path\n", 14 | "import time\n", 15 | "import logging\n", 16 | "\n", 17 | "\n", 18 | "from ldp.agent import AgentConfig\n", 19 | "from ldp.alg.rollout import RolloutManager\n", 20 | "from ldp.data_structures import Trajectory, Transition\n", 21 | "\n", 22 | "from fhda.data_analysis_env import DataAnalysisEnv\n", 23 | "from fhda.notebook_env import NBEnvironment\n", 24 | "from fhda.utils import NBLanguage\n", 25 | "from fhda import prompts\n", 26 | "import fhda.config as cfg" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "def setup_data_analysis_env(\n", 36 | " query: str, dataset: Path, language: NBLanguage = NBLanguage.PYTHON\n", 37 | "):\n", 38 | " # Hash the task to get a unique identifier\n", 39 | " task_hash = hashlib.sha256(query.encode()).hexdigest()\n", 40 | " trajectory_path = (\n", 41 | " Path(os.path.abspath(\"tmp_results_dir\")) / f\"{task_hash}-{time.time()}\"\n", 42 | " )\n", 43 | " trajectory_path.mkdir(parents=True, exist_ok=True)\n", 44 | " nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME\n", 45 | " # Copy task data to trajectory path\n", 46 | " if dataset.is_dir():\n", 47 | " for item in dataset.iterdir():\n", 48 | " if item.is_file():\n", 49 | " shutil.copy2(item, trajectory_path)\n", 50 | " elif item.is_dir():\n", 51 | " shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True)\n", 52 | " else:\n", 53 | " shutil.copy2(dataset, trajectory_path)\n", 54 | " # Augment incoming task with CoT instructions\n", 55 | " augmented_task = f\"\"\"\\\n", 56 | " Here is the user query to address:\n", 57 | "\n", 58 | "\n", 59 | " \n", 60 | " {query}\n", 61 | " \n", 62 | "\n", 63 | " {prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language.name)}\n", 64 | " {prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language.name)}\"\"\"\n", 65 | "\n", 66 | " if language == NBLanguage.R:\n", 67 | " augmented_task += f\"\\n{prompts.R_SPECIFIC_GUIDELINES}\"\n", 68 | "\n", 69 | " dae = DataAnalysisEnv(\n", 70 | " problem_id=f\"data-analysis-task-{task_hash}\",\n", 71 | " problem=augmented_task,\n", 72 | " eval_mode=None,\n", 73 | " nb_path=nb_path,\n", 74 | " work_dir=trajectory_path,\n", 75 | " language=language,\n", 76 | " system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_QUERY,\n", 77 | " use_tmp_work_dir=False,\n", 78 | " run_notebook_on_edit=True if cfg.USE_DOCKER else False,\n", 79 | " )\n", 80 | " return dae" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# ENVIRONMENT CONFIGURATION\n", 90 | "\n", 91 | "# Set your API keys\n", 92 | "os.environ[\"ANTHROPIC_API_KEY\"] = \"\"\n", 93 | "# os.environ[\"OPENAI_API_KEY\"] = \"\"\n", 94 | "# If using docker, be sure to pull the image from docker hub first\n", 95 | "# docker pull futurehouse/bixbench:aviary-notebook-env\n", 96 | "# This image includes many bioinformatics and data science packages\n", 97 | "cfg.USE_DOCKER = False\n", 98 | "# This can be R or PYTHON in Docker or with a local kernel if you have R installed\n", 99 | "LANGUAGE = NBLanguage.PYTHON\n", 100 | "MAX_STEPS = 3\n", 101 | "MODEL_NAME = \"claude-3-7-sonnet-latest\"" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# AVIARY ROLLOUT\n", 111 | "# This folder only contains a single csv file on animal brain size and body mass from here:\n", 112 | "# https://animaltraits.org/\n", 113 | "# However, it could contain many files including nested folders\n", 114 | "\n", 115 | "logger = logging.getLogger(__name__)\n", 116 | "logger.info(\"Setting up data analysis environment\")\n", 117 | "\n", 118 | "dataset = Path(\"datasets/brain_size_data.csv\")\n", 119 | "query = \"Analyze the dataset and give me an in depth analysis using pretty plots. I am particularly interested in crows.\"\n", 120 | "environment = setup_data_analysis_env(query, dataset, LANGUAGE)\n", 121 | "\n", 122 | "agent = AgentConfig(\n", 123 | " agent_type=\"ReActAgent\",\n", 124 | " agent_kwargs={\n", 125 | " \"llm_model\": {\n", 126 | " \"parallel_tool_calls\": False,\n", 127 | " \"num_retries\": 3,\n", 128 | " \"temperature\": 1.0,\n", 129 | " \"name\": MODEL_NAME,\n", 130 | " },\n", 131 | " \"hide_old_env_states\": True,\n", 132 | " },\n", 133 | ")\n", 134 | "\n", 135 | "agent = agent.construct_agent()\n", 136 | "rollout = RolloutManager(agent=agent)\n", 137 | "\n", 138 | "# You can see the notebook updating live in the tmp_results_dir folder\n", 139 | "result = await rollout.sample_trajectories(\n", 140 | " environments=[environment], max_steps=MAX_STEPS\n", 141 | ")\n", 142 | "\n", 143 | "print(\"Trajectory completed! Final notebook available at: \\n\", environment.nb_path)\n", 144 | "print(f\"Final agent answer:\\n{environment.state.answer}\")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# INSPECT THE RESULT\n", 154 | "trajectory = result[0]\n", 155 | "# You can inspect each step in the trajectory and see what the agent's reasoning was,\n", 156 | "# what tool it called, and what the observation was\n", 157 | "for c, step in enumerate(trajectory.steps):\n", 158 | " print(f\"Timestep {c}\")\n", 159 | " print(f\"Done: {step.done}\")\n", 160 | " print(\"Agent Reasoning:\")\n", 161 | " for message in step.agent_state.messages:\n", 162 | " if message.content:\n", 163 | " print(f\"Message: {message.content[:200]} [Truncated]\")\n", 164 | " # print(f\"Observation: {step.observation[:200]} [Truncated]\")\n", 165 | " print(f\"Action: {step.action.value}\")\n", 166 | " print(\"---\")" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# VANILLA ROLLOUT - this is a simple version of the what the rollout Manager does\n", 176 | "dataset_folder = Path(\"dataset\")\n", 177 | "query = \"Analyze the dataset and give me an in depth analysis using pretty plots. I am particularly interested in crows.\"\n", 178 | "environment = setup_data_analysis_env(query, dataset_folder)\n", 179 | "\n", 180 | "obs, tools = await environment.reset()\n", 181 | "agent_state = await agent.init_state(tools)\n", 182 | "trajectory = Trajectory()\n", 183 | "max_steps = 10\n", 184 | "for timestep in range(max_steps):\n", 185 | " action, next_agent_state, value = await agent.get_asv(agent_state, obs)\n", 186 | " next_obs, reward, done, trunc = await environment.step(action.value)\n", 187 | " # Create the transition object\n", 188 | " transition = Transition(\n", 189 | " timestep=timestep,\n", 190 | " agent_state=agent_state,\n", 191 | " next_agent_state=next_agent_state,\n", 192 | " observation=obs,\n", 193 | " next_observation=next_obs,\n", 194 | " action=action,\n", 195 | " reward=reward,\n", 196 | " done=done,\n", 197 | " truncated=trunc,\n", 198 | " value=value,\n", 199 | " )\n", 200 | " # Update steps by creating a new list with the additional transition\n", 201 | " trajectory.steps = [*trajectory.steps, transition]\n", 202 | " if done or trunc:\n", 203 | " break\n", 204 | "\n", 205 | " agent_state = next_agent_state\n", 206 | " obs = next_obs" 207 | ] 208 | } 209 | ], 210 | "metadata": { 211 | "kernelspec": { 212 | "display_name": ".venv", 213 | "language": "python", 214 | "name": "python3" 215 | }, 216 | "language_info": { 217 | "codemirror_mode": { 218 | "name": "ipython", 219 | "version": 3 220 | }, 221 | "file_extension": ".py", 222 | "mimetype": "text/x-python", 223 | "name": "python", 224 | "nbconvert_exporter": "python", 225 | "pygments_lexer": "ipython3" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 2 230 | } 231 | -------------------------------------------------------------------------------- /src/fhda/models.py: -------------------------------------------------------------------------------- 1 | """Module for handling yaml config/CLI args and translating them into pydantic configs.""" 2 | 3 | import contextlib 4 | import inspect 5 | import logging 6 | import os 7 | import shutil 8 | import sys 9 | import textwrap 10 | 11 | from argparse import ArgumentParser 12 | from collections.abc import Iterable 13 | from pathlib import Path 14 | from typing import Any, TypeVar 15 | 16 | import yaml 17 | from pydantic import BaseModel, ConfigDict 18 | from pydantic_core import PydanticUndefined 19 | from ldp.utils import configure_stdout_logs 20 | from llmclient import configure_llm_logs 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def configure_logs( 27 | log_file: str | os.PathLike | None = None, 28 | stdout_level: int | str | tuple[str, int | str] | None = logging.INFO, 29 | fmt: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 30 | ) -> None: 31 | """Configure logs. 32 | 33 | Args: 34 | log_file: Optional log file to add to all loggers. 35 | stdout_level: If int (default) or str, it's a log level for stdout. If two-tuple 36 | of str and int, it's a logger name and log level for that logger. Otherwise, 37 | if None, don't configure stdout logs. 38 | fmt: Logging format string. 39 | """ 40 | configure_llm_logs() 41 | 42 | # Set some good default log levels to avoid too much verbosity 43 | logging.getLogger("dask").setLevel(logging.WARNING) 44 | logging.getLogger("vcr.cassette").setLevel(logging.WARNING) 45 | 46 | if stdout_level is not None: 47 | if isinstance(stdout_level, tuple): 48 | configure_stdout_logs(name=stdout_level[0], level=stdout_level[1], fmt=fmt) 49 | else: 50 | configure_stdout_logs(level=stdout_level, fmt=fmt) 51 | 52 | if log_file is not None: 53 | # Configure all loggers to write to a log file 54 | file_handler = logging.FileHandler(log_file) 55 | file_handler.setLevel(logging.DEBUG) 56 | file_handler.setFormatter(logging.Formatter(fmt)) 57 | logger.info(f"Logging to {log_file}.") 58 | 59 | # apply retroactively to root logger and all existing loggers 60 | for logger_name in ("root", *logging.root.manager.loggerDict.keys()): 61 | logging.getLogger(logger_name).addHandler(file_handler) 62 | 63 | 64 | class ConfigModel(BaseModel): 65 | model_config = ConfigDict( 66 | extra="forbid", arbitrary_types_allowed=True, populate_by_name=True 67 | ) 68 | 69 | 70 | TConfig = TypeVar("TConfig", bound=BaseModel) 71 | 72 | 73 | def load_arg_dict(argv: list[str]) -> dict[str, Any]: 74 | """Loads arguments from command line and yaml files into a dictionary. 75 | 76 | For example, if the command line args are `--foo.bar 1 --foo.baz 2`, the resulting 77 | dictionary is {'foo': {'bar': 1, 'baz': 2}}. YAML files are directly parsed as dictionaries. 78 | """ 79 | parser = ArgumentParser(add_help=False) 80 | parser.add_argument("config_files", nargs="*", type=str) 81 | 82 | if not any(a.endswith(".yaml") for a in argv): 83 | # add a dummy arg to avoid argparse error 84 | argv = ["INVALID.yaml", *argv] 85 | args, remaining_args = parser.parse_known_args(argv) 86 | 87 | config_acc: dict[str, Any] = {} 88 | for cfg in args.config_files: 89 | if cfg == "INVALID.yaml": 90 | continue 91 | with open(cfg) as fcfg: 92 | config = yaml.load(fcfg, Loader=yaml.Loader) # noqa: S506 93 | _recursive_update(config_acc, config) 94 | 95 | _parse_cli_args(remaining_args, config_acc) 96 | 97 | return config_acc 98 | 99 | 100 | def load_config( 101 | config_cls: type[TConfig], 102 | verbose: bool = True, 103 | argv: list[str] | None = None, 104 | args_to_exclude: Iterable[str] | None = None, 105 | ) -> TConfig: 106 | """Utility function for handling config and command line args supplied via command line. 107 | 108 | Args: 109 | config_cls: Config class object 110 | verbose: Boolean indicating extent of logging info 111 | argv: List of command line args. If not specified (default), will use sys.argv. 112 | args_to_exclude: Arguments to skip when constructing the config object. 113 | 114 | Returns: 115 | Config object synthesizing CLI args and supplied yaml. 116 | """ 117 | if argv is None: 118 | argv = sys.argv[1:] 119 | 120 | if "-h" in argv or "--help" in argv: 121 | print(get_config_help_string(config_cls)) 122 | sys.exit(0) 123 | 124 | config_acc = load_arg_dict(argv) 125 | if args_to_exclude: 126 | for arg in args_to_exclude: 127 | config_acc.pop(arg, None) 128 | 129 | config = config_cls(**config_acc) 130 | 131 | if verbose: 132 | logger.info("\n%s", yaml.dump({config_cls.__name__: config.model_dump()})) 133 | 134 | return config 135 | 136 | 137 | def _parse_cli_args(remaining_args: list[str], config_acc: dict): 138 | while remaining_args: 139 | arg = remaining_args.pop(0) 140 | if not arg.startswith("--"): 141 | raise ValueError(f"Invalid argument {arg}") 142 | 143 | arg = arg[2:] 144 | try: 145 | value = remaining_args[0] 146 | if value.startswith("--"): 147 | # moved on to next arg 148 | value = "True" 149 | else: 150 | # consumed value - remove from args 151 | remaining_args.pop(0) 152 | except IndexError: 153 | # end of args, assume it was a flag 154 | value = "True" 155 | value = _resolve_value(value) 156 | 157 | arg_hierarchy = arg.split(".") 158 | update_dict: dict[str, Any] = {} 159 | current_dict = update_dict 160 | for arg in arg_hierarchy[:-1]: 161 | current_dict[arg] = {} 162 | current_dict = current_dict[arg] 163 | current_dict[arg_hierarchy[-1]] = value 164 | _recursive_update(config_acc, update_dict) 165 | 166 | 167 | def dump_config(config: BaseModel, path: os.PathLike | str) -> None: 168 | """Dump the input Pydantic config to a YAML file.""" 169 | path = Path(path) 170 | if path.is_dir(): 171 | path /= "config.yaml" 172 | with path.open("w") as f: 173 | yaml.dump(config.model_dump(), f) 174 | 175 | 176 | def get_config_help_string(config_cls: type[BaseModel], indent: int = 0) -> str: 177 | s = ( 178 | textwrap.indent(f"{config_cls.__name__}:", " " * indent) + "\n" 179 | if indent == 0 180 | else "" 181 | ) 182 | 183 | indent += 1 184 | for key, value in config_cls.model_fields.items(): 185 | annot: Any = value.annotation 186 | # Removing the description printing for now, since it's just too verbose. 187 | # TODO: see if we can format it in a more readable way. 188 | # desc = f" # {value.description}" if value.description else "" 189 | desc = "" 190 | 191 | if inspect.isclass(annot): 192 | if issubclass(annot, BaseModel): 193 | s += textwrap.indent(f"{key}:{desc}", " " * indent) + "\n" 194 | s += get_config_help_string(annot, indent) 195 | continue 196 | 197 | annot = annot.__name__ 198 | 199 | if value.is_required(): 200 | s += textwrap.indent(f"{key}: {annot}{desc}", " " * indent) + "\n" 201 | else: 202 | default = ( 203 | value.default_factory 204 | if value.default is PydanticUndefined 205 | else value.default 206 | ) 207 | s += ( 208 | textwrap.indent(f"{key}: {annot} = {default!r}{desc}", " " * indent) 209 | + "\n" 210 | ) 211 | 212 | return s 213 | 214 | 215 | DEFAULT_OUTPUT_LOG_NAME = "output.log" 216 | 217 | 218 | def set_up_output_dir( 219 | directory_path: str | os.PathLike, 220 | config: BaseModel | None = None, 221 | log_name: str | None = DEFAULT_OUTPUT_LOG_NAME, 222 | is_main_process: bool = True, 223 | remove_existing: bool = False, 224 | ) -> Path: 225 | if remove_existing and is_main_process: 226 | shutil.rmtree(directory_path, ignore_errors=True) 227 | directory_path = Path(directory_path) 228 | directory_path.mkdir(parents=True, exist_ok=True) 229 | 230 | if log_name: 231 | configure_logs(log_file=directory_path / log_name) 232 | 233 | if config is not None and is_main_process: 234 | dump_config(config, directory_path) 235 | 236 | return directory_path 237 | 238 | 239 | def _resolve_value(value: str) -> Any: 240 | if value.lower() == "true": 241 | return True 242 | if value.lower() == "false": 243 | return False 244 | 245 | with contextlib.suppress(ValueError): 246 | return int(value) 247 | with contextlib.suppress(ValueError): 248 | return float(value) 249 | 250 | if value == "None": 251 | return None 252 | 253 | return value 254 | 255 | 256 | def configure_yaml_multiline() -> None: 257 | # copied from SWE-agent 258 | def multiline_representer(dumper, data): 259 | """Configures yaml for dumping multiline strings. 260 | 261 | Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data. 262 | """ 263 | if data.count("\n") > 0: # check for multiline string 264 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") 265 | return dumper.represent_scalar("tag:yaml.org,2002:str", data) 266 | 267 | yaml.add_representer(str, multiline_representer) 268 | 269 | 270 | def _recursive_update(d: dict, u: dict) -> dict: 271 | for k, v in u.items(): 272 | if isinstance(v, dict): 273 | d[k] = _recursive_update(d.get(k, {}), v) 274 | else: 275 | d[k] = v 276 | return d 277 | 278 | 279 | CONFIGURATION_ENABLE = {"1", "true", "yes", "on"} 280 | CONFIGURATION_DISABLE = {"0", "false", "no", "off"} 281 | -------------------------------------------------------------------------------- /tutorial/multi_agent_orchestration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "In this advanced tutorial, we show how to orchestrate a multi-agent, multi-step workflow. The workflow does the following:\n", 8 | "1. Runs RNAseq DEA across 10 parallel Finch runs\n", 9 | "2. Run a single meta-analysis (consensus) Finch run using all the outputs of step 1.\n", 10 | "3. Run 10 parallel Crow runs on the top 10 differentially expressed genes from step 2.\n", 11 | "4. Use Finch to create a volcano plot incorporating results from step 2 and step 3." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from fhda.tortoise import Tortoise, Step\n", 21 | "from futurehouse_client.models import RuntimeConfig\n", 22 | "from futurehouse_client import JobNames\n", 23 | "import pandas as pd\n", 24 | "import json" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# Define our parameters\n", 34 | "TREATMENT = \"dexamethasone\"\n", 35 | "MECHANISM = \"airway smooth muscle cells\"\n", 36 | "CONTEXT = \"asthma\"\n", 37 | "N_TOP_GENES = 10\n", 38 | "PARALLEL_DEA = 5\n", 39 | "FH_API_KEY = \"\" # Add your API key here\n", 40 | "\n", 41 | "# Define the prompts\n", 42 | "DEA_PROMPT = \"\"\"\n", 43 | "Determine the effect of {treatment} on {mechanism} in {context}. \n", 44 | "\n", 45 | "Perform differential expression analysis and pathway analysis on relevant comparison groups. Map all gene IDs to gene symbols using annotation package such as 'org.Hs.eg.db'.\n", 46 | "\n", 47 | "Generate volcano plots and heatmap of differentially expressed genes, and dot plots for enriched pathways, use gene symbols for labels where relevant.\n", 48 | "\n", 49 | "Output a single csv file named \"dea_results.csv\" with the results for all tested genes of the most relevant contrast, report both gene ID and gene symbol.\n", 50 | "\n", 51 | "If there is an error, keep trying, do not give up until you reach the end of the analysis. When mapping gene ID to gene symbol, consider all possible forms of gene IDs, keep trying until the gene symbols are obtained.\n", 52 | "\"\"\"\n", 53 | "\n", 54 | "CONSENSUS_PROMPT = f\"\"\"\n", 55 | "Combine these differential expression analysis results by calculating the mode of log2FC and adjusted p values. Output the results in a file named 'consensus_results.csv', include the columns gene_symbol, log2FC and adjusted P values. In a separate file named 'top{N_TOP_GENES}_genes.csv', output the gene symbols of the consensus most significant genes with the column name \"gene_symbol\". \n", 56 | "\n", 57 | "Create a stacked bar plot showing gene regulation consistency across all analyses. Plot regulation direction (up vs down) on x-axis and percentage of genes in each category on y-axis. Color-code by significance category: all analyses, >50% of analyses and <50% of analyses. Include percentages within each segment and a clear legend. Exclude genes that are non-significant across all analyses.\n", 58 | "\"\"\"\n", 59 | "\n", 60 | "PQA_PROMPT = \"\"\"\n", 61 | "What are the possible mechanisms for {gene} in the effect of {treatment} on {mechanism} in {context}?\n", 62 | "From 1 to 5, with 1 being no evidence of association at all and 5 being strong association with supporting evidence, how strong is the evidence supporting this mechanism?\n", 63 | "Give a concise summary for the evidence in up to 10 words, and a short summary of mechanisms in up to 20 words. Do not include references or links.\n", 64 | "Please share this information in json format in the form of: `\"gene_symbol\": , \"association_evidence_score\":[1...5], \"evidence_summary\": , \"mechanism_summary\": `.\n", 65 | "Share nothing else but the JSON output.\n", 66 | "\"\"\"\n", 67 | "\n", 68 | "VOLCANO_PROMPT = f\"\"\"\n", 69 | "Make an interactive volcano plot. Colour-code by significance categories: top up-regulated genes, up-regulated genes, top down-regulated genes, down-regulated genes, and non-significant genes. Genes considered as top have extra annotation available in 'pqa_results.csv'.\n", 70 | "\n", 71 | "Include hover information according to the categories, for the top genes, on hover, show gene symbol, log2FC, adjusted p value, mechanism, evidence and evidence score. For up and down regulated genes that are not in top {N_TOP_GENES}, show gene symbol, log2FC and adjusted p value. For non-significant genes, do not include hover information.\n", 72 | "\n", 73 | "For the annotations, remove all text in the brackets in the summary columns, and remove the fullstop at the end. For annotations with 6 words or more in a line, use text-wrap. Don't include text on the plot itself. Include a legend explaining the color-codes.\n", 74 | "\n", 75 | "PLEASE USE TEXT WRAP FOR THE HOVER INFORMATION!\n", 76 | "\"\"\"\n", 77 | "\n", 78 | "# Initialize Tortoise\n", 79 | "tortoise = Tortoise(api_key=FH_API_KEY)\n", 80 | "\n", 81 | "OUTPUT_DIR = \"output\"" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# Step 1: Differential Expression Analysis (DEA)\n", 91 | "dea_step = Step(\n", 92 | " name=JobNames.FINCH,\n", 93 | " prompt_template=DEA_PROMPT,\n", 94 | " prompt_args={\"treatment\": TREATMENT, \"mechanism\": MECHANISM, \"context\": CONTEXT},\n", 95 | " input_files={\n", 96 | " \"datasets/GSE52778_All_Sample_FPKM_Matrix.txt.gz\": \"GSE52778_series_matrix.txt.gz\"\n", 97 | " },\n", 98 | " output_files={\"dea_results.csv\": \"dea_results/dea_results.csv\"},\n", 99 | " n_replicate_tasks=PARALLEL_DEA,\n", 100 | " runtime_config=RuntimeConfig(\n", 101 | " max_steps=30,\n", 102 | " environment_config={\"language\": \"R\", \"default_cot_prompt\": True},\n", 103 | " timeout=15 * 60,\n", 104 | " ),\n", 105 | ")\n", 106 | "tortoise.add_step(dea_step)\n", 107 | "\n", 108 | "# Step 2: Consensus Analysis\n", 109 | "consensus_step = Step(\n", 110 | " name=JobNames.FINCH,\n", 111 | " prompt_template=CONSENSUS_PROMPT,\n", 112 | " input_files={f\"{OUTPUT_DIR}/{dea_step.step_id}/dea_results\": \"dea_results/\"},\n", 113 | " output_files={\n", 114 | " \"consensus_results.csv\": \"consensus_results.csv\",\n", 115 | " f\"top{N_TOP_GENES}_genes.csv\": f\"top{N_TOP_GENES}_genes.csv\",\n", 116 | " },\n", 117 | " runtime_config=RuntimeConfig(\n", 118 | " max_steps=30,\n", 119 | " environment_config={\"language\": \"R\", \"default_cot_prompt\": True},\n", 120 | " timeout=15 * 60,\n", 121 | " ),\n", 122 | ")\n", 123 | "tortoise.add_step(consensus_step)\n", 124 | "\n", 125 | "\n", 126 | "# Step 3: Literature Search with PaperQA\n", 127 | "def pqa_post_process(results, output_dir):\n", 128 | " \"\"\"Process the results from multiple PQA tasks\"\"\"\n", 129 | "\n", 130 | " answer_list = []\n", 131 | " for task_response in results.get(\"task_responses\", []):\n", 132 | " try:\n", 133 | " answer = json.loads(task_response.answer)\n", 134 | " if isinstance(answer, list):\n", 135 | " answer = answer[0]\n", 136 | " answer_list.append(answer)\n", 137 | " except Exception as e:\n", 138 | " print(f\"Error parsing answer for task {task_response.task_id}: {e}\")\n", 139 | "\n", 140 | " # Create DataFrame and save\n", 141 | " pqa_df = pd.DataFrame(answer_list)\n", 142 | " pqa_df.to_csv(f\"{output_dir}/pqa_results.csv\", index=False)\n", 143 | " return pqa_df\n", 144 | "\n", 145 | "\n", 146 | "# Define a function to create multiple PQA prompts for genes\n", 147 | "def pqa_prompt_generator():\n", 148 | " \"\"\"Generate PQA prompts for each top gene\"\"\"\n", 149 | " top_genes_df = pd.read_csv(\n", 150 | " f\"{OUTPUT_DIR}/{consensus_step.step_id}/top{N_TOP_GENES}_genes.csv\"\n", 151 | " )\n", 152 | " gene_symbols = top_genes_df[\"gene_symbol\"].tolist()\n", 153 | " prompt_pairs = []\n", 154 | " for gene in gene_symbols:\n", 155 | " prompt_pairs.append(\n", 156 | " (\n", 157 | " PQA_PROMPT,\n", 158 | " {\n", 159 | " \"gene\": gene,\n", 160 | " \"treatment\": TREATMENT,\n", 161 | " \"mechanism\": MECHANISM,\n", 162 | " \"context\": CONTEXT,\n", 163 | " },\n", 164 | " )\n", 165 | " )\n", 166 | " return prompt_pairs\n", 167 | "\n", 168 | "\n", 169 | "# Read top genes and create PQA steps\n", 170 | "pqa_step = Step(\n", 171 | " name=JobNames.CROW,\n", 172 | " prompt_template=PQA_PROMPT,\n", 173 | " prompt_generator=pqa_prompt_generator,\n", 174 | " n_replicate_tasks=N_TOP_GENES, # Will process all top genes in parallel\n", 175 | " post_process=pqa_post_process,\n", 176 | ")\n", 177 | "tortoise.add_step(pqa_step)\n", 178 | "\n", 179 | "# Step 4: Visualization with Volcano Plot\n", 180 | "volcano_step = Step(\n", 181 | " name=JobNames.FINCH,\n", 182 | " prompt_template=VOLCANO_PROMPT,\n", 183 | " input_files={\n", 184 | " f\"{OUTPUT_DIR}/{consensus_step.step_id}/consensus_results.csv\": \"consensus_results.csv\",\n", 185 | " f\"{OUTPUT_DIR}/{pqa_step.step_id}/pqa_results.csv\": \"pqa_results.csv\",\n", 186 | " },\n", 187 | " runtime_config=RuntimeConfig(\n", 188 | " max_steps=30,\n", 189 | " environment_config={\"language\": \"PYTHON\", \"default_cot_prompt\": True},\n", 190 | " timeout=15 * 60,\n", 191 | " ),\n", 192 | ")\n", 193 | "tortoise.add_step(volcano_step)\n", 194 | "\n", 195 | "# Run the pipeline\n", 196 | "results = await tortoise.run_pipeline(OUTPUT_DIR)\n", 197 | "print(\"Pipeline execution completed\")\n", 198 | "print(\n", 199 | " f\"View the final volcano plot at: https://platform.futurehouse.org/trajectories/{tortoise.results[volcano_step.step_id]['task_ids'][0]}\"\n", 200 | ")" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": ".venv", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /src/fhda/templates/lab/base.html.j2: -------------------------------------------------------------------------------- 1 | {%- extends 'display_priority.j2' -%} 2 | {% from 'celltags.j2' import celltags %} 3 | {% from 'cell_id_anchor.j2' import cell_id_anchor %} 4 | 5 | {% block codecell %} 6 | {%- if not cell.outputs -%} 7 | {%- set no_output_class="jp-mod-noOutputs" -%} 8 | {%- endif -%} 9 | {%- if not resources.global_content_filter.include_input -%} 10 | {%- set no_input_class="jp-mod-noInput" -%} 11 | {%- endif -%} 12 | 15 | {%- endblock codecell %} 16 | 17 | {% block input_group -%} 18 | 25 | {% endblock input_group %} 26 | 27 | {% block input %} 28 | 33 | {%- endblock input %} 34 | 35 | {% block output_group %} 36 | 41 | {% endblock output_group %} 42 | 43 | {% block outputs %} 44 | 47 | {% endblock outputs %} 48 | 49 | {% block in_prompt -%} 50 | 57 | {%- endblock in_prompt %} 58 | 59 | {% block empty_in_prompt -%} 60 | 62 | {%- endblock empty_in_prompt %} 63 | 64 | {# 65 | output_prompt doesn't do anything in HTML, 66 | because there is a prompt div in each output area (see output block) 67 | #} 68 | {% block output_prompt %} 69 | {% endblock output_prompt %} 70 | 71 | {% block output_area_prompt %} 72 | 81 | {% endblock output_area_prompt %} 82 | 83 | {% block output %} 84 | {%- if output.output_type == 'execute_result' -%} 85 |