├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── pyproject.toml ├── src ├── __init__.py ├── expts │ ├── __init__.py │ ├── client.py │ ├── common.py │ ├── dsbench.py │ ├── eval.py │ ├── server.py │ └── training.py ├── fhda │ ├── Dockerfile.pinned │ ├── __init__.py │ ├── config.py │ ├── data_analysis_env.py │ ├── dataset.py │ ├── dev.yaml │ ├── kernel_requirements.txt │ ├── models.py │ ├── notebook_env.py │ ├── prompts.py │ ├── storage.py │ ├── templates │ │ ├── base │ │ │ ├── cell_id_anchor.j2 │ │ │ ├── celltags.j2 │ │ │ ├── display_priority.j2 │ │ │ ├── jupyter_widgets.html.j2 │ │ │ ├── mathjax.html.j2 │ │ │ └── null.j2 │ │ └── lab │ │ │ ├── base.html.j2 │ │ │ ├── conf.json │ │ │ ├── index.html.j2 │ │ │ ├── mermaidjs.html.j2 │ │ │ └── static │ │ │ ├── index.css │ │ │ ├── theme-dark.css │ │ │ └── theme-light.css │ ├── tortoise.py │ └── utils.py └── scripts │ ├── __init__.py │ ├── bixbench_evaluation │ ├── run.sh │ ├── runner.yaml │ └── server.yaml │ ├── config.py │ ├── configurable.py │ └── expt_logging.py ├── tests └── test_nb_env.py ├── tutorial ├── consensus.ipynb ├── datasets │ ├── GSE52778_All_Sample_FPKM_Matrix.txt.gz │ └── brain_size_data.csv ├── example.ipynb ├── multi_agent_orchestration.ipynb ├── platform_api.ipynb └── tmp_results_dir │ ├── bf222a115d3970be6e12430b1cd57eb67d2f36b1950ed5765a247a1f071e7569-1741036231.876748 │ ├── brain_size_data.csv │ ├── claude-3-7.html │ ├── notebook.ipynb │ └── notebook.md │ └── bf222a115d3970be6e12430b1cd57eb67d2f36b1950ed5765a247a1f071e7569-1741036365.66959 │ ├── brain_size_data.csv │ ├── claude-3-7.html │ ├── notebook.ipynb │ └── notebook.md └── uv.lock /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.12"] 14 | 15 | steps: 16 | - name: Check out Git repository 17 | uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install setuptools>=66 wheel>=0.36 build 26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 27 | if [ -f pyproject.toml ]; then pip install -e .[dev]; fi 28 | 29 | - name: Run Lint 30 | run: | 31 | # Check for linting issues 32 | ruff check . 33 | # Check for formatting issues (will fail if code needs formatting) 34 | ruff format --check . 35 | 36 | test: 37 | runs-on: ubuntu-latest 38 | strategy: 39 | matrix: 40 | python-version: ["3.12"] 41 | 42 | steps: 43 | - name: Check out Git repository 44 | uses: actions/checkout@v4 45 | 46 | - name: Set up Python ${{ matrix.python-version }} 47 | uses: actions/setup-python@v5 48 | with: 49 | python-version: ${{ matrix.python-version }} 50 | - name: Install dependencies 51 | run: | 52 | python -m pip install --upgrade pip 53 | pip install setuptools>=66 wheel>=0.36 build 54 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 55 | if [ -f pyproject.toml ]; then pip install -e .[dev]; fi 56 | 57 | - name: Run Test 58 | run: | 59 | python -m pytest 60 | env: 61 | GITHUB_ACTIONS: true 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # User-specific 3 | local 4 | 5 | # IntelliJ 6 | out/ 7 | 8 | # Local History for Visual Studio Code 9 | .history/ 10 | 11 | # Built Visual Studio Code Extensions 12 | *.vsix 13 | # General 14 | .DS_Store 15 | .AppleDouble 16 | .LSOverride 17 | 18 | # Files that might appear in the root of a volume 19 | .DocumentRevisions-V100 20 | .fseventsd 21 | .Spotlight-V100 22 | .TemporaryItems 23 | .Trashes 24 | .VolumeIcon.icns 25 | .com.apple.timemachine.donotpresent 26 | 27 | # Directories potentially created on remote AFP share 28 | .AppleDB 29 | .AppleDesktop 30 | Network Trash Folder 31 | Temporary Items 32 | .apdisk 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | share/python-wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | src/fhda/storage/ 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | *.py,cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | cover/ 85 | 86 | # Translations 87 | *.mo 88 | *.pot 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | *.ipynb 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | tutorial/tmp_results_dir 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | # PyCharm 143 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 144 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 145 | # and can be added to the global gitignore or merged into this file. For a more nuclear 146 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 147 | .idea/ 148 | # Local .terraform directories 149 | **/.terraform/* 150 | 151 | # .tfstate files 152 | *.tfstate 153 | *.tfstate.* 154 | 155 | # Crash log files 156 | crash.log 157 | crash.*.log 158 | 159 | # Exclude all .tfvars files, which are likely to contain sensitive data, such as 160 | # password, private keys, and other secrets. These should not be part of version 161 | # control as they are data points which are potentially sensitive and subject 162 | # to change depending on the environment. 163 | *.tfvars 164 | *.tfvars.json 165 | 166 | # Ignore override files as they are usually used to override resources locally and so 167 | # are not checked in 168 | override.tf 169 | override.tf.json 170 | *_override.tf 171 | *_override.tf.json 172 | 173 | # Include override files you do wish to add to version control using negated pattern 174 | # !example_override.tf 175 | 176 | # Include tfplan files to ignore the plan output of command: terraform plan -out=tfplan 177 | # example: *tfplan* 178 | 179 | # Ignore CLI configuration files 180 | .terraformrc 181 | terraform.rc 182 | 183 | # SLURM artifacts 184 | slurm_outputs/ 185 | 186 | # SWE-agent auto-creates these files 187 | keys.cfg 188 | 189 | # Version files made by setuptools_scm 190 | **/version.py 191 | 192 | # WandB cache files (e.g. generated by pytest) 193 | wandb/ 194 | 195 | # VSCode repo settings 196 | .vscode/ 197 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | default_language_version: 3 | python: python3 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | - id: check-added-large-files 9 | - id: check-byte-order-marker 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-shebang-scripts-are-executable 13 | - id: check-symlinks 14 | - id: check-toml 15 | - id: check-yaml 16 | - id: debug-statements 17 | - id: detect-private-key 18 | - id: end-of-file-fixer 19 | - id: mixed-line-ending 20 | - id: trailing-whitespace 21 | - repo: https://github.com/astral-sh/ruff-pre-commit 22 | rev: v0.9.1 23 | hooks: 24 | - id: ruff 25 | args: [--fix, --exit-non-zero-on-fix] 26 | - id: ruff-format 27 | - repo: https://github.com/rbubley/mirrors-prettier 28 | rev: v3.4.2 29 | hooks: 30 | - id: prettier 31 | - repo: https://github.com/jumanjihouse/pre-commit-hooks 32 | rev: 3.0.0 33 | hooks: 34 | - id: check-mailmap 35 | - repo: https://github.com/codespell-project/codespell 36 | rev: v2.3.0 37 | hooks: 38 | - id: codespell 39 | additional_dependencies: [".[toml]"] 40 | exclude_types: [jupyter] 41 | - repo: https://github.com/pappasam/toml-sort 42 | rev: v0.24.2 43 | hooks: 44 | - id: toml-sort-fix 45 | exclude: poetry.lock 46 | - repo: https://github.com/srstevenson/nb-clean 47 | rev: 4.0.1 48 | hooks: 49 | - id: nb-clean 50 | args: [--preserve-cell-outputs, --remove-empty-cells] 51 | - repo: https://github.com/henryiii/validate-pyproject-schema-store 52 | rev: 2025.01.10 53 | hooks: 54 | - id: validate-pyproject 55 | - repo: https://github.com/pre-commit/mirrors-mypy 56 | rev: v1.14.1 57 | hooks: 58 | - id: mypy 59 | additional_dependencies: 60 | - aiohttp 61 | - boto3-stubs[s3] 62 | - docstring_parser 63 | - fh-llm-client[deepseek]>=0.0.11 # Match aviary_internal pyproject.toml 64 | - fhaviary[server] >= 0.18.0 # Match aviary_internal pyproject.toml 65 | - gitpython 66 | - google-auth>=2.31 # Match aviary_internal pyproject.toml 67 | - google-cloud 68 | - google-cloud-run 69 | - google-cloud-tasks 70 | - google-cloud-secret-manager 71 | - google-cloud-storage 72 | - httpx<0.28 # Match aviary_internal pyproject.toml 73 | - jupyter-client 74 | - ldp>=0.22.0 # Match aviary_internal pyproject.toml 75 | - litellm>=1.40.9 # Match aviary_internal pyproject.toml 76 | - nbformat 77 | - numpy<2 # Match aviary_internal pyproject.toml 78 | - omegaconf 79 | - openai>=1 # Match aviary_internal pyproject.toml 80 | - pandas-stubs 81 | - pydantic~=2.0 # Match aviary_internal pyproject.toml 82 | - rich 83 | - SQLAlchemy[aiosqlite]~=2.0 # Match fhaviary pyproject.toml and dev-requirements.txt 84 | - tenacity 85 | - tiktoken 86 | - torch==2.5.1 # Match aviary_internal/nn/requirements.txt 87 | - types-aiofiles 88 | - types-Pillow 89 | - types-PyYAML 90 | - types-requests 91 | - types-tqdm 92 | - typing-extensions 93 | - wandb 94 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 FutureHouse 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Analysis Crow: A Jupyter Notebook Agent 2 | 3 | Data Analysis Crow is an AI agent framework designed to perform complex scientific data analysis tasks by iteratively working through Jupyter notebooks. This agent takes in datasets and prompts, then systematically explores, analyzes, and interprets the data to provide comprehensive answers and insights. 4 | 5 | The agent was used to produce the trajectories for the [BixBench benchmark](https://github.com/Future-House/bixbench). 6 | 7 | ## Key Features 8 | 9 | - Accepts datasets and natural language prompts 10 | - Iteratively builds Jupyter notebooks to answer research questions 11 | - Works with Python, R, and Bash code execution 12 | - Specializes in bioinformatics analysis but adaptable to various domains 13 | - Comes with a Docker image including most common bioinformatics packages 14 | 15 | ## Links 16 | 17 | - [Installation](#installation) 18 | - [Using the Agent](#using-the-agent) 19 | - [Advanced Usage](#advanced-usage) 20 | - [BixBench Benchmark](#bixbench-benchmark) 21 | 22 | ## Installation 23 | 24 | ```bash 25 | # Clone the repository 26 | git clone https://github.com/Future-House/data-analysis-crow.git 27 | cd data-analysis-crow 28 | 29 | # Install dependencies 30 | pip install -e . 31 | 32 | # OPTIONAL:pull the docker image with bioinformatics packages 33 | docker pull futurehouse/bixbench:aviary-notebook-env 34 | ``` 35 | 36 | ## Prerequisites 37 | 38 | ### API Keys 39 | 40 | We support all LLMs that are supported by [litellm](https://github.com/BerriAI/litellm). Create a `.env` file with the API keys for the LLMs you want to use. For example: 41 | 42 | ``` 43 | OPENAI_API_KEY = "your-openai-api-key" 44 | ANTHROPIC_API_KEY = "your-anthropic-api-key" 45 | ``` 46 | 47 | ## Using the Agent 48 | 49 | The agent works by taking a dataset and a prompt, then iteratively building a Jupyter notebook to answer the question. Visit the [tutorial](https://github.com/Future-House/data-analysis-crow/blob/main/tutorial/example.ipynb) for a simple step-by-step guide on how to use the agent. 50 | 51 | ## Advanced Usage 52 | For advanced evaluations, you can configure `server.yaml` and `runner.yaml` in the `src/scripts/bixbench_evaluation` directory and then run the evaluation script: 53 | ```bash 54 | bash src/scripts/bixbench_evaluation/run.sh 55 | ``` 56 | 57 | This will: 58 | 1. Load the specified dataset 59 | 2. Process the prompt to understand the research question 60 | 3. Generate a Jupyter notebook with progressive analysis steps 61 | 4. Provide a final answer based on the analysis 62 | 63 | Results are saved in the output directory specified in your configuration file. 64 | 65 | Note that the dataset and environment configuration must be updated appropriately. For an example, see [dataset.py](https://github.com/Future-House/data-analysis-crow/blob/main/src/fhda/dataset.py) which includes the capsule dataset configuration used for the BixBench benchmark. 66 | 67 | We also recommend visiting the BixBench repository where we share a full evaluation harness for the agent. 68 | 69 | ## Hosted Agent 70 | Coming soon! 71 | 72 | ## BixBench Benchmark 73 | 74 | Data Analysis Crow was used to produce the trajectories for the [BixBench benchmark](https://github.com/Future-House/bixbench), which evaluates AI agents on real-world bioinformatics tasks. 75 | 76 | BixBench tests AI agents' ability to: 77 | 78 | - Explore biological datasets 79 | - Perform long, multi-step computational analyses 80 | - Interpret nuanced results in the context of a research question 81 | 82 | You can find the BixBench dataset in [Hugging Face](https://huggingface.co/datasets/futurehouse/BixBench), the paper [here](https://arxiv.org/abs/2503.00096), and the blog post [here](https://www.futurehouse.org/research-announcements/bixbench). 83 | 84 | ### Running BixBench Evaluations 85 | 86 | To use this agent for BixBench evaluations, we recommend visiting the [BixBench repository](https://github.com/Future-House/bixbench) for more details. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools>=64"] 4 | 5 | [project] 6 | authors = [ 7 | {email = "hello@futurehouse.org", name = "FutureHouse technical staff"} 8 | ] 9 | dependencies = [ 10 | "aiodocker==0.24.0", 11 | "fhaviary[server]==0.19.0", 12 | "ldp==0.26.0", 13 | "pandas==2.2.3", 14 | "numpy==2.2.3", 15 | "matplotlib==3.10.0", 16 | "aiofiles==24.1.0", 17 | "google-auth==2.38.0", 18 | "google-cloud-storage==3.0.0", 19 | "google-cloud-secret-manager==2.23.0", 20 | "futurehouse-client==0.3.18", 21 | "jupyter==1.1.1", 22 | "nbconvert==7.16.6", 23 | "notebook==7.3.2", 24 | "nbformat==5.10.4" 25 | ] 26 | description = "Data analysis crow" 27 | name = "fhda" 28 | requires-python = ">=3.12" 29 | version = "1.0.0" 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | "black", 34 | "isort", 35 | "mypy", 36 | "pre-commit", 37 | "pytest", 38 | "pytest-asyncio", 39 | "pytest-cov", 40 | "ruff" 41 | ] 42 | 43 | [project.scripts] 44 | run_expt = 'scripts.configurable:_run_expt' 45 | 46 | [tool.setuptools] 47 | package-dir = {"" = "src"} 48 | 49 | [tool.setuptools.packages.find] 50 | where = ["src"] 51 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fhda 2 | from . import expts 3 | from . import scripts 4 | 5 | __all__ = ["fhda", "expts", "scripts"] 6 | -------------------------------------------------------------------------------- /src/expts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Future-House/data-analysis-crow/095df6e47b1cfea4361927872f02a8c40204f864/src/expts/__init__.py -------------------------------------------------------------------------------- /src/expts/client.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing 3 | from enum import StrEnum, auto 4 | 5 | from aviary.core import ( 6 | TaskDatasetClient, 7 | TaskEnvironmentClient, 8 | ) 9 | from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin 10 | 11 | 12 | class TaskDatasetSubsetClient(TaskDatasetClient, ComputeTrajectoryMetricsMixin): 13 | """Convenience class to subset a dataset using a single server.""" 14 | 15 | def __init__(self, client: TaskDatasetClient, task_idcs: list[int]) -> None: 16 | super().__init__( 17 | server_url=client.server_url, request_timeout=client.request_timeout 18 | ) 19 | self.idcs = task_idcs 20 | 21 | def __len__(self) -> int: 22 | return len(self.idcs) 23 | 24 | def get_new_env_by_idx(self, idx: int) -> TaskEnvironmentClient: 25 | return super().get_new_env_by_idx(self.idcs[idx]) 26 | 27 | 28 | class TaskDatasetSplit(StrEnum): 29 | TRAIN = auto() 30 | EVAL = auto() 31 | TEST = auto() 32 | ALL = auto() 33 | 34 | def get_random_split( 35 | self, dataset_client: TaskDatasetClient, seed: int = 0 36 | ) -> TaskDatasetClient: 37 | if self == TaskDatasetSplit.ALL: 38 | return dataset_client 39 | 40 | # Slightly hacky way to make a split for now 41 | # Split the dataset into a 80/10/10 split using a deterministic seed 42 | n_total = len(dataset_client) 43 | all_idcs = random.Random(seed).sample(range(n_total), n_total) 44 | 45 | match self: 46 | case TaskDatasetSplit.TRAIN: 47 | idcs = all_idcs[: int(0.8 * n_total)] 48 | case TaskDatasetSplit.EVAL: 49 | idcs = all_idcs[int(0.8 * n_total) : int(0.9 * n_total)] 50 | case TaskDatasetSplit.TEST: 51 | idcs = all_idcs[int(0.9 * n_total) :] 52 | 53 | case _: 54 | typing.assert_never(self) 55 | 56 | return TaskDatasetSubsetClient(dataset_client, idcs) 57 | -------------------------------------------------------------------------------- /src/expts/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from collections.abc import Sequence 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import numpy as np 9 | from aviary.core import ( 10 | Environment, 11 | Message, 12 | Messages, 13 | TaskDatasetClient, 14 | TaskEnvironmentClient, 15 | ToolRequestMessage, 16 | ) 17 | 18 | # from aviary_internal import utils 19 | # from aviary_internal.graph.multiple_completion_op import ( 20 | # SequentialMultipleCompletionLLMCallOp, 21 | # ) 22 | from ldp.agent import Agent 23 | from ldp.alg import Callback 24 | from ldp.data_structures import Trajectory, Transition 25 | from llmclient.cost_tracker import GLOBAL_COST_TRACKER 26 | from fhda.storage import DataRepo 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class VerboseCallback(Callback): 32 | """Callback to visualize notebook state before each transition.""" 33 | 34 | async def before_transition( 35 | self, 36 | traj_id: str, 37 | agent: Agent, 38 | env: Environment, 39 | agent_state: Any, 40 | obs: list[Message], 41 | ) -> None: 42 | for msg in obs: 43 | if msg.content: 44 | logger.info("VerboseCallback:\n%s", msg.content) 45 | 46 | 47 | class SaveWorkspaceCallback(Callback): 48 | def __init__(self, dataset_client: TaskDatasetClient, workspace_repo: DataRepo): 49 | self.dataset_client = dataset_client 50 | self.workspace_repo = workspace_repo 51 | 52 | async def before_transition( 53 | self, 54 | traj_id: str, 55 | agent: Agent, 56 | env: Environment, 57 | agent_state, 58 | obs: list[Message], 59 | ) -> None: 60 | self.start = time.time() 61 | 62 | async def after_transition( 63 | self, 64 | traj_id: str, 65 | agent: Agent, 66 | env: TaskEnvironmentClient, # type: ignore[override] 67 | transition: Transition, 68 | ) -> None: 69 | if not any((transition.done, transition.truncated)): 70 | # only save if the trajectory is over 71 | return 72 | 73 | # TODO: figure out how to support overwrite flag 74 | async with self.dataset_client.get_http_client() as client: 75 | response = await client.post( 76 | "/save_workspace", 77 | json={ 78 | "env_id": env.state.env_id, 79 | "traj_id": traj_id, 80 | "workspace_repo": self.workspace_repo.model_dump(), 81 | "exception": transition.failed, 82 | "cost": GLOBAL_COST_TRACKER.lifetime_cost_usd, 83 | "time": time.time() - self.start, 84 | }, 85 | ) 86 | if not response.is_success: 87 | logger.error(f"Failed to save workspace: {response.content!r}") 88 | 89 | 90 | class LoggingCallback(Callback): 91 | def __init__(self, output_repo: DataRepo): 92 | self.output_repo = output_repo 93 | self.rewards: list[float] = [] 94 | 95 | async def after_eval_step(self, trajectories: Sequence[Trajectory]) -> None: 96 | this_batch_rewards = [ 97 | sum(step.reward for step in traj.steps) for traj in trajectories 98 | ] 99 | self.rewards += this_batch_rewards 100 | self.reward_mean, self.reward_stde = self._compute_summary_stats(self.rewards) 101 | # NOTE: assumes that positive reward implies success 102 | self.acc_mean, self.acc_stde = self._compute_summary_stats( 103 | [r > 0 for r in self.rewards] 104 | ) 105 | 106 | print(flush=True) 107 | logger.info( 108 | f"Accuracy={self.acc_mean:.2f}±{self.acc_stde:.2f}; " 109 | f"Rewards={self.reward_mean:.2f}±{self.reward_stde:.2f}" 110 | ) 111 | 112 | async def after_eval_loop(self) -> None: 113 | results = { 114 | "reward_mean": self.reward_mean, 115 | "reward_stde": self.reward_stde, 116 | "acc_mean": self.acc_mean, 117 | "acc_stde": self.acc_stde, 118 | } 119 | 120 | with open(Path(self.output_repo.local_path) / "results.json", "w") as f: 121 | json.dump(results, f, indent=4) 122 | logger.info(f"These are the results: {results}") 123 | with open(Path(self.output_repo.local_path) / "rewards.json", "w") as f: 124 | json.dump(self.rewards, f) 125 | 126 | def _compute_summary_stats(self, metrics: list) -> tuple[float, float]: 127 | return np.mean(metrics), np.std(metrics) / np.sqrt(len(metrics) + 1) 128 | 129 | 130 | def prev_choice_rep_fn(output_messages: Messages) -> str: 131 | rep = "" 132 | for i, msg in enumerate(output_messages): 133 | assert isinstance(msg, ToolRequestMessage) 134 | assert len(msg.tool_calls) == 1 135 | tc = msg.tool_calls[0] 136 | 137 | match tc.function.name: 138 | case "submit_answer": 139 | rep += f"Option {i + 1}: Submitting solution." 140 | 141 | case "list_workdir": 142 | rep += f"Option {i + 1}: Listing workdir contents." 143 | 144 | case "edit_cell": 145 | idx = tc.function.arguments.get("idx", None) 146 | if idx is None: 147 | rep += f"Option {i + 1}: Adding cell:\n```\n" 148 | else: 149 | rep += f"Option {i + 1}: Editing cell {idx}:\n```\n" 150 | rep += tc.function.arguments["contents"] + "\n```\n" 151 | 152 | case _: 153 | # Don't throw error for now, since there may be a case I haven't considered 154 | # But eventually this should be an exception. 155 | logger.error(f"Unexpected tool call: {tc.function.name}") 156 | 157 | rep += "\n" 158 | 159 | return rep 160 | -------------------------------------------------------------------------------- /src/expts/dsbench.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | from aviary_internal import __version__, utils 7 | from pydantic import Field 8 | from tqdm import tqdm 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class GetKaggleInfo(utils.ConfigurableExpt): 14 | dataset_repo: utils.DataRepo = Field( 15 | default_factory=lambda: utils.DataRepo( 16 | name="baseline-envs/dsbench/data_modeling" 17 | ) 18 | ) 19 | 20 | async def run(self) -> None: 21 | try: 22 | from kaggle.api.kaggle_api_extended import KaggleApi 23 | except ImportError: 24 | raise ImportError( 25 | "Please `pip install kaggle` and set up authentication." 26 | ) from None 27 | 28 | api = KaggleApi() 29 | # Will raise if user is not authenticated 30 | api.authenticate() 31 | 32 | src_dir = Path(self.dataset_repo.local_path) 33 | competitions = sorted([d.name for d in (src_dir / "data_resplit").glob("*")]) 34 | kaggle_info: dict[str, dict[str, float | bool | list[float]]] = {} 35 | 36 | for comp in tqdm(competitions, desc="Querying Kaggle", ncols=0): 37 | # Bit ugly: to determine if 'best' is max or min, we get the GT result and compare 38 | # to the actual submissions. I can't find any documentation saying the leaderboard 39 | # is ordered. 40 | 41 | try: 42 | target_result = float( 43 | (src_dir / "save_performance/GT" / comp / "result.txt").read_text() 44 | ) 45 | except FileNotFoundError: 46 | logger.error(f"Could not find GT result file for {comp} - skipping.") 47 | continue 48 | 49 | leaderboard = api.competition_leaderboard_view(comp) 50 | scores = [float(entry.score) for entry in leaderboard if entry.hasScore] 51 | if not scores: 52 | logger.error(f"No scores found for {comp} - skipping.") 53 | continue 54 | 55 | max_score, min_score = max(scores), min(scores) 56 | 57 | if min_score >= target_result: 58 | # smaller is better 59 | kaggle_info[comp] = { 60 | "best_score": min_score, 61 | "max_is_best": False, 62 | "scores": scores, 63 | } 64 | 65 | elif max_score <= target_result: 66 | # larger is better 67 | kaggle_info[comp] = { 68 | "best_score": max_score, 69 | "max_is_best": True, 70 | "scores": scores, 71 | } 72 | 73 | else: 74 | raise RuntimeError(f"Could not determine best score for {comp}.") 75 | 76 | with (src_dir / "kaggle_submissions.json").open("w") as f: 77 | json.dump( 78 | { 79 | "metadata": { 80 | "description": "Created by data_analysis.expts.dsbench.GetKaggleInfo.", 81 | "timestamp": datetime.now().isoformat(), 82 | "aviary_internal": __version__, 83 | }, 84 | "kaggle_info": kaggle_info, 85 | }, 86 | f, 87 | indent=2, 88 | ) 89 | 90 | self.dataset_repo.push(progress=True) 91 | -------------------------------------------------------------------------------- /src/expts/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import shutil 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import Self, cast 7 | 8 | import litellm 9 | from aviary.core import EvalAnswerMode, TaskDatasetClient 10 | from scripts.config import ConfigModel, set_up_output_dir 11 | from scripts.configurable import ConfigurableExpt 12 | from fhda.utils import NBLanguage 13 | from fhda.storage import DataRepo 14 | from ldp.agent import Agent, AgentConfig 15 | from ldp.alg import Evaluator, EvaluatorConfig, TrajectoryFileCallback 16 | from ldp.alg.callbacks import Callback 17 | from ldp.alg.rollout import RolloutManager 18 | from ldp.data_structures import Transition 19 | from llmclient.cost_tracker import enable_cost_tracking 20 | from pydantic import Field, model_validator 21 | 22 | from fhda.data_analysis_env import DataAnalysisEnv 23 | 24 | from .client import TaskDatasetSplit 25 | from .common import ( 26 | LoggingCallback, 27 | SaveWorkspaceCallback, 28 | VerboseCallback, 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class EnvServerConfig(ConfigModel): 35 | split: TaskDatasetSplit 36 | host: str = "localhost" 37 | port: int 38 | request_timeout: float | None = 300.0 39 | 40 | 41 | class NBEvalExpt(ConfigurableExpt): 42 | output_repo: DataRepo 43 | comment: str = "" 44 | overwrite: bool = False 45 | 46 | env: EnvServerConfig 47 | 48 | agent: AgentConfig 49 | evaluator: EvaluatorConfig = Field( 50 | default_factory=lambda: EvaluatorConfig(num_eval_iterations=25) 51 | ) 52 | 53 | async def make_dataset(self) -> TaskDatasetClient: 54 | base_dataset = await TaskDatasetClient.create( 55 | server_url=f"http://{self.env.host}:{self.env.port}", 56 | request_timeout=self.env.request_timeout, 57 | ) 58 | return self.env.split.get_random_split(base_dataset) 59 | 60 | @model_validator(mode="after") 61 | def post_init(self) -> Self: 62 | if self.overwrite: 63 | shutil.rmtree(self.output_repo.local_path, ignore_errors=True) 64 | self.output_repo.mkdir() 65 | return self 66 | 67 | async def run(self) -> None: 68 | set_up_output_dir(self.output_repo.local_path, config=self) 69 | dataset = await self.make_dataset() 70 | agent = self.agent.construct_agent() 71 | callbacks: list[Callback] = [ 72 | TrajectoryFileCallback(self.output_repo.local_path), 73 | LoggingCallback(self.output_repo), 74 | SaveWorkspaceCallback( 75 | dataset_client=dataset, 76 | workspace_repo=DataRepo(name=f"{self.output_repo.name}-workspaces"), 77 | ), 78 | ] 79 | if self.evaluator.batch_size == 1: 80 | callbacks.append(VerboseCallback()) 81 | litellm.drop_params = True 82 | enable_cost_tracking(enabled=True) 83 | evaluator = Evaluator( 84 | config=self.evaluator, 85 | agent=agent, 86 | dataset=dataset, 87 | callbacks=callbacks, 88 | ) 89 | await evaluator.run() 90 | 91 | self.output_repo.push(progress=True) 92 | 93 | 94 | class AdHocExptCallback(Callback): 95 | def __init__(self, output_dir: Path): 96 | self.output_dir = output_dir 97 | 98 | async def after_transition( 99 | self, 100 | traj_id: str, 101 | agent: Agent, 102 | env: DataAnalysisEnv, # type: ignore[override] 103 | transition: Transition, 104 | ) -> None: 105 | if transition.done or transition.truncated or transition.failed: 106 | target_dir = self.output_dir / env.problem_id 107 | if target_dir.exists(): 108 | shutil.rmtree(target_dir) 109 | shutil.copytree(env.state.work_dir, target_dir) 110 | 111 | if transition.action: 112 | action = transition.action.value 113 | submitted_answers = [ 114 | tc.function.arguments["answer"] 115 | for tc in action.tool_calls 116 | if tc.function.name == "submit_answer" 117 | ] 118 | with (self.output_dir / (env.problem_id + "-answer.json")).open( 119 | "w" 120 | ) as f: 121 | json.dump(submitted_answers, f, indent=2) 122 | 123 | 124 | class AdHocExpt(ConfigurableExpt): 125 | problem: str = Field(description="Problem to solve.") 126 | problem_id: str = Field( 127 | default_factory=lambda: f"analysis-{datetime.now().strftime('%Y%m%d-%H%M%S')}", 128 | description="Arbitrary problem ID - outputs will be stored with this name. " 129 | "Auto-assigned with timestamp if not provided.", 130 | ) 131 | 132 | input_dir: str = Field(description="Directory containing input data.") 133 | input_repo: DataRepo | None = Field( 134 | default=None, 135 | description="If provided, will set `input_dir` to `input_repo.local_path`.", 136 | ) 137 | 138 | output_dir: str | None = Field( 139 | default=None, 140 | description="Directory to save output notebooks. " 141 | "If not provided, will use `input_dir`.", 142 | ) 143 | output_repo: DataRepo | None = Field( 144 | default=None, 145 | description="If provided, will set `output_dir` to `output_repo.local_path`.", 146 | ) 147 | 148 | agent: AgentConfig 149 | max_rollout_steps: int | None = None 150 | verbose_callback: bool = True 151 | copy_workspace_callback: bool = True 152 | language: str = "python" 153 | 154 | async def run(self) -> None: 155 | output_path = Path(cast(str, self.output_dir)) 156 | agent = self.agent.construct_agent() 157 | 158 | # Sanity check to prevent misconfiguration for now - may revisit 159 | if not getattr(agent, "hide_old_env_states", True): 160 | raise RuntimeError( 161 | "It is strongly recommended that hide_old_env_states=True " 162 | "if the agent provides this option." 163 | ) 164 | 165 | callbacks: list[Callback] = [] 166 | if self.verbose_callback: 167 | callbacks.append(VerboseCallback()) 168 | if self.copy_workspace_callback: 169 | callbacks.append(AdHocExptCallback(output_path)) 170 | 171 | rollout = RolloutManager(agent=agent, callbacks=callbacks) 172 | 173 | language = NBLanguage.PYTHON if self.language == "python" else NBLanguage.R 174 | 175 | input_path = Path(self.input_dir) 176 | env = DataAnalysisEnv( 177 | problem_id=self.problem_id, 178 | problem=self.problem, 179 | # doesn't really matter, since there's no answer 180 | eval_mode=EvalAnswerMode.EXACT, 181 | # use_tmp_work_dir=True by default, so self.data_dir will be copied 182 | nb_path=(input_path / "analysis.ipynb"), 183 | work_dir=input_path, 184 | language=language, 185 | ) 186 | 187 | await rollout.sample_trajectories( 188 | environments=[env], max_steps=self.max_rollout_steps 189 | ) 190 | 191 | await env.close() 192 | 193 | if self.output_repo is not None: 194 | self.output_repo.push(progress=True) 195 | 196 | @model_validator(mode="before") 197 | @classmethod 198 | def set_dirs(cls, data): 199 | if isinstance(data, dict): 200 | for pfx in ("input", "output"): 201 | if f"{pfx}_repo" in data: 202 | assert f"{pfx}_dir" not in data, ( 203 | f"Cannot provide both {pfx}_dir and {pfx}_repo" 204 | ) 205 | data[f"{pfx}_repo"] = DataRepo(**data[f"{pfx}_repo"]) 206 | data[f"{pfx}_dir"] = data[f"{pfx}_repo"].local_path 207 | return data 208 | 209 | @model_validator(mode="after") 210 | def post_init(self) -> Self: 211 | if self.input_repo is not None: 212 | self.input_repo.pull(progress=True) 213 | 214 | if self.output_repo is not None: 215 | self.output_repo.mkdir() 216 | 217 | if self.output_dir is None: 218 | self.output_dir = self.input_dir 219 | 220 | return self 221 | -------------------------------------------------------------------------------- /src/expts/server.py: -------------------------------------------------------------------------------- 1 | """Utilities to run TaskDatasetServers on various notebook task datasets.""" 2 | 3 | import json 4 | import logging 5 | import shutil 6 | from abc import ABC, abstractmethod 7 | from pathlib import Path 8 | from typing import Generic, TypeVar 9 | 10 | from aviary.core import TaskDataset, TaskDatasetServer 11 | from fhda.storage import DataRepo 12 | from fhda.utils import collect_notebook_stats 13 | from fhda.data_analysis_env import DataAnalysisEnv 14 | from fhda.dataset import CapsuleDataset, CapsuleDatasetConfig 15 | from scripts.configurable import ConfigurableExpt 16 | from pydantic import BaseModel, Field 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class SaveWorkspaceRequest(BaseModel): 23 | env_id: str 24 | traj_id: str 25 | workspace_repo: DataRepo 26 | exception: bool 27 | cost: float 28 | time: float 29 | 30 | 31 | class NBTaskDatasetServer(TaskDatasetServer[DataAnalysisEnv]): 32 | def _setup_routes(self) -> None: 33 | super()._setup_routes() 34 | 35 | @self.app.post("/save_workspace") 36 | async def save_workspace(req: SaveWorkspaceRequest): 37 | async with self.lock: 38 | env = self._get_env(req.env_id) 39 | 40 | problem_id = env.problem_id 41 | this_workspace_repo = DataRepo( 42 | name=f"{req.workspace_repo.name}/{problem_id.replace('/', '-')}-{req.traj_id}" 43 | ) 44 | this_workspace_repo.mkdir() 45 | out_dir = Path(this_workspace_repo.local_path) 46 | logger.info(f"Saving workspace to {this_workspace_repo.name}") 47 | 48 | # # Copy the full output directory 49 | for file in Path(env.state.work_dir).glob("**/*"): 50 | if file.suffix in {".ipynb", ".json"}: 51 | dest = out_dir / file.relative_to(env.state.work_dir) 52 | dest.parent.mkdir(parents=True, exist_ok=True) 53 | shutil.copy2(file, dest) 54 | res = { 55 | "problem_id": problem_id, 56 | "traj_id": req.traj_id, 57 | "reward": env.state.total_reward, 58 | "agent_answer": env.state.answer, 59 | "ideal_answer": env.answer, 60 | "problem": env.problem, 61 | "mcq_options": [q.options for q in env.mcqs] if env.mcqs else [], 62 | "mcq_question": [q.question for q in env.mcqs] if env.mcqs else [], 63 | "question_rewards": env.question_rewards, 64 | "cost": req.cost, 65 | "exception": req.exception, 66 | "notebook_stats": collect_notebook_stats(env.state.nb), 67 | "time": req.time, 68 | "actions": env.state.actions, 69 | "run_id": req.workspace_repo.name, 70 | "metadata": env.metadata, 71 | "insufficient_options": { 72 | q.question_id: q.unsure_answer_letter for q in (env.mcqs or []) 73 | }, 74 | } 75 | with (out_dir / "metadata.json").open("w") as f: 76 | json.dump( 77 | res, 78 | f, 79 | indent=4, 80 | ) 81 | 82 | # Push just this specific workspace, not the whole workspace repo 83 | this_workspace_repo.push(progress=True) 84 | # # Delete the workspace directory after pushing 85 | shutil.rmtree(out_dir) 86 | 87 | 88 | TDataset = TypeVar("TDataset", bound=TaskDataset) 89 | 90 | 91 | class DatasetServer(ConfigurableExpt, ABC, Generic[TDataset]): 92 | port: int 93 | 94 | @abstractmethod 95 | def make_dataset(self) -> TDataset: 96 | pass 97 | 98 | async def run(self) -> None: 99 | dataset = self.make_dataset() 100 | logger.info(f"Starting {dataset.__class__.__name__} server on port {self.port}") 101 | server = NBTaskDatasetServer(dataset, port=self.port) 102 | await server.astart() 103 | 104 | 105 | class CapsuleDatasetServer(DatasetServer[CapsuleDataset]): 106 | dataset: CapsuleDatasetConfig = Field(default_factory=CapsuleDatasetConfig) 107 | 108 | def make_dataset(self) -> CapsuleDataset: 109 | return CapsuleDataset(config=self.dataset) 110 | -------------------------------------------------------------------------------- /src/expts/training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from collections.abc import Mapping, Sequence 4 | 5 | from aviary.core import ( 6 | TaskDatasetClient, 7 | ) 8 | from aviary_internal import utils 9 | from aviary_internal.agent import DQNAgentVariant 10 | from aviary_internal.agent.dqn_agent import LLMSamplingMode 11 | from aviary_internal.alg.optimizer.dqn import DQNOptimizer 12 | from aviary_internal.nn.sft_optimizer import LocalLLMSFTOptimizer 13 | from aviary_internal.serialization import disable_serialization_backend 14 | from cloning.expts.local_sft import CloningOnlineLocalTrainingExpt 15 | from gsm8k.expts.dqn.online import GSM8kDQNOnlineTrainingExpt 16 | from ldp.alg.callbacks import Callback 17 | from ldp.alg.runners import OnlineTrainerConfig 18 | from ldp.data_structures import Trajectory 19 | 20 | from .client import TaskDatasetSplit 21 | from .common import SaveWorkspaceCallback, prev_choice_rep_fn 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class EnvServerConfig(utils.ConfigModel): 27 | host: str 28 | port: int 29 | request_timeout: float | None = 300.0 30 | 31 | async def make_datasets(self) -> dict[str, TaskDatasetClient]: 32 | base_dataset = await TaskDatasetClient.create( 33 | server_url=f"http://{self.host}:{self.port}", 34 | request_timeout=self.request_timeout, 35 | ) 36 | return { 37 | "train_dataset": TaskDatasetSplit.TRAIN.get_random_split(base_dataset), 38 | "eval_dataset": TaskDatasetSplit.EVAL.get_random_split(base_dataset), 39 | } 40 | 41 | 42 | class NBDQNOnlineTrainingExpt(GSM8kDQNOnlineTrainingExpt): 43 | env: EnvServerConfig 44 | 45 | async def make_datasets(self) -> dict[str, TaskDatasetClient]: 46 | return await self.env.make_datasets() 47 | 48 | def make_callbacks( 49 | self, 50 | agent: DQNAgentVariant, 51 | optimizer: DQNOptimizer, 52 | datasets: Mapping[str, TaskDatasetClient], 53 | ) -> list[Callback]: 54 | callbacks = super().make_callbacks(agent, optimizer, datasets) 55 | callbacks.append( 56 | SaveWorkspaceCallback( 57 | dataset_client=datasets["train_dataset"], 58 | workspace_repo=utils.DataRepo( 59 | name=f"{self.output_repo.name}-workspaces" 60 | ), 61 | ) 62 | ) 63 | return callbacks 64 | 65 | def make_agent(self, **kwargs) -> DQNAgentVariant: 66 | if self.agent.llm_sampling_mode == LLMSamplingMode.SEQUENTIAL: 67 | self.agent.llm_kwargs["prev_choice_rep_fn"] = prev_choice_rep_fn 68 | return super().make_agent(**kwargs) 69 | 70 | 71 | class NBOnlineTrainingConfig(OnlineTrainerConfig): 72 | save_all_checkpoints: bool = True 73 | num_val_trajs: int 74 | num_train_trajs: int | None = None 75 | 76 | 77 | class NBOnlineLocalTrainingExpt(CloningOnlineLocalTrainingExpt): 78 | env: EnvServerConfig 79 | trainer: NBOnlineTrainingConfig 80 | 81 | async def _get_demonstration_examples( 82 | self, opt: LocalLLMSFTOptimizer 83 | ) -> tuple[list[dict], list[dict]]: 84 | backend = await self.make_backend() 85 | trajectories = await backend.get_trajectories() 86 | 87 | random.Random(self.data_seed).shuffle(trajectories) 88 | val_trajs = self._filter_trajectories( 89 | trajectories[: self.trainer.num_val_trajs], opt 90 | ) 91 | train_trajs = self._filter_trajectories( 92 | trajectories[self.trainer.num_val_trajs :][: self.trainer.num_train_trajs], 93 | opt, 94 | ) 95 | logger.info( 96 | f"Loaded {len(train_trajs)} ({len(val_trajs)}) train (val) trajectories." 97 | ) 98 | 99 | # Disable the backend so we don't accidentally overwrite input data 100 | disable_serialization_backend() 101 | 102 | # convert to examples 103 | train_examples = self._trajs_to_examples(train_trajs, opt) 104 | val_examples = self._trajs_to_examples(val_trajs, opt) 105 | return train_examples, val_examples 106 | 107 | def _filter_trajectories( 108 | self, trajectories: Sequence[Trajectory], opt: LocalLLMSFTOptimizer 109 | ): 110 | return [t for t in trajectories if opt.trajectory_passes(t)] 111 | -------------------------------------------------------------------------------- /src/fhda/Dockerfile.pinned: -------------------------------------------------------------------------------- 1 | # DANGER: Beware of changing this dockerfile, orchestrating the versioning in these R/python packages was very challenging 2 | FROM continuumio/miniconda3:24.9.2-0 3 | 4 | RUN mkdir /workspace && \ 5 | mkdir /envs 6 | WORKDIR /envs 7 | 8 | ENV DEBIAN_FRONTEND=noninteractive 9 | RUN apt-get update && \ 10 | apt-get install -yq --no-install-recommends \ 11 | wget \ 12 | gpg \ 13 | software-properties-common \ 14 | build-essential && \ 15 | rm -rf /var/lib/apt/lists/* 16 | 17 | RUN conda install mamba=2.0.5 -c conda-forge -y 18 | 19 | # Install R packages from conda-forge 20 | RUN mamba install -c conda-forge -y \ 21 | r-base=4.3.3 \ 22 | r-recommended=4.3 \ 23 | r-irkernel=1.3.2 \ 24 | r-factominer=2.11 \ 25 | r-rcolorbrewer=1.1_3 \ 26 | r-devtools=2.4.5 \ 27 | r-broom=1.0.7 \ 28 | r-data.table=1.15.4 \ 29 | r-enrichr=3.2 \ 30 | r-factoextra=1.0.7 \ 31 | r-ggnewscale=0.5.0 \ 32 | r-ggrepel=0.9.6 \ 33 | r-ggpubr=0.6.0 \ 34 | r-ggvenn=0.1.10 \ 35 | r-janitor=2.2.1 \ 36 | r-multcomp=1.4_26 \ 37 | r-matrix=1.6_5 \ 38 | r-pheatmap=1.0.12 \ 39 | r-tidyverse=2.0.0 \ 40 | r-readxl=1.4.3 \ 41 | r-reshape=0.8.9 \ 42 | r-rstatix=0.7.2 \ 43 | r-viridis=0.6.5 \ 44 | udocker=1.3.17 \ 45 | imbalanced-learn=0.13.0 \ 46 | ipykernel=6.29.5 \ 47 | sqlite=3.47.2 48 | 49 | RUN python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)" 50 | RUN R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")' 51 | 52 | # I separate these because not all packages need both channels, additionally, 53 | # creating multiple layers makes caching easier 54 | RUN mamba install -c conda-forge -c bioconda -y \ 55 | biokit=0.5.0 \ 56 | gseapy=1.1.4 \ 57 | blast=2.16.0 \ 58 | clipkit=2.3.0 \ 59 | fastqc=0.12.1 \ 60 | iqtree=2.3.6 \ 61 | mafft=7.526 \ 62 | metaeuk=7.bba0d80 \ 63 | mygene=3.2.2 \ 64 | perl=5.32.1 \ 65 | phykit=2.0.1 \ 66 | pydeseq2=0.4.12 \ 67 | spades=4.0.0 \ 68 | trim-galore=0.6.10 \ 69 | bioconductor-enhancedvolcano=1.20.0 \ 70 | bioconductor-deseq2=1.42.0 \ 71 | bioconductor-clusterprofiler=4.10.0 \ 72 | bioconductor-org.hs.eg.db=3.18.0 \ 73 | bioconductor-genomicranges=1.54.1 \ 74 | bioconductor-summarizedexperiment=1.32.0 \ 75 | bioconductor-apeglm=1.24.0 76 | 77 | 78 | COPY kernel_requirements.txt . 79 | 80 | # Install conda packages first 81 | RUN mamba install -c conda-forge --file kernel_requirements.txt -y 82 | 83 | # Install pip packages 84 | RUN pip install aiodocker ldp==0.26.0 fhaviary[server]==0.19.0 futurehouse-client==0.3.14 85 | 86 | # Certain tools are not easily installable via conda. A common practice for 87 | # bioinformaticians is to use udocker to run certain heavy duty omics processing 88 | # tools in an isolated environment 89 | # RUN udocker --allow-root install && \ 90 | # udocker --allow-root pull ezlabgva/busco:v5.8.0_cv1 91 | 92 | WORKDIR /workspace 93 | 94 | RUN mamba clean --all -f -y && \ 95 | conda clean --all -f -y && \ 96 | rm -rf /root/.cache/pip 97 | -------------------------------------------------------------------------------- /src/fhda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Future-House/data-analysis-crow/095df6e47b1cfea4361927872f02a8c40204f864/src/fhda/__init__.py -------------------------------------------------------------------------------- /src/fhda/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | USE_DOCKER = bool(os.getenv("USE_DOCKER", "false").lower() == "true") 5 | USE_R = bool(os.getenv("USE_R", "false").lower() == "true") 6 | NB_ENVIRONMENT_DOCKER_IMAGE = os.getenv( 7 | "NB_ENVIRONMENT_DOCKER_IMAGE", "futurehouse/bixbench:aviary-notebook-env" 8 | ) 9 | 10 | # Some R error messages can be 100,000 of characters 11 | NB_OUTPUT_LIMIT = 3000 # chars 12 | # Streams from a docker container. Don't set to `sys.stdout.fileno()` 13 | # because we want to differentiate from file I/O 14 | DOCKER_STREAM_TYPE_STDOUT = 1 15 | DOCKER_STREAM_TYPE_STDERR = 2 16 | 17 | STAGE = os.getenv("STAGE", "local") 18 | if STAGE == "local": 19 | DATA_STORAGE_PATH = Path("storage") 20 | else: 21 | DATA_STORAGE_PATH = Path("/storage") 22 | -------------------------------------------------------------------------------- /src/fhda/data_analysis_env.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import logging 4 | import shutil 5 | from typing import Any, cast 6 | import time 7 | from aviary.core import ( 8 | EvalAnswerMode, 9 | Frame, 10 | Message, 11 | Messages, 12 | Tool, 13 | eval_answer, 14 | ) 15 | 16 | from .notebook_env import NBEnvironment 17 | from .utils import NBLanguage, MultipleChoiceQuestion, nb_to_html 18 | from . import prompts 19 | from . import config as cfg 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | CORRECT_MSG = "Correct answer!" 24 | INCORRECT_MSG = "Incorrect answer." 25 | 26 | 27 | class DataAnalysisEnv(NBEnvironment): 28 | def __init__( 29 | self, 30 | *, 31 | problem_id: str, 32 | problem: str, 33 | answer: str | int | float | None = None, # noqa: PYI041 34 | system_prompt: str | None = None, 35 | correct_reward: float = 1.0, 36 | eval_mode: EvalAnswerMode, 37 | metadata: dict[str, Any] | None = None, # used for NBEvalExpt 38 | mcqs: list[MultipleChoiceQuestion] | None = None, 39 | **kwargs, 40 | ): 41 | super().__init__(**kwargs) 42 | 43 | self.problem_id = problem_id 44 | self.problem = problem 45 | self.mcqs = mcqs 46 | self.answer = answer 47 | self.eval_mode = eval_mode 48 | self.correct_reward = correct_reward 49 | self.system_prompt = system_prompt 50 | self.metadata = metadata 51 | self.question_rewards: dict[str, int] = {} 52 | 53 | async def reset(self) -> tuple[Messages, list[Tool]]: 54 | # Discard base class's init_obs and make our own with the problem statement 55 | _, tools = await super().reset() 56 | messages = [ 57 | Message(content=self.problem), 58 | self.get_env_state_msg(), 59 | ] 60 | if self.system_prompt: 61 | messages.append(Message(role="system", content=self.system_prompt)) 62 | init_obs = cast( 63 | Messages, 64 | messages, 65 | ) 66 | 67 | return init_obs, tools 68 | 69 | async def submit_answer(self, answer: str | float | dict[str, Any] | None) -> str: # type: ignore[override] 70 | """Submit an answer to the problem. 71 | 72 | Note that this tool may only be called once and ends the episode. 73 | 74 | Args: 75 | answer: The answer to the problem 76 | """ 77 | # TODO: support various eval modes 78 | self.state.answer = answer 79 | self.state.done = True 80 | logger.info("Submitting answer and closing environment") 81 | await self.close() 82 | correct = False 83 | logger.info("Answer: %s", answer) 84 | 85 | if self.eval_mode is None: 86 | return CORRECT_MSG 87 | 88 | if isinstance(self.answer, int): 89 | try: 90 | answer = int(answer) # type: ignore[arg-type] 91 | except ValueError: 92 | pass 93 | else: 94 | correct = answer == self.answer 95 | 96 | elif isinstance(self.answer, float): 97 | try: 98 | answer = float(answer) # type: ignore[arg-type] 99 | except ValueError: 100 | pass 101 | else: 102 | correct = abs(answer - self.answer) < 1e-4 * self.answer 103 | 104 | elif isinstance(self.answer, str): 105 | correct = bool( 106 | await eval_answer( 107 | proposed=str(answer), 108 | correct=str(self.answer), 109 | question=self.problem, 110 | eval_mode=self.eval_mode, 111 | ) 112 | ) 113 | elif isinstance(self.answer, dict): # This is for mcqs and open questions 114 | # Check if answer is a json string 115 | if isinstance(answer, str): # type: ignore[unreachable] 116 | # Process json into dictionary 117 | try: 118 | processed_answer = json.loads(answer) 119 | except json.JSONDecodeError: 120 | return INCORRECT_MSG 121 | else: 122 | processed_answer = answer if isinstance(answer, dict) else {} 123 | 124 | # Loop through each question and answer 125 | for question_id, agent_answer in processed_answer.items(): 126 | try: 127 | ideal_answer = self.answer[question_id] 128 | question = next( 129 | q 130 | for q in self.mcqs 131 | if q.question_id.lower() == question_id.lower() 132 | ) 133 | correct = bool( 134 | await eval_answer( 135 | proposed=str(agent_answer), 136 | correct=str(ideal_answer), 137 | question=question, 138 | eval_mode=self.eval_mode, 139 | ) 140 | ) 141 | self.question_rewards[question_id] = correct 142 | except KeyError: 143 | self.question_rewards[question_id] = 0 144 | average_reward = sum(self.question_rewards.values()) / len(self.mcqs) 145 | correct = round(average_reward) == 1.0 146 | 147 | if correct: 148 | self.state.total_reward += self.correct_reward 149 | return CORRECT_MSG 150 | return INCORRECT_MSG 151 | 152 | @classmethod 153 | def from_task( 154 | cls, task: str, gcs_artifact_path: str | None = None 155 | ) -> "DataAnalysisEnv": 156 | """ 157 | Perform data analysis on a user query. 158 | 159 | Args: 160 | task: The user query structured as | 161 | 162 | eg "CaspuleFolder-a7812fg | How many genes are differentially expressed between the two conditions?" 163 | """ 164 | logger.info("User task: %s", task) 165 | logger.info("GCS artifact path: %s", gcs_artifact_path) 166 | 167 | if ( 168 | gcs_artifact_path 169 | ): # The files are already in the GCS bucket in a job-specific directory 170 | trajectory_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path 171 | nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME 172 | query = task 173 | task_hash = gcs_artifact_path 174 | else: 175 | # Extract data path and query from task 176 | data_path, query = task.split("|") 177 | # Hash the task to get a unique identifier 178 | task_hash = hashlib.sha256(task.encode()).hexdigest() 179 | # Create temporary directory in GCP mounted storage volume 180 | trajectory_path = cfg.DATA_STORAGE_PATH / f"{task_hash}-{time.time()}" 181 | trajectory_path.mkdir(parents=True, exist_ok=True) 182 | nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME 183 | # Copy task data to trajectory path 184 | for item in (cfg.DATA_STORAGE_PATH / data_path).iterdir(): 185 | if item.is_file(): 186 | shutil.copy2(item, trajectory_path) 187 | elif item.is_dir(): 188 | shutil.copytree( 189 | item, trajectory_path / item.name, dirs_exist_ok=True 190 | ) 191 | 192 | # Augment incoming task with CoT instructions 193 | augmented_task = f"""\ 194 | Here is the user query to address: 195 | 196 | 197 | {query} 198 | 199 | 200 | {prompts.CHAIN_OF_THOUGHT_AGNOSTIC} 201 | {prompts.GENERAL_NOTEBOOK_GUIDELINES}""" 202 | 203 | language = NBLanguage.PYTHON # In future, this should be a hyperparameter 204 | if language == NBLanguage.R: 205 | augmented_task += f"\n{prompts.R_OUTPUT_RECOMMENDATION_PROMPT}" 206 | 207 | # Log all parameters being passed to constructor 208 | logger.info( 209 | "Creating DataAnalysisEnv with parameters: " 210 | "problem_id=data-analysis-task-%s, " 211 | "problem=%s, " 212 | "eval_mode=%s, " 213 | "nb_path=%s, " 214 | "work_dir=%s, " 215 | "language=%s, " 216 | "system_prompt=%s, " 217 | "use_tmp_work_dir=%s, " 218 | "gcs_artifact_path=%s", 219 | task_hash, 220 | augmented_task, 221 | EvalAnswerMode.LLM, 222 | nb_path, 223 | trajectory_path, 224 | language, 225 | prompts.CAPSULE_SYSTEM_PROMPT_QUERY, 226 | False, 227 | gcs_artifact_path, 228 | ) 229 | if trajectory_path.exists(): 230 | logger.info( 231 | "Files in directory: %s", [f.name for f in trajectory_path.iterdir()] 232 | ) 233 | 234 | return cls( 235 | problem_id=f"data-analysis-task-{task_hash}", 236 | problem=augmented_task, 237 | eval_mode=EvalAnswerMode.LLM, 238 | nb_path=nb_path, 239 | work_dir=trajectory_path, 240 | language=language, 241 | system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_QUERY, 242 | use_tmp_work_dir=False, 243 | ) 244 | 245 | def export_frame(self) -> Frame: 246 | return Frame( 247 | state={ 248 | "last_action": self.state.actions[-1], 249 | "answer": self.state.answer, 250 | "done": self.state.done, 251 | "total_reward": self.state.total_reward, 252 | "nb_state": self.state.nb, 253 | "nb_state_html": nb_to_html(self.state.nb), 254 | }, 255 | info={ 256 | "eval_mode": self.eval_mode, 257 | "language": self.state.language, 258 | "problem": self.problem, 259 | "problem_id": self.problem_id, 260 | }, 261 | ) 262 | -------------------------------------------------------------------------------- /src/fhda/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | from pathlib import Path 4 | from tempfile import mkdtemp 5 | 6 | from pydantic import Field 7 | 8 | from aviary.core import EvalAnswerMode, TaskDataset 9 | from .storage import DataRepo 10 | from .data_analysis_env import DataAnalysisEnv 11 | from .utils import NBLanguage, load_mcq 12 | from . import prompts 13 | from .models import ConfigModel 14 | from .notebook_env import NBEnvironment 15 | import logging 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class CapsuleDatasetConfig(ConfigModel): 22 | repo: DataRepo = Field( 23 | default_factory=lambda: DataRepo(name="baseline-envs/data-analysis/v3.1"), 24 | description="The hosted repo to use for the dataset.", 25 | ) 26 | 27 | local_repo_path: str | None = Field( 28 | default=None, 29 | description="If provided, will source the data from this local path instead of the hosted repo.", 30 | ) 31 | 32 | local_output_path: str | None = Field( 33 | default=None, 34 | description="If provided, will save the output to this local path instead of the hosted repo.", 35 | ) 36 | 37 | capsule_mode: str | None = Field( 38 | default="mcq", 39 | description="Determines whether the agent is to answer MCQs, open questions or whether a hypothesis is supported by the data", 40 | ) 41 | 42 | eval_mode: EvalAnswerMode = Field( 43 | default=EvalAnswerMode.LLM, 44 | description="If exact, the target will be 'answer' in the metadata json (i.e. T/F) " 45 | "If llm, the target will be 'result'. Contains/score not supported", 46 | ) 47 | 48 | avoid_images: bool = Field( 49 | default=False, 50 | description="If True, the agent will be prompted to avoid using images in its notebook.", 51 | ) 52 | 53 | preload_notebook: bool = Field( 54 | default=False, 55 | description=( 56 | "If False, the agent will have to start from a virgin notebook. " 57 | "If True, the agent environment will be preloaded with a notebook " 58 | "containing a portion of the capsule problem already completed " 59 | "eg package & data loading." 60 | ), 61 | ) 62 | 63 | prompt_template_key: str = Field( 64 | default="v1.3.1", 65 | description="The key of the prompt template from the CAPSULE_PROMPT_TEMPLATES dict to use for the problem.", 66 | ) 67 | 68 | 69 | class CapsuleDataset(TaskDataset[DataAnalysisEnv]): 70 | """A dataset of tasks derived from data analysis capsules.""" 71 | 72 | def __init__(self, config: CapsuleDatasetConfig): 73 | # Load dataset from local path or hosted repo 74 | if config.local_repo_path: 75 | repo_path = config.local_repo_path 76 | else: 77 | config.repo.pull(progress=True) 78 | repo_path = config.repo.local_path 79 | self.capsules = list(Path(repo_path).rglob("CapsuleFolder*")) 80 | 81 | # Load prompt template 82 | self.prompt = prompts.CAPSULE_PROMPT_TEMPLATES[config.prompt_template_key] 83 | self.config = config 84 | 85 | def get_new_env_by_idx(self, idx: int) -> DataAnalysisEnv: 86 | capsule_path = self.capsules[idx] 87 | metadata = json.load((capsule_path / "metadata.json").open()) 88 | 89 | notebook_name = NBEnvironment.NOTEBOOK_NAME 90 | # Define local capsule directory 91 | if self.config.local_output_path: 92 | problem_dir = Path(self.config.local_output_path) / capsule_path.name 93 | else: 94 | problem_dir = Path(mkdtemp()) 95 | problem_dir.mkdir(parents=True, exist_ok=True) 96 | 97 | # Copy capsule contents to local directory 98 | for item in capsule_path.iterdir(): 99 | if self.config.preload_notebook and str(item).endswith("_stripped.ipynb"): 100 | shutil.copy(item, problem_dir) 101 | elif str(item).endswith((".ipynb", "metadata.json", "checksum")): 102 | continue 103 | elif item.is_dir(): 104 | shutil.copytree(item, problem_dir / item.name) 105 | else: 106 | shutil.copy(item, problem_dir) 107 | 108 | nb_path = problem_dir / notebook_name 109 | 110 | # Define system prompt and problem 111 | if self.config.capsule_mode == "hypothesis": 112 | system_prompt = prompts.CAPSULE_SYSTEM_PROMPT_HYPOTHESIS 113 | problem = self.prompt.replace("{{hypothesis}}", metadata["hypothesis"]) 114 | answer = metadata["answer"] 115 | processed_questions = None 116 | elif self.config.capsule_mode == "mcq": 117 | raw_mcqs = metadata["notebook_questions"]["questions"] 118 | processed_questions = [ 119 | load_mcq(i, open_question=False, question_id=i["id"]) for i in raw_mcqs 120 | ] 121 | system_prompt = prompts.CAPSULE_SYSTEM_PROMPT_MCQ 122 | problem = self.prompt.format( 123 | questions="\n-------\n".join( 124 | [i.question_prompt for i in processed_questions] 125 | ) 126 | ) 127 | answer = {i.question_id: i.ideal_answer for i in processed_questions} 128 | elif self.config.capsule_mode == "open": 129 | system_prompt = prompts.CAPSULE_SYSTEM_PROMPT_OPEN 130 | raw_open_questions = metadata["notebook_questions"]["questions"] 131 | processed_questions = [ 132 | load_mcq(i, open_question=True, question_id=i["id"]) 133 | for i in raw_open_questions 134 | ] 135 | problem = self.prompt.format( 136 | questions="\n-------\n".join( 137 | [i.question_prompt for i in processed_questions] 138 | ) 139 | ) 140 | answer = {i.question_id: i.ideal_answer for i in processed_questions} 141 | else: 142 | raise ValueError(f"Invalid capsule mode: {self.config.capsule_mode}") 143 | 144 | if self.config.avoid_images: 145 | problem += prompts.AVOID_IMAGES 146 | 147 | # Temporarily hard code language to python, but can also use R 148 | language = NBLanguage.PYTHON 149 | return DataAnalysisEnv( 150 | problem_id=capsule_path.name, 151 | problem=problem, 152 | eval_mode=self.config.eval_mode, 153 | nb_path=nb_path, 154 | work_dir=problem_dir, 155 | language=language, 156 | system_prompt=system_prompt, 157 | metadata=metadata, 158 | answer=answer, 159 | mcqs=processed_questions, 160 | ) 161 | 162 | def __len__(self) -> int: 163 | return len(self.capsules) 164 | -------------------------------------------------------------------------------- /src/fhda/dev.yaml: -------------------------------------------------------------------------------- 1 | job: 2 | cpu: 2 3 | memory: 4Gi 4 | timeout: 1200s 5 | env: 6 | CROW_AGENT: ldp.agent.SimpleAgent 7 | CROW_ENVIRONMENT: data_analysis.env.DataAnalysisEnv 8 | OPENAI_API_KEY: gcsm:crow-openai-api-key 9 | ANTHROPIC_API_KEY: gcsm:crow-anthropic-api-key 10 | -------------------------------------------------------------------------------- /src/fhda/kernel_requirements.txt: -------------------------------------------------------------------------------- 1 | anndata==0.11.1 2 | biopython==1.84 3 | ete3==3.1.3 4 | fcsparser==0.2.8 5 | cython==3.0.12 6 | gseapy==1.1.4 7 | keras==3.7.0 8 | jupyter==1.0.0 9 | matplotlib==3.10.0 10 | matplotlib-venn==1.1.1 11 | mygene==3.2.2 12 | nbconvert==7.16.4 13 | numpy==1.26.4 # Pinned lower for fcsparser <2 14 | optuna==4.1.0 15 | openpyxl==3.1.5 16 | pandas==2.2.3 17 | plotly==5.24.1 18 | rpy2==3.5.11 19 | scipy==1.14.1 20 | scanpy==1.10.4 21 | seaborn==0.13.2 22 | scikit-learn==1.6.0 23 | statsmodels==0.14.4 24 | umap-learn==0.5.7 25 | -------------------------------------------------------------------------------- /src/fhda/models.py: -------------------------------------------------------------------------------- 1 | """Module for handling yaml config/CLI args and translating them into pydantic configs.""" 2 | 3 | import contextlib 4 | import inspect 5 | import logging 6 | import os 7 | import shutil 8 | import sys 9 | import textwrap 10 | 11 | from argparse import ArgumentParser 12 | from collections.abc import Iterable 13 | from pathlib import Path 14 | from typing import Any, TypeVar 15 | 16 | import yaml 17 | from pydantic import BaseModel, ConfigDict 18 | from pydantic_core import PydanticUndefined 19 | from ldp.utils import configure_stdout_logs 20 | from llmclient import configure_llm_logs 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def configure_logs( 27 | log_file: str | os.PathLike | None = None, 28 | stdout_level: int | str | tuple[str, int | str] | None = logging.INFO, 29 | fmt: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 30 | ) -> None: 31 | """Configure logs. 32 | 33 | Args: 34 | log_file: Optional log file to add to all loggers. 35 | stdout_level: If int (default) or str, it's a log level for stdout. If two-tuple 36 | of str and int, it's a logger name and log level for that logger. Otherwise, 37 | if None, don't configure stdout logs. 38 | fmt: Logging format string. 39 | """ 40 | configure_llm_logs() 41 | 42 | # Set some good default log levels to avoid too much verbosity 43 | logging.getLogger("dask").setLevel(logging.WARNING) 44 | logging.getLogger("vcr.cassette").setLevel(logging.WARNING) 45 | 46 | if stdout_level is not None: 47 | if isinstance(stdout_level, tuple): 48 | configure_stdout_logs(name=stdout_level[0], level=stdout_level[1], fmt=fmt) 49 | else: 50 | configure_stdout_logs(level=stdout_level, fmt=fmt) 51 | 52 | if log_file is not None: 53 | # Configure all loggers to write to a log file 54 | file_handler = logging.FileHandler(log_file) 55 | file_handler.setLevel(logging.DEBUG) 56 | file_handler.setFormatter(logging.Formatter(fmt)) 57 | logger.info(f"Logging to {log_file}.") 58 | 59 | # apply retroactively to root logger and all existing loggers 60 | for logger_name in ("root", *logging.root.manager.loggerDict.keys()): 61 | logging.getLogger(logger_name).addHandler(file_handler) 62 | 63 | 64 | class ConfigModel(BaseModel): 65 | model_config = ConfigDict( 66 | extra="forbid", arbitrary_types_allowed=True, populate_by_name=True 67 | ) 68 | 69 | 70 | TConfig = TypeVar("TConfig", bound=BaseModel) 71 | 72 | 73 | def load_arg_dict(argv: list[str]) -> dict[str, Any]: 74 | """Loads arguments from command line and yaml files into a dictionary. 75 | 76 | For example, if the command line args are `--foo.bar 1 --foo.baz 2`, the resulting 77 | dictionary is {'foo': {'bar': 1, 'baz': 2}}. YAML files are directly parsed as dictionaries. 78 | """ 79 | parser = ArgumentParser(add_help=False) 80 | parser.add_argument("config_files", nargs="*", type=str) 81 | 82 | if not any(a.endswith(".yaml") for a in argv): 83 | # add a dummy arg to avoid argparse error 84 | argv = ["INVALID.yaml", *argv] 85 | args, remaining_args = parser.parse_known_args(argv) 86 | 87 | config_acc: dict[str, Any] = {} 88 | for cfg in args.config_files: 89 | if cfg == "INVALID.yaml": 90 | continue 91 | with open(cfg) as fcfg: 92 | config = yaml.load(fcfg, Loader=yaml.Loader) # noqa: S506 93 | _recursive_update(config_acc, config) 94 | 95 | _parse_cli_args(remaining_args, config_acc) 96 | 97 | return config_acc 98 | 99 | 100 | def load_config( 101 | config_cls: type[TConfig], 102 | verbose: bool = True, 103 | argv: list[str] | None = None, 104 | args_to_exclude: Iterable[str] | None = None, 105 | ) -> TConfig: 106 | """Utility function for handling config and command line args supplied via command line. 107 | 108 | Args: 109 | config_cls: Config class object 110 | verbose: Boolean indicating extent of logging info 111 | argv: List of command line args. If not specified (default), will use sys.argv. 112 | args_to_exclude: Arguments to skip when constructing the config object. 113 | 114 | Returns: 115 | Config object synthesizing CLI args and supplied yaml. 116 | """ 117 | if argv is None: 118 | argv = sys.argv[1:] 119 | 120 | if "-h" in argv or "--help" in argv: 121 | print(get_config_help_string(config_cls)) 122 | sys.exit(0) 123 | 124 | config_acc = load_arg_dict(argv) 125 | if args_to_exclude: 126 | for arg in args_to_exclude: 127 | config_acc.pop(arg, None) 128 | 129 | config = config_cls(**config_acc) 130 | 131 | if verbose: 132 | logger.info("\n%s", yaml.dump({config_cls.__name__: config.model_dump()})) 133 | 134 | return config 135 | 136 | 137 | def _parse_cli_args(remaining_args: list[str], config_acc: dict): 138 | while remaining_args: 139 | arg = remaining_args.pop(0) 140 | if not arg.startswith("--"): 141 | raise ValueError(f"Invalid argument {arg}") 142 | 143 | arg = arg[2:] 144 | try: 145 | value = remaining_args[0] 146 | if value.startswith("--"): 147 | # moved on to next arg 148 | value = "True" 149 | else: 150 | # consumed value - remove from args 151 | remaining_args.pop(0) 152 | except IndexError: 153 | # end of args, assume it was a flag 154 | value = "True" 155 | value = _resolve_value(value) 156 | 157 | arg_hierarchy = arg.split(".") 158 | update_dict: dict[str, Any] = {} 159 | current_dict = update_dict 160 | for arg in arg_hierarchy[:-1]: 161 | current_dict[arg] = {} 162 | current_dict = current_dict[arg] 163 | current_dict[arg_hierarchy[-1]] = value 164 | _recursive_update(config_acc, update_dict) 165 | 166 | 167 | def dump_config(config: BaseModel, path: os.PathLike | str) -> None: 168 | """Dump the input Pydantic config to a YAML file.""" 169 | path = Path(path) 170 | if path.is_dir(): 171 | path /= "config.yaml" 172 | with path.open("w") as f: 173 | yaml.dump(config.model_dump(), f) 174 | 175 | 176 | def get_config_help_string(config_cls: type[BaseModel], indent: int = 0) -> str: 177 | s = ( 178 | textwrap.indent(f"{config_cls.__name__}:", " " * indent) + "\n" 179 | if indent == 0 180 | else "" 181 | ) 182 | 183 | indent += 1 184 | for key, value in config_cls.model_fields.items(): 185 | annot: Any = value.annotation 186 | # Removing the description printing for now, since it's just too verbose. 187 | # TODO: see if we can format it in a more readable way. 188 | # desc = f" # {value.description}" if value.description else "" 189 | desc = "" 190 | 191 | if inspect.isclass(annot): 192 | if issubclass(annot, BaseModel): 193 | s += textwrap.indent(f"{key}:{desc}", " " * indent) + "\n" 194 | s += get_config_help_string(annot, indent) 195 | continue 196 | 197 | annot = annot.__name__ 198 | 199 | if value.is_required(): 200 | s += textwrap.indent(f"{key}: {annot}{desc}", " " * indent) + "\n" 201 | else: 202 | default = ( 203 | value.default_factory 204 | if value.default is PydanticUndefined 205 | else value.default 206 | ) 207 | s += ( 208 | textwrap.indent(f"{key}: {annot} = {default!r}{desc}", " " * indent) 209 | + "\n" 210 | ) 211 | 212 | return s 213 | 214 | 215 | DEFAULT_OUTPUT_LOG_NAME = "output.log" 216 | 217 | 218 | def set_up_output_dir( 219 | directory_path: str | os.PathLike, 220 | config: BaseModel | None = None, 221 | log_name: str | None = DEFAULT_OUTPUT_LOG_NAME, 222 | is_main_process: bool = True, 223 | remove_existing: bool = False, 224 | ) -> Path: 225 | if remove_existing and is_main_process: 226 | shutil.rmtree(directory_path, ignore_errors=True) 227 | directory_path = Path(directory_path) 228 | directory_path.mkdir(parents=True, exist_ok=True) 229 | 230 | if log_name: 231 | configure_logs(log_file=directory_path / log_name) 232 | 233 | if config is not None and is_main_process: 234 | dump_config(config, directory_path) 235 | 236 | return directory_path 237 | 238 | 239 | def _resolve_value(value: str) -> Any: 240 | if value.lower() == "true": 241 | return True 242 | if value.lower() == "false": 243 | return False 244 | 245 | with contextlib.suppress(ValueError): 246 | return int(value) 247 | with contextlib.suppress(ValueError): 248 | return float(value) 249 | 250 | if value == "None": 251 | return None 252 | 253 | return value 254 | 255 | 256 | def configure_yaml_multiline() -> None: 257 | # copied from SWE-agent 258 | def multiline_representer(dumper, data): 259 | """Configures yaml for dumping multiline strings. 260 | 261 | Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data. 262 | """ 263 | if data.count("\n") > 0: # check for multiline string 264 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") 265 | return dumper.represent_scalar("tag:yaml.org,2002:str", data) 266 | 267 | yaml.add_representer(str, multiline_representer) 268 | 269 | 270 | def _recursive_update(d: dict, u: dict) -> dict: 271 | for k, v in u.items(): 272 | if isinstance(v, dict): 273 | d[k] = _recursive_update(d.get(k, {}), v) 274 | else: 275 | d[k] = v 276 | return d 277 | 278 | 279 | CONFIGURATION_ENABLE = {"1", "true", "yes", "on"} 280 | CONFIGURATION_DISABLE = {"0", "false", "no", "off"} 281 | -------------------------------------------------------------------------------- /src/fhda/prompts.py: -------------------------------------------------------------------------------- 1 | # System prompt for bioinformatics tasks 2 | CAPSULE_SYSTEM_PROMPT_HYPOTHESIS = """ 3 | You are an expert bioinformatician and seasoned biological data scientist tasked with 4 | creating a Jupyter notebook to analyze data relating to a hypothesis. Your goal is to 5 | validate whether the data provided supports the hypothesis or not. 6 | """ 7 | 8 | CAPSULE_SYSTEM_PROMPT_MCQ = """ 9 | You are an expert bioinformatician and seasoned biological data scientist. 10 | Your task is to create a comprehensive Jupyter notebook named 'notebook.ipynb' that analyzes data to answer a series of Multiple Choice Questions (MCQs). 11 | The notebook should contain all necessary artifacts (plots, tables, print outputs) to fully answer these questions, structured in a way that another model could use to derive the answers. 12 | """ 13 | 14 | CAPSULE_SYSTEM_PROMPT_OPEN = """ 15 | You are an expert bioinformatician and seasoned biological data scientist. 16 | Your task is to create a comprehensive Jupyter notebook named 'notebook.ipynb' that analyzes data to answer a series of open-ended questions. 17 | The notebook should contain all necessary artifacts (plots, tables, print outputs) to fully answer these questions, structured in a way that another model could use to derive the answers. 18 | """ 19 | 20 | CAPSULE_SYSTEM_PROMPT_QUERY = """ 21 | You are an expert bioinformatician and seasoned biological data scientist. 22 | Your task is to create a comprehensive Jupyter notebook named 'notebook.ipynb' that analyzes data to answer a user query. 23 | The notebook should contain all necessary artifacts (plots, tables, print outputs) to fully answer these questions. 24 | Take your time to think through the question and the data before writing any code, explore the data rigorously and defend your conclusions rigorously. 25 | """ 26 | 27 | # Guidelines for R code output optimization 28 | R_SPECIFIC_GUIDELINES = """Guidelines for using the R programming language: 29 | 1. Load packages using this format to minimize verbose output: 30 | ```r 31 | if (!requireNamespace("package_name", quietly = TRUE)) {{ 32 | install.packages("package_name") 33 | }} 34 | suppressPackageStartupMessages(library(package_name)) 35 | ``` 36 | 2. You must use the tidyverse wherever possible: dplyr, tidyr, ggplot2, readr, stringr, forcats, purrr, tibble, and lubridate. 37 | 38 | 3. All plots must be made using ggplot2. Here is an example of how to make a plot: 39 | 40 | # Create a density scatter plot of FSC-A vs SSC-A 41 | plot_data <- as.data.frame(dmso_data[, c("FSC-A", "SSC-A")]) 42 | scatter_plot <- ggplot2::ggplot(plot_data, ggplot2::aes(x = `FSC-A`, y = `SSC-A`)) + 43 | ggplot2::geom_hex(bins = 100) + 44 | ggplot2::scale_fill_viridis_c(trans = "log10") + 45 | ggplot2::labs( 46 | title = "FSC-A vs SSC-A Density Plot (DMSO Control)", 47 | x = "FSC-A", 48 | y = "SSC-A" 49 | ) + 50 | ggplot2::theme_minimal() 51 | 52 | 3. Use explicit namespace qualification for functions. For example, use dplyr::select() instead of select(). 53 | 54 | 4. For data operations, suppress messages about column name repairs: 55 | ```r 56 | variable_name <- read_excel(".csv", col_names = FALSE, .name_repair = "minimal") 57 | ``` 58 | """ 59 | 60 | 61 | # General notebook guidelines 62 | GENERAL_NOTEBOOK_GUIDELINES = """ 63 | General Guidelines: 64 | - Write small to medium-sized cells for easier debugging. 65 | - Edit existing cells by their index number when fixing bugs, rather than creating new ones. 66 | - Check dataframe shapes before printing. Use head() for large dataframes. 67 | - Ensure each cell executes successfully before moving to the next. 68 | - Assume you already have the packages you need installed and only install new ones if you receive errors. 69 | - If you need to install packages, use pip or mamba. 70 | - All cells are by default {language} cells. Use {language} or bash tools for all analysis. 71 | - You can use bash cells by adding %%bash to the first line of the cell or running a subprocess. 72 | - You can only create code cells, no markdown cells. 73 | """ 74 | 75 | 76 | AVOID_IMAGES = """ 77 | AVOID USING PLOTS/IMAGES. USE TABLES AND PRINT OUTPUTS INSTEAD AS MUCH AS POSSIBLE. 78 | """ 79 | 80 | BASH_TOOL_USAGE = """ 81 | If you need to use Busco, you can use it through udocker as follows: 82 | 83 | ```bash 84 | # BUSCO Guidelines: 85 | # 1. Set up the required directory structure for BUSCO: 86 | mkdir busco_downloads 87 | mkdir busco_downloads/lineages 88 | mv busco_downloads/lineages # Move your downloaded lineage database 89 | 90 | # 2. Run BUSCO analysis on protein files: 91 | for protein_file in *. ; do 92 | output_name=$(echo "$protein_file" | sed "s/.$/.busco/g") 93 | udocker --allow-root run -u $(id -u) \ 94 | -v /content/:/busco_wd \ 95 | ezlabgva/busco:v5.8.0_cv1 \ 96 | busco -i $protein_file \ 97 | -m prot \ 98 | --offline \ 99 | -o $output_name \ 100 | -l 101 | done 102 | 103 | # Note: Replace the following placeholders: 104 | # - : Your downloaded BUSCO lineage database directory 105 | # - : Your protein file extension (e.g., faa, fasta) 106 | # - : Name of the BUSCO lineage to use (e.g., eukaryota_odb10) 107 | ``` 108 | 109 | You can also use mafft, clipkit, fastqc, iqtree, metaeuk, perl, phykit through the command line. 110 | """ 111 | 112 | # Agnostic to MCQ vs hypothesis 113 | CHAIN_OF_THOUGHT_AGNOSTIC = """ 114 | Follow these steps to create your notebook, using chain-of-thought reasoning at each stage: 115 | 116 | 1. Load Data and Perform Descriptive Statistics: 117 | 118 | - Identify which data files are most relevant to resolving the task. 119 | - Plan how to load these files efficiently in {language}. 120 | - List the specific descriptive statistics you plan to use (e.g., summary(), str(), head()). 121 | - Consider potential issues like missing data or unexpected formats. How will you handle each? 122 | - Plan how to present this information clearly in the notebook. 123 | - Write down key statistics you expect to see and how you'll interpret them. 124 | - Consider potential data quality issues and how you'll address them. 125 | 126 | Execute your plan to load data and perform descriptive statistics. 127 | 128 | 2. Develop Analysis Plan: 129 | 130 | - Break down each task into testable components. List these components. 131 | - For each component, list appropriate statistical tests or visualizations. 132 | - Consider alternative approaches for each component and justify your choices. 133 | - Identify potential confounding factors and how to address them. 134 | - Plan the sequence of your analysis steps, explaining the rationale for each. 135 | - Consider how this analysis plan will be documented in the notebook. 136 | - List potential statistical assumptions for your chosen methods and how you'll test them. 137 | - Think about how your analysis plan addresses your original task. 138 | 139 | Write out your analysis plan as comments in the notebook. 140 | 141 | 3. Execute Analysis Plan: 142 | 143 | - For each step in your analysis plan, list the {language} or bash functions and libraries you'll use. 144 | - Think about how to structure your code for readability and efficiency. 145 | - Plan how to document your code with clear comments. 146 | - Consider how to present results clearly, using tables or visualizations where appropriate. 147 | - Ensure that all outputs are clearly labeled and explained in the context of the task. 148 | - Plan how you'll interpret each result in relation to the original task. 149 | - Consider potential unexpected results and how you'll handle them. 150 | 151 | Execute your analysis plan, creating new cells as needed. 152 | 153 | 4. Conclude and Submit Answer: 154 | 155 | - Reflect on how your results relate to the original task. 156 | - Consider any limitations or uncertainties in your analysis. 157 | - Plan a concise summary of your findings. 158 | - Think about how to phrase your conclusion as clear statements. 159 | - Ensure that the notebook contains all necessary information for another model to derive these answers. 160 | - Consider any additional insights or patterns you've noticed during the analysis. 161 | - Think about potential follow-up questions or areas for further investigation. 162 | 163 | """ 164 | 165 | SUBMIT_ANSWER_HYPOTHESIS = """ 166 | [Use the submit_answer tool to submit your final answer as a single string either "True" or "False"] 167 | Remember, the final notebook should contain all necessary artifacts (plots, tables, print outputs) to solve the task provided. 168 | """ 169 | SUBMIT_ANSWER_SINGLE = """ 170 | [Use the submit_answer tool to submit your final answer as a single string] 171 | Example output: 172 | ``` 173 | submit_answer("CD94") or submit_answer("-1.23") 174 | ``` 175 | Remember, the final notebook should contain all necessary artifacts (plots, tables, print outputs) to solve the task provided. 176 | """ 177 | SUBMIT_ANSWER_OPEN = """ 178 | [Use the submit_answer tool to submit your final answer as a jsondictionary with keys as the question number and values as a short answer] 179 | Example output: 180 | ``` 181 | submit_answer({{ 182 | "q1": "Short answer to question 1", 183 | "q2": "Short answer to question 2", 184 | "q3": "Short answer to question 3", 185 | "q4": "Short answer to question 4" 186 | }}) 187 | ``` 188 | Remember, the final notebook should contain all necessary artifacts (plots, tables, print outputs) to solve the task provided. 189 | """ 190 | SUBMIT_ANSWER_MCQ = """ 191 | [Use the submit_answer tool to submit your final answer as a json dictionary with keys as the question number and values as the answer] 192 | Example output: 193 | ``` 194 | submit_answer({{ 195 | "q1": "A", 196 | "q2": "B", 197 | "q3": "C", 198 | "q4": "D" 199 | }}) 200 | Remember, the final notebook should contain all necessary artifacts (plots, tables, print outputs) to solve the task provided. 201 | """ 202 | 203 | HYPOTHESIS_PROMPT_TEMPLATE = f""" 204 | 205 | Here is the hypothesis you need to address: 206 | 207 | 208 | {{hypothesis}} 209 | 210 | 211 | {CHAIN_OF_THOUGHT_AGNOSTIC} 212 | {SUBMIT_ANSWER_HYPOTHESIS} 213 | {GENERAL_NOTEBOOK_GUIDELINES} 214 | {R_SPECIFIC_GUIDELINES} 215 | """ 216 | # MCQ 217 | MCQ_PROMPT_TEMPLATE = f""" 218 | Here are the questions you need to address: 219 | 220 | {{questions}} 221 | 222 | 223 | {CHAIN_OF_THOUGHT_AGNOSTIC} 224 | {SUBMIT_ANSWER_MCQ} 225 | {GENERAL_NOTEBOOK_GUIDELINES} 226 | {R_SPECIFIC_GUIDELINES} 227 | """ 228 | # Open answer 229 | OPEN_PROMPT_TEMPLATE = f""" 230 | Here are the questions you need to address: 231 | 232 | 233 | {{questions}} 234 | 235 | 236 | {CHAIN_OF_THOUGHT_AGNOSTIC} 237 | {SUBMIT_ANSWER_OPEN} 238 | {GENERAL_NOTEBOOK_GUIDELINES} 239 | {R_SPECIFIC_GUIDELINES} 240 | """ 241 | 242 | CONTINUATION_PROMPT_TEMPLATE = f""" 243 | {GENERAL_NOTEBOOK_GUIDELINES} 244 | 245 | You have been provided with a notebook previously generated by an agent based on a user's research question. 246 | 247 | This was the user's research question: 248 | 249 | {{previous_research_question}} 250 | 251 | 252 | This was the final answer generated by the previous agent: 253 | 254 | {{previous_final_answer}} 255 | 256 | 257 | The user has now tasked you with addressing a new query: 258 | 259 | {{query}} 260 | 261 | 262 | Please make any edits required to the notebook and the answer to address the new query. Be extremely diligent and ensure that the notebook is fully updated to address the new query. 263 | Note you may have to run all cells one by one again if the user query involved updating one of the intermediate cells and subsequent cells depend on it. 264 | Once you have updated the notebook, use the submit_answer tool to submit your final answer once the user's query is addressed. 265 | """ 266 | -------------------------------------------------------------------------------- /src/fhda/storage.py: -------------------------------------------------------------------------------- 1 | """Module containing storage utilities for Google Cloud Platform (GCP) and Google Cloud Storage (GCS).""" 2 | 3 | import asyncio 4 | import base64 5 | import concurrent.futures 6 | import logging 7 | import os 8 | import re 9 | import shutil 10 | from typing import Self 11 | 12 | import aiofiles 13 | import google.api_core.exceptions 14 | import google.auth 15 | import httpx 16 | from google.cloud import secretmanager 17 | from google.cloud.storage import Client 18 | from google_crc32c import Checksum 19 | from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator 20 | from requests.adapters import HTTPAdapter 21 | from tqdm import tqdm 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | DEFAULT_BUCKET = "aviary-storage" 26 | DEFAULT_KEY = os.path.expanduser("~/.keys/aviary-storage-service.json") 27 | DEFAULT_STORAGE_PATH = os.path.expanduser("~/aviary_data/") 28 | DEFAULT_GCP_PROJECT_ID = "362315315966" # Corresponds to "paperqa" project 29 | MAX_THREADS = 100 30 | 31 | 32 | def validate_google_app_creds() -> None: 33 | """Validate we have a google application credential set. 34 | 35 | Priority order: 36 | 1. GOOGLE_APPLICATION_CREDENTIALS environment variable 37 | 2. Default key path 38 | 3. Fetch key from Secret Manager (and cache in default key path) 39 | """ 40 | if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ: 41 | # This code path is mostly meant for CI, which uses a GitHub secret, 42 | # not a key file. 43 | return 44 | 45 | if os.path.exists(DEFAULT_KEY): 46 | return 47 | 48 | logger.info("aviary-storage-service account key not found, attempting to fetch...") 49 | client = secretmanager.SecretManagerServiceClient() 50 | try: 51 | response = client.access_secret_version( 52 | request={ 53 | "name": f"projects/{DEFAULT_GCP_PROJECT_ID}/secrets/AVIARY-STORAGE-SERVICE-KEY/versions/latest" 54 | } 55 | ) 56 | except google.api_core.exceptions.RetryError as e: 57 | # Could use better error handling here, but it's a little confusing how they chain exceptions 58 | raise RuntimeError( 59 | "Failed to fetch 'aviary-storage-service' key from Secret Manager. " 60 | "Confirm that you are authenticated by running `gcloud auth application-default login`" 61 | ) from e 62 | 63 | payload = response.payload.data.decode("UTF-8") 64 | os.makedirs(os.path.dirname(DEFAULT_KEY), exist_ok=True) 65 | with open(DEFAULT_KEY, "w") as f: # noqa: FURB103 66 | f.write(payload) 67 | logger.info( 68 | f"Successfully stored aviary-storage-service account key in {DEFAULT_KEY}." 69 | ) 70 | 71 | 72 | def auth_required(func): 73 | """Decorator to ensure that the user is authenticated with GCP before calling.""" 74 | 75 | def wrapper(*args, **kwargs): 76 | validate_google_app_creds() 77 | os.environ.setdefault("GOOGLE_APPLICATION_CREDENTIALS", DEFAULT_KEY) 78 | google.auth.default() # Check authentication 79 | return func(*args, **kwargs) 80 | 81 | return wrapper 82 | 83 | 84 | class DataRepo(BaseModel): 85 | model_config = ConfigDict(extra="forbid") 86 | 87 | name: str = Field( 88 | description=( 89 | "Subpath to the target directory within the cloud bucket `bucket`," 90 | " something like 'relative/path/to/sub/bucket'. Set to empty string to use" 91 | " the root of the bucket." 92 | ), 93 | ) 94 | 95 | local_path: str = Field( 96 | default="UNSET", 97 | description=( 98 | "Set to the target directory to mirror files. If left as the default of" 99 | " 'UNSET', it will be set to be /." 100 | ), 101 | ) 102 | bucket: str = Field( 103 | default=DEFAULT_BUCKET, 104 | description=( 105 | "Cloud bucket name like 'aviary-storage'. An analogy with a local" 106 | " filesystem is the drive (e.g. 'C:' on Windows)." 107 | ), 108 | ) 109 | 110 | validate_gcs_auth: bool = Field( 111 | default=True, 112 | description=( 113 | "Set True (default) to validate GCS authentication at construction time." 114 | ), 115 | ) 116 | 117 | def __bool__(self) -> bool: 118 | """Determines truthiness based on whether the name and local_path are set.""" 119 | return bool(self.name and self.local_path and self.local_path != "UNSET") 120 | 121 | @staticmethod 122 | def get_local_storage_path() -> str: 123 | return os.getenv("AVIARY_LOCAL_STORAGE", DEFAULT_STORAGE_PATH) 124 | 125 | @field_validator("name") 126 | @classmethod 127 | def _remove_slash(cls, value: str) -> str: 128 | return value.rstrip("/") 129 | 130 | @property 131 | def gcs_name(self) -> str: 132 | return f"{self.name}/" 133 | 134 | @model_validator(mode="after") 135 | def set_local_path(self) -> Self: 136 | if self.local_path == "UNSET": 137 | self.local_path = os.path.join(self.get_local_storage_path(), self.name) 138 | return self 139 | 140 | def mkdir(self, remove_existing: bool = False): 141 | if remove_existing: 142 | shutil.rmtree(self.local_path, ignore_errors=True) 143 | os.makedirs(self.local_path, exist_ok=True) 144 | 145 | @auth_required 146 | def push( 147 | self, 148 | overwrite: bool = False, 149 | include: re.Pattern | str | None = None, 150 | exclude: re.Pattern | str | None = None, 151 | progress: bool = False, 152 | ) -> None: 153 | logger.info(f"Pushing data repo: {self.name}") 154 | bucket = _get_gcs_client().get_bucket(self.bucket) 155 | 156 | include = _resolve_pattern(include) 157 | exclude = _resolve_pattern(exclude) 158 | 159 | # If overwrite is True, delete the contents of the bucket directory 160 | if overwrite: 161 | blobs = bucket.list_blobs(prefix=self.gcs_name) 162 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_THREADS) 163 | for blob in blobs: 164 | executor.submit(lambda b: b.delete(), blob) 165 | executor.shutdown(wait=True) 166 | 167 | def upload(local_path: str, blob_path: str): 168 | blob = bucket.blob(blob_path) 169 | 170 | # Check if the blob already exists and has the same hash 171 | if blob.exists(): 172 | blob.reload() # Ensure that the blob's metadata is up-to-date 173 | if blob.crc32c == compute_crc32c(local_path): 174 | pbar.update() 175 | return 176 | 177 | # Upload the file 178 | logger.debug(f"Pushing {local_path} to gcs://{blob_path}") 179 | blob.upload_from_filename(local_path) 180 | blob.patch() # Save metadata changes to GCS 181 | pbar.update() 182 | 183 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_THREADS) 184 | pbar = tqdm( 185 | disable=not progress, desc=f"Push [{self.name}]", unit="files", ncols=0 186 | ) 187 | 188 | # Walk through the local directory and upload each file 189 | count = 0 190 | for root, _, files in os.walk(self.local_path): 191 | for file in files: 192 | if file.endswith(".checksum"): 193 | continue 194 | 195 | local_path = os.path.join(root, file) 196 | blob_path = os.path.join( 197 | self.name, os.path.relpath(local_path, self.local_path) 198 | ) 199 | 200 | if not _passes_filters(include, exclude, local_path): 201 | continue 202 | 203 | executor.submit(upload, local_path, blob_path) 204 | count += 1 205 | 206 | pbar.total = count 207 | executor.shutdown(wait=True) 208 | pbar.close() 209 | 210 | @auth_required 211 | def pull( 212 | self, 213 | overwrite: bool = False, 214 | include: re.Pattern | str | None = None, 215 | exclude: re.Pattern | str | None = None, 216 | progress: bool = False, 217 | ): 218 | logger.info(f"Pulling data repo: {self.name}") 219 | bucket = _get_gcs_client().get_bucket(self.bucket) 220 | 221 | include = _resolve_pattern(include) 222 | exclude = _resolve_pattern(exclude) 223 | 224 | # If overwrite is True, delete the contents of the local directory 225 | if overwrite: 226 | shutil.rmtree(self.local_path) 227 | self.mkdir() 228 | 229 | def download(blob, local_path: str): 230 | blob.reload() 231 | if os.path.exists(local_path) and blob.crc32c == compute_crc32c(local_path): 232 | # print(f"Skipping {local_path}; no changes detected.") 233 | pbar.update() 234 | return 235 | 236 | local_dir_path = os.path.dirname(local_path) 237 | if not os.path.exists(local_dir_path): 238 | os.makedirs(local_dir_path) 239 | 240 | logger.debug(f"Pulling gcs://{blob.name} to {local_path}") 241 | blob.download_to_filename(local_path) 242 | with open(f"{local_path}.checksum", "w") as f: # noqa: FURB103 243 | f.write(blob.crc32c) 244 | pbar.update() 245 | 246 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_THREADS) 247 | pbar = tqdm( 248 | disable=not progress, desc=f"Pull [{self.name}]", unit=" files", ncols=0 249 | ) 250 | 251 | # Walk through the bucket directory and download each file 252 | blobs = bucket.list_blobs(prefix=self.gcs_name) 253 | count = 0 254 | n_name = len(self.gcs_name) 255 | for blob in blobs: 256 | local_path = os.path.join(self.local_path, blob.name[n_name:]) 257 | if local_path.endswith(".checksum"): 258 | # ??? 259 | continue 260 | 261 | if not _passes_filters(include, exclude, local_path): 262 | continue 263 | 264 | executor.submit(download, blob, local_path) 265 | count += 1 266 | 267 | pbar.total = count 268 | executor.shutdown(wait=True) 269 | pbar.close() 270 | 271 | @auth_required 272 | def remote_exists(self) -> bool: 273 | bucket = Client().get_bucket(self.bucket) 274 | return any(True for _ in bucket.list_blobs(prefix=self.gcs_name)) 275 | 276 | @model_validator(mode="after") 277 | def check_auth(self) -> Self: 278 | if self.validate_gcs_auth: 279 | self.remote_exists() # Validate we can connect to GCS 280 | return self 281 | 282 | def local_exists(self) -> bool: 283 | return os.path.exists(self.local_path) 284 | 285 | 286 | def compute_crc32c(path: str): 287 | checksum_path = f"{path}.checksum" 288 | if os.path.exists(checksum_path) and os.path.getmtime( 289 | checksum_path 290 | ) > os.path.getmtime(path): 291 | with open(checksum_path) as f: # noqa: FURB101 292 | return f.read() 293 | else: 294 | if os.path.getsize(path) > 500 * 1024 * 1024: 295 | logger.info(f"Computing checksum of {path}...") 296 | with open(path, "rb") as f: # noqa: FURB101 297 | data = f.read() 298 | crc32c = Checksum() 299 | crc32c.update(data) 300 | checksum = base64.b64encode(crc32c.digest()).decode("utf-8") 301 | with open(checksum_path, "w") as f: # noqa: FURB103 302 | f.write(checksum) 303 | return checksum 304 | 305 | 306 | def _resolve_pattern(pat: str | re.Pattern | None) -> re.Pattern | None: 307 | if isinstance(pat, str): 308 | try: 309 | pat = re.compile(pat) 310 | except re.error as e: 311 | raise ValueError(f'Invalid regex pattern "{pat}"') from e 312 | return pat 313 | 314 | 315 | def _passes_filters( 316 | include: re.Pattern | None, exclude: re.Pattern | None, string: str 317 | ) -> bool: 318 | if include is not None and not include.match(string): 319 | return False 320 | return not (exclude is not None and exclude.match(string)) 321 | 322 | 323 | def _get_gcs_client() -> Client: 324 | # patch in a HTTPAdapter with a larger pool size 325 | # from https://stackoverflow.com/a/77740153 326 | client = Client() 327 | adapter = HTTPAdapter(pool_connections=MAX_THREADS, pool_maxsize=MAX_THREADS) 328 | client._http.mount("https://", adapter) 329 | client._http._auth_request.session.mount("https://", adapter) 330 | return client 331 | 332 | 333 | async def download_file( 334 | client: httpx.AsyncClient, 335 | download_url: str, 336 | local_path: str | os.PathLike, 337 | file_name: str, 338 | headers: dict[str, str], 339 | timeout: float | None, 340 | ) -> None: 341 | """Download a single file. 342 | 343 | Args: 344 | client: httpx.AsyncClient 345 | download_url: URL to download. 346 | local_path: Local path to download file. 347 | file_name: Name of file to download. 348 | headers: Dictionary of headers. 349 | timeout: Timeout. 350 | 351 | """ 352 | response = await client.get(download_url, headers=headers, timeout=timeout) 353 | response.raise_for_status() 354 | 355 | file_path = os.path.join(local_path, file_name) 356 | async with aiofiles.open(file_path, "wb") as f: 357 | await f.write(response.content) 358 | print(f"Downloaded {file_path}") 359 | 360 | 361 | async def download_github_subdirectory( 362 | client: httpx.AsyncClient, 363 | repo_owner: str, 364 | repo_name: str, 365 | branch: str, 366 | subdirectory: str, 367 | local_path: str | os.PathLike, 368 | timeout: float | None, 369 | ) -> None: 370 | """Download a specific subdirectory from a GitHub repository. 371 | 372 | Args: 373 | client: httpx.AsyncClient 374 | repo_owner: GitHub repository owner. 375 | repo_name: GitHub repository name. 376 | branch: GitHub branch. 377 | subdirectory: Subdirectory to download. 378 | local_path: Local path to download to. 379 | timeout: Timeout. 380 | """ 381 | # Headers with the API version and authentication (optional) 382 | headers = { 383 | "Accept": "application/vnd.github.v3+json", 384 | "Authorization": "token " + os.environ["GITHUB_TOKEN"], 385 | } 386 | api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents/{subdirectory}?ref={branch}" 387 | response = await client.get(api_url, headers=headers, timeout=timeout) 388 | response.raise_for_status() # Check for HTTP errors 389 | items = response.json() 390 | 391 | if not os.path.exists(local_path): 392 | os.makedirs(local_path) 393 | 394 | coroutines = [] 395 | for item in items: 396 | if item["type"] == "file": 397 | coroutines.append( 398 | download_file( 399 | client, 400 | item["download_url"], 401 | local_path, 402 | item["name"], 403 | headers, 404 | timeout, 405 | ) 406 | ) 407 | elif item["type"] == "dir": 408 | new_subdir = os.path.join(subdirectory, item["name"]) 409 | new_local_path = os.path.join(local_path, item["name"]) 410 | coroutines.append( 411 | download_github_subdirectory( 412 | client, 413 | repo_owner, 414 | repo_name, 415 | branch, 416 | new_subdir, 417 | new_local_path, 418 | timeout, 419 | ) 420 | ) 421 | 422 | await asyncio.gather(*coroutines) 423 | -------------------------------------------------------------------------------- /src/fhda/templates/base/cell_id_anchor.j2: -------------------------------------------------------------------------------- 1 | {%- macro cell_id_anchor(cell) -%} 2 | {% if cell.id | length > 0 -%} 3 | id="{{ ('cell-id=' ~ cell.id) | escape_html -}}" 4 | {%- endif %} 5 | {%- endmacro %} 6 | -------------------------------------------------------------------------------- /src/fhda/templates/base/celltags.j2: -------------------------------------------------------------------------------- 1 | {%- macro celltags(cell) -%} 2 | {% if cell.metadata.tags | length > 0 -%} 3 | {% for tag in (cell.metadata.tags) -%} 4 | {{ (' celltag_' ~ tag) | escape_html -}} 5 | {%- endfor -%} 6 | {%- endif %} 7 | {%- endmacro %} 8 | -------------------------------------------------------------------------------- /src/fhda/templates/base/display_priority.j2: -------------------------------------------------------------------------------- 1 | {%- extends 'base/null.j2' -%} 2 | 3 | {#display data priority#} 4 | 5 | 6 | {%- block data_priority scoped -%} 7 | {%- for type in output.data | filter_data_type -%} 8 | {%- if type == 'application/pdf' -%} 9 | {%- block data_pdf -%} 10 | {%- endblock -%} 11 | {%- elif type == 'image/svg+xml' -%} 12 | {%- block data_svg -%} 13 | {%- endblock -%} 14 | {%- elif type == 'image/png' -%} 15 | {%- block data_png -%} 16 | {%- endblock -%} 17 | {%- elif type == 'text/html' -%} 18 | {%- block data_html -%} 19 | {%- endblock -%} 20 | {%- elif type == 'text/markdown' -%} 21 | {%- block data_markdown -%} 22 | {%- endblock -%} 23 | {%- elif type == 'image/jpeg' -%} 24 | {%- block data_jpg -%} 25 | {%- endblock -%} 26 | {%- elif type == 'text/plain' -%} 27 | {%- block data_text -%} 28 | {%- endblock -%} 29 | {%- elif type == 'text/latex' -%} 30 | {%- block data_latex -%} 31 | {%- endblock -%} 32 | {%- elif type == 'text/vnd.mermaid' -%} 33 | {%- block data_mermaid -%} 34 | {%- endblock -%} 35 | {%- elif type == 'application/javascript' -%} 36 | {%- block data_javascript -%} 37 | {%- endblock -%} 38 | {%- elif type == 'application/vnd.jupyter.widget-view+json' -%} 39 | {%- block data_widget_view -%} 40 | {%- endblock -%} 41 | {%- elif type == resources.output_mimetype -%} 42 | {%- block data_native -%} 43 | {%- endblock -%} 44 | {%- else -%} 45 | {%- block data_other -%} 46 | {%- endblock -%} 47 | {%- endif -%} 48 | {%- endfor -%} 49 | {%- endblock data_priority -%} 50 | -------------------------------------------------------------------------------- /src/fhda/templates/base/jupyter_widgets.html.j2: -------------------------------------------------------------------------------- 1 | {%- macro jupyter_widgets(widgets_cdn_url, html_manager_semver_range, widget_renderer_url='') -%} 2 | 3 | 35 | 36 | {%- endmacro %} 37 | -------------------------------------------------------------------------------- /src/fhda/templates/base/mathjax.html.j2: -------------------------------------------------------------------------------- 1 | 2 | {%- macro mathjax(url="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML-full,Safe") -%} 3 | 4 | 5 | 6 | 37 | 38 | {%- endmacro %} 39 | -------------------------------------------------------------------------------- /src/fhda/templates/base/null.j2: -------------------------------------------------------------------------------- 1 | {# 2 | 3 | DO NOT USE THIS AS A BASE, 4 | IF YOU ARE COPY AND PASTING THIS FILE 5 | YOU ARE PROBABLY DOING THINGS INCORRECTLY. 6 | 7 | Null template, does nothing except defining a basic structure 8 | To layout the different blocks of a notebook. 9 | 10 | Subtemplates can override blocks to define their custom representation. 11 | 12 | If one of the block you do overwrite is not a leaf block, consider 13 | calling super. 14 | 15 | {%- block nonLeafBlock -%} 16 | #add stuff at beginning 17 | {{ super() }} 18 | #add stuff at end 19 | {%- endblock nonLeafBlock -%} 20 | 21 | consider calling super even if it is a leaf block, we might insert more blocks later. 22 | 23 | #} 24 | {%- block header -%} 25 | {%- endblock header -%} 26 | {%- block body -%} 27 | {%- block body_header -%} 28 | {%- endblock body_header -%} 29 | {%- block body_loop -%} 30 | {%- for cell in nb.cells -%} 31 | {%- block any_cell scoped -%} 32 | {%- if cell.cell_type == 'code'-%} 33 | {%- if resources.global_content_filter.include_code -%} 34 | {%- block codecell scoped -%} 35 | {%- if resources.global_content_filter.include_input and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 36 | {%- block input_group -%} 37 | {%- if resources.global_content_filter.include_input_prompt -%} 38 | {%- block in_prompt -%}{%- endblock in_prompt -%} 39 | {%- endif -%} 40 | {%- block input -%}{%- endblock input -%} 41 | {%- endblock input_group -%} 42 | {%- endif -%} 43 | {%- if cell.outputs and resources.global_content_filter.include_output -%} 44 | {%- block output_group -%} 45 | {%- if resources.global_content_filter.include_output_prompt -%} 46 | {%- block output_prompt -%}{%- endblock output_prompt -%} 47 | {%- endif -%} 48 | {%- block outputs scoped -%} 49 | {%- for output in cell.outputs -%} 50 | {%- block output scoped -%} 51 | {%- if output.output_type == 'execute_result' -%} 52 | {%- block execute_result scoped -%}{%- endblock execute_result -%} 53 | {%- elif output.output_type == 'stream' -%} 54 | {%- block stream scoped -%} 55 | {%- if output.name == 'stdout' -%} 56 | {%- block stream_stdout scoped -%} 57 | {%- endblock stream_stdout -%} 58 | {%- elif output.name == 'stderr' -%} 59 | {%- block stream_stderr scoped -%} 60 | {%- endblock stream_stderr -%} 61 | {%- elif output.name == 'stdin' -%} 62 | {%- block stream_stdin scoped -%} 63 | {%- endblock stream_stdin -%} 64 | {%- endif -%} 65 | {%- endblock stream -%} 66 | {%- elif output.output_type == 'display_data' -%} 67 | {%- block display_data scoped -%} 68 | {%- block data_priority scoped -%} 69 | {%- endblock data_priority -%} 70 | {%- endblock display_data -%} 71 | {%- elif output.output_type == 'error' -%} 72 | {%- block error scoped -%} 73 | {%- for line in output.traceback -%} 74 | {%- block traceback_line scoped -%}{%- endblock traceback_line -%} 75 | {%- endfor -%} 76 | {%- endblock error -%} 77 | {%- endif -%} 78 | {%- endblock output -%} 79 | {%- endfor -%} 80 | {%- endblock outputs -%} 81 | {%- endblock output_group -%} 82 | {%- endif -%} 83 | {%- endblock codecell -%} 84 | {%- endif -%} 85 | {%- elif cell.cell_type in ['markdown'] -%} 86 | {%- if resources.global_content_filter.include_markdown and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 87 | {%- block markdowncell scoped-%} {%- endblock markdowncell -%} 88 | {%- endif -%} 89 | {%- elif cell.cell_type in ['raw'] -%} 90 | {%- if resources.global_content_filter.include_raw and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 91 | {%- block rawcell scoped -%} 92 | {%- if cell.metadata.get('raw_mimetype', '').lower() in resources.get('raw_mimetypes', ['']) -%} 93 | {{ cell.source }} 94 | {%- endif -%} 95 | {%- endblock rawcell -%} 96 | {%- endif -%} 97 | {%- else -%} 98 | {%- if resources.global_content_filter.include_unknown and not cell.metadata.get("transient",{}).get("remove_source", false) -%} 99 | {%- block unknowncell scoped-%} 100 | {%- endblock unknowncell -%} 101 | {%- endif -%} 102 | {%- endif -%} 103 | {%- endblock any_cell -%} 104 | {%- endfor -%} 105 | {%- endblock body_loop -%} 106 | {%- block body_footer -%} 107 | {%- endblock body_footer -%} 108 | {%- endblock body -%} 109 | 110 | {%- block footer -%} 111 | {%- endblock footer -%} 112 | -------------------------------------------------------------------------------- /src/fhda/templates/lab/base.html.j2: -------------------------------------------------------------------------------- 1 | {%- extends 'display_priority.j2' -%} 2 | {% from 'celltags.j2' import celltags %} 3 | {% from 'cell_id_anchor.j2' import cell_id_anchor %} 4 | 5 | {% block codecell %} 6 | {%- if not cell.outputs -%} 7 | {%- set no_output_class="jp-mod-noOutputs" -%} 8 | {%- endif -%} 9 | {%- if not resources.global_content_filter.include_input -%} 10 | {%- set no_input_class="jp-mod-noInput" -%} 11 | {%- endif -%} 12 | 15 | {%- endblock codecell %} 16 | 17 | {% block input_group -%} 18 | 25 | {% endblock input_group %} 26 | 27 | {% block input %} 28 | 33 | {%- endblock input %} 34 | 35 | {% block output_group %} 36 | 41 | {% endblock output_group %} 42 | 43 | {% block outputs %} 44 | 47 | {% endblock outputs %} 48 | 49 | {% block in_prompt -%} 50 | 57 | {%- endblock in_prompt %} 58 | 59 | {% block empty_in_prompt -%} 60 | 62 | {%- endblock empty_in_prompt %} 63 | 64 | {# 65 | output_prompt doesn't do anything in HTML, 66 | because there is a prompt div in each output area (see output block) 67 | #} 68 | {% block output_prompt %} 69 | {% endblock output_prompt %} 70 | 71 | {% block output_area_prompt %} 72 | 81 | {% endblock output_area_prompt %} 82 | 83 | {% block output %} 84 | {%- if output.output_type == 'execute_result' -%} 85 |