├── .gitignore ├── LICENSE ├── README.md ├── ai_scientist ├── __init__.py ├── blank_icbinb_latex │ ├── fancyhdr.sty │ ├── iclr2025.bib │ ├── iclr2025.bst │ ├── iclr2025.sty │ ├── math_commands.tex │ ├── natbib.sty │ └── template.tex ├── blank_icml_latex │ ├── algorithm.sty │ ├── algorithmic.sty │ ├── icml2025.bst │ ├── icml2025.sty │ └── template.tex ├── fewshot_examples │ ├── 132_automated_relational.json │ ├── 132_automated_relational.pdf │ ├── 132_automated_relational.txt │ ├── 2_carpe_diem.json │ ├── 2_carpe_diem.pdf │ ├── 2_carpe_diem.txt │ ├── attention.json │ ├── attention.pdf │ └── attention.txt ├── ideas │ ├── i_cant_believe_its_not_better.json │ ├── i_cant_believe_its_not_better.md │ ├── i_cant_believe_its_not_better.py │ ├── i_cant_believe_its_not_betterrealworld.json │ └── i_cant_believe_its_not_betterrealworld.py ├── llm.py ├── perform_icbinb_writeup.py ├── perform_ideation_temp_free.py ├── perform_llm_review.py ├── perform_plotting.py ├── perform_vlm_review.py ├── perform_writeup.py ├── tools │ ├── __init__.py │ ├── base_tool.py │ └── semantic_scholar.py ├── treesearch │ ├── __init__.py │ ├── agent_manager.py │ ├── backend │ │ ├── __init__.py │ │ ├── backend_anthropic.py │ │ ├── backend_openai.py │ │ └── utils.py │ ├── bfts_utils.py │ ├── interpreter.py │ ├── journal.py │ ├── journal2report.py │ ├── log_summarization.py │ ├── parallel_agent.py │ ├── perform_experiments_bfts_with_agentmanager.py │ └── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── data_preview.py │ │ ├── metric.py │ │ ├── response.py │ │ ├── serialize.py │ │ ├── tree_export.py │ │ └── viz_templates │ │ ├── template.html │ │ └── template.js ├── utils │ └── token_tracker.py └── vlm.py ├── bfts_config.yaml ├── docs └── logo_v1.png ├── launch_scientist_bfts.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | .venv_jax 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | .idea/ 164 | .aider* 165 | *.DS_Store 166 | 167 | # VS Code 168 | .vscode/ 169 | 170 | # Misc folders 171 | data/ 172 | *ckpt.pt 173 | *.zip 174 | templates/*/run_0/ 175 | templates/*/*.png 176 | results/ 177 | # logs/ 178 | 179 | # data folders under review_iclr_bench/ 180 | review_iclr_bench/iclr_papers/ 181 | review_iclr_bench/iclr_parsed_w_img/ 182 | review_iclr_bench/generated_paper/ 183 | review_iclr_bench/iclr_parsed/ 184 | 185 | # Experiment logs 186 | experiments/ 187 | aisci_outputs/ 188 | slurm_logs/ 189 | .out 190 | 191 | # final papers 192 | final_papers/ 193 | 194 | # HF 195 | cache/ 196 | huggingface/ 197 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Rémi Louf 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | AI Scientist v2 Logo 4 | 5 |

6 | The AI Scientist-v2: Workshop-Level Automated
7 | Scientific Discovery via Agentic Tree Search 8 |

9 |
10 | 11 |

12 | 📚 [Paper] | 13 | 📝 [Blog Post] | 14 | 📂 [ICLR2025 Workshop Experiment] 15 |

16 | 17 | Fully autonomous scientific research systems are becoming increasingly capable, with AI playing a pivotal role in transforming how scientific discoveries are made. 18 | We are excited to introduce The AI Scientist-v2, a generalized end-to-end agentic system that has generated the first workshop paper written entirely by AI and accepted through peer review. 19 | 20 | This system autonomously generates hypotheses, runs experiments, analyzes data, and writes scientific manuscripts. Unlike [its predecessor (AI Scientist-v1)](https://github.com/SakanaAI/AI-Scientist), the AI Scientist-v2 removes reliance on human-authored templates, generalizes across Machine Learning (ML) domains, and employs a progressive agentic tree search, guided by an experiment manager agent. 21 | 22 | > **Note:** 23 | > The AI Scientist-v2 doesn’t necessarily produce better papers than v1, especially when a strong starting template is available. v1 follows well-defined templates, leading to high success rates, while v2 takes a broader, more exploratory approach with lower success rates. v1 works best for tasks with clear objectives and a solid foundation, whereas v2 is designed for open-ended scientific exploration. 24 | 25 | > **Caution!** 26 | > This codebase will execute Large Language Model (LLM)-written code. There are various risks and challenges associated with this autonomy, including the potential use of dangerous packages, uncontrolled web access, and the possibility of spawning unintended processes. Ensure that you run this within a controlled sandbox environment (e.g., a Docker container). Use at your own discretion. 27 | 28 | ## Table of Contents 29 | 30 | 1. [Requirements](#requirements) 31 | * [Installation](#installation) 32 | * [Supported Models and API Keys](#supported-models-and-api-keys) 33 | 2. [Generate Research Ideas](#generate-research-ideas) 34 | 3. [Run AI Scientist-v2 Paper Generation Experiments](#run-ai-scientist-v2-paper-generation-experiments) 35 | 4. [Citing The AI Scientist-v2](#citing-the-ai-scientist-v2) 36 | 5. [Frequently Asked Questions](#frequently-asked-questions) 37 | 6. [Acknowledgement](#acknowledgement) 38 | 39 | ## Requirements 40 | 41 | This code is designed to run on Linux with NVIDIA GPUs using CUDA and PyTorch. 42 | 43 | ### Installation 44 | 45 | ```bash 46 | # Create a new conda environment 47 | conda create -n ai_scientist python=3.11 48 | conda activate ai_scientist 49 | 50 | # Install PyTorch with CUDA support (adjust pytorch-cuda version for your setup) 51 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia 52 | 53 | # Install PDF and LaTeX tools 54 | conda install anaconda::poppler 55 | conda install conda-forge::chktex 56 | 57 | # Install Python package requirements 58 | pip install -r requirements.txt 59 | ``` 60 | 61 | ### Supported Models and API Keys 62 | 63 | #### OpenAI Models 64 | 65 | By default, the system uses the `OPENAI_API_KEY` environment variable for OpenAI models. 66 | 67 | #### Claude Models via AWS Bedrock 68 | 69 | To use Claude models provided by Amazon Bedrock, install the necessary additional packages: 70 | ```bash 71 | pip install anthropic[bedrock] 72 | ``` 73 | Next, configure valid [AWS Credentials](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html) and the target [AWS Region](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-regions.html) by setting the following environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION_NAME`. 74 | 75 | #### Semantic Scholar API (Literature Search) 76 | 77 | Our code can optionally use a Semantic Scholar API Key (`S2_API_KEY`) for higher throughput during literature search [if you have one](https://www.semanticscholar.org/product/api). This is used during both the ideation and paper writing stages. The system should work without it, though you might encounter rate limits or reduced novelty checking during ideation. If you experience issues with Semantic Scholar, you can skip the citation phase during paper generation. 78 | 79 | #### Setting API Keys 80 | 81 | Ensure you provide the necessary API keys as environment variables for the models you intend to use. For example: 82 | ```bash 83 | export OPENAI_API_KEY="YOUR_OPENAI_KEY_HERE" 84 | export S2_API_KEY="YOUR_S2_KEY_HERE" 85 | # Set AWS credentials if using Bedrock 86 | # export AWS_ACCESS_KEY_ID="YOUR_AWS_ACCESS_KEY_ID" 87 | # export AWS_SECRET_ACCESS_KEY="YOUR_AWS_SECRET_KEY" 88 | # export AWS_REGION_NAME="your-aws-region" 89 | ``` 90 | 91 | ## Generate Research Ideas 92 | 93 | Before running the full AI Scientist-v2 experiment pipeline, you first use the `ai_scientist/perform_ideation_temp_free.py` script to generate potential research ideas. This script uses an LLM to brainstorm and refine ideas based on a high-level topic description you provide, interacting with tools like Semantic Scholar to check for novelty. 94 | 95 | 1. **Prepare a Topic Description:** Create a Markdown file (e.g., `my_research_topic.md`) describing the research area or theme you want the AI to explore. This file should contain sections like `Title`, `Keywords`, `TL;DR`, and `Abstract` to define the scope of the research. Refer to the example file `ai_scientist/ideas/i_cant_believe_its_not_better.md` for the expected structure and content format. Place your file in a location accessible by the script (e.g., the `ai_scientist/ideas/` directory). 96 | 97 | 2. **Run the Ideation Script:** Execute the script from the main project directory, pointing it to your topic description file and specifying the desired LLM. 98 | 99 | ```bash 100 | python ai_scientist/perform_ideation_temp_free.py \ 101 | --workshop-file "ai_scientist/ideas/my_research_topic.md" \ 102 | --model gpt-4o-2024-05-13 \ 103 | --max-num-generations 20 \ 104 | --num-reflections 5 105 | ``` 106 | * `--workshop-file`: Path to your topic description Markdown file. 107 | * `--model`: The LLM to use for generating ideas (ensure you have the corresponding API key set). 108 | * `--max-num-generations`: How many distinct research ideas to attempt generating. 109 | * `--num-reflections`: How many refinement steps the LLM should perform for each idea. 110 | 111 | 3. **Output:** The script will generate a JSON file named after your input Markdown file (e.g., `ai_scientist/ideas/my_research_topic.json`). This file will contain a list of structured research ideas, including hypotheses, proposed experiments, and related work analysis. 112 | 113 | 4. **Proceed to Experiments:** Once you have the generated JSON file containing research ideas, you can proceed to the next section to run the experiments. 114 | 115 | This ideation step guides the AI Scientist towards specific areas of interest and produces concrete research directions to be tested in the main experimental pipeline. 116 | 117 | ## Run AI Scientist-v2 Paper Generation Experiments 118 | 119 | Using the JSON file generated in the previous ideation step, you can now launch the main AI Scientist-v2 pipeline. This involves running experiments via agentic tree search, analyzing results, and generating a paper draft. 120 | 121 | Specify the models used for the write-up and review phases via command-line arguments. 122 | The configuration for the best-first tree search (BFTS) is located in `bfts_config.yaml`. Adjust parameters in this file as needed. 123 | 124 | Key tree search configuration parameters in `bfts_config.yaml`: 125 | 126 | - `agent` config: 127 | - Set `num_workers` (number of parallel exploration paths) and `steps` (maximum number of nodes to explore). For example, if `num_workers=3` and `steps=21`, the tree search will explore up to 21 nodes, expanding 3 nodes concurrently at each step. 128 | - `num_seeds`: Should generally be the same as `num_workers` if `num_workers` is less than 3. Otherwise, set `num_seeds` to 3. 129 | - Note: Other agent parameters like `k_fold_validation`, `expose_prediction`, and `data_preview` are not used in the current version. 130 | - `search` config: 131 | - `max_debug_depth`: The maximum number of times the agent will attempt to debug a failing node before abandoning that search path. 132 | - `debug_prob`: The probability of attempting to debug a failing node. 133 | - `num_drafts`: The number of initial root nodes (i.e., the number of independent trees to grow) during Stage 1. 134 | 135 | Example command to run AI-Scientist-v2 using a generated idea file (e.g., `my_research_topic.json`). Please review `bfts_config.yaml` for detailed tree search parameters (the default config includes `claude-3-5-sonnet` for experiments). Do not set `load_code` if you do not want to initialize experimentation with a code snippet. 136 | 137 | ```bash 138 | python launch_scientist_bfts.py \ 139 | --load_ideas "ai_scientist/ideas/my_research_topic.json" \ 140 | --load_code \ 141 | --add_dataset_ref \ 142 | --model_writeup o1-preview-2024-09-12 \ 143 | --model_citation gpt-4o-2024-11-20 \ 144 | --model_review gpt-4o-2024-11-20 \ 145 | --model_agg_plots o3-mini-2025-01-31 \ 146 | --num_cite_rounds 20 147 | ``` 148 | 149 | Once the initial experimental stage is complete, you will find a timestamped log folder inside the `experiments/` directory. Navigate to `experiments/"timestamp_ideaname"/logs/0-run/` within that folder to find the tree visualization file `unified_tree_viz.html`. 150 | 151 | ## Citing The AI Scientist-v2 152 | 153 | If you use **The AI Scientist-v2** in your research, please cite our work as follows: 154 | 155 | ```bibtex 156 | @article{aiscientist_v2, 157 | title={The AI Scientist-v2: Workshop-Level Automated Scientific Discovery via Agentic Tree Search}, 158 | author={Yamada, Yutaro and Lange, Robert Tjarko and Lu, Cong and Hu, Shengran and Lu, Chris and Foerster, Jakob and Clune, Jeff and Ha, David}, 159 | journal={arXiv preprint arXiv:2504.08066}, 160 | year={2025} 161 | } 162 | ``` 163 | 164 | ## Frequently Asked Questions 165 | 166 | **Why wasn't a PDF or a review generated for my experiment?** 167 | 168 | The AI Scientist-v2 completes experiments with a success rate that depends on the chosen foundation model, and the complexity of the idea. Higher success rates are generally observed when using powerful models like Claude 3.5 Sonnet for the experimentation phase. 169 | 170 | **What is the estimated cost per experiment?** 171 | 172 | The ideation step cost depends on the LLM used and the number of generations/reflections, but is generally low (a few dollars). For the main experiment pipeline, using Claude 3.5 Sonnet for the experimentation phase typically costs around $15–$20 per run. The subsequent writing phase adds approximately $5 when using the default models specified in the example command. Using GPT-4o for `model_citation` is recommended as it can help reduce writing costs. 173 | 174 | **How do I run The AI Scientist-v2 for different subject fields?** 175 | 176 | First, perform the [Generate Research Ideas](#generate-research-ideas) step. Create a new Markdown file describing your desired subject field or topic, following the structure of the example `ai_scientist/ideas/i_cant_believe_its_not_better.md`. Run the `perform_ideation_temp_free.py` script with this file to generate a corresponding JSON idea file. Then, proceed to the [Run AI Scientist-v2 Paper Generation Experiments](#run-ai-scientist-v2-paper-generation-experiments) step, using this JSON file with the `launch_scientist_bfts.py` script via the `--load_ideas` argument. 177 | 178 | **What should I do if I have problems accessing the Semantic Scholar API?** 179 | 180 | The Semantic Scholar API is used to assess the novelty of generated ideas and to gather citations during the paper write-up phase. If you don't have an API key, encounter rate limits, you may be able to skip these phases. 181 | 182 | **I encountered a "CUDA Out of Memory" error. What can I do?** 183 | 184 | This error typically occurs when the AI Scientist-v2 attempts to load or run a model that requires more GPU memory than available on your system. To resolve this, you can try updating your ideation prompt file (`ai_scientist/ideas/my_research_topic.md`) to suggest using smaller models for the experiments. 185 | 186 | ## Acknowledgement 187 | 188 | The tree search component implemented within the `ai_scientist` directory is built on top of the [AIDE](https://github.com/WecoAI/aideml) project. We thank the AIDE developers for their valuable contributions and for making their work publicly available. 189 | 190 | 191 | ## Star History 192 | 193 | [![Star History Chart](https://api.star-history.com/svg?repos=SakanaAI/AI-Scientist-v2&type=Date)](https://star-history.com/#SakanaAI/AI-Scientist-v2&Date) 194 | 195 | -------------------------------------------------------------------------------- /ai_scientist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/ai_scientist/__init__.py -------------------------------------------------------------------------------- /ai_scientist/blank_icbinb_latex/iclr2025.bib: -------------------------------------------------------------------------------- 1 | @incollection{Bengio+chapter2007, 2 | author = {Bengio, Yoshua and LeCun, Yann}, 3 | booktitle = {Large Scale Kernel Machines}, 4 | publisher = {MIT Press}, 5 | title = {Scaling Learning Algorithms Towards {AI}}, 6 | year = {2007} 7 | } 8 | 9 | @article{Hinton06, 10 | author = {Hinton, Geoffrey E. and Osindero, Simon and Teh, Yee Whye}, 11 | journal = {Neural Computation}, 12 | pages = {1527--1554}, 13 | title = {A Fast Learning Algorithm for Deep Belief Nets}, 14 | volume = {18}, 15 | year = {2006} 16 | } 17 | 18 | @book{goodfellow2016deep, 19 | title={Deep learning}, 20 | author={Goodfellow, Ian and Bengio, Yoshua and Courville, Aaron and Bengio, Yoshua}, 21 | volume={1}, 22 | year={2016}, 23 | publisher={MIT Press} 24 | } -------------------------------------------------------------------------------- /ai_scientist/blank_icbinb_latex/iclr2025.sty: -------------------------------------------------------------------------------- 1 | %%%% ICLR Macros (LaTex) 2 | %%%% Adapted by Hugo Larochelle from the NIPS stylefile Macros 3 | %%%% Style File 4 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999; October 2014 5 | 6 | % This file can be used with Latex2e whether running in main mode, or 7 | % 2.09 compatibility mode. 8 | % 9 | % If using main mode, you need to include the commands 10 | % \documentclass{article} 11 | % \usepackage{iclr14submit_e,times} 12 | % 13 | 14 | % Change the overall width of the page. If these parameters are 15 | % changed, they will require corresponding changes in the 16 | % maketitle section. 17 | % 18 | \usepackage{eso-pic} % used by \AddToShipoutPicture 19 | \RequirePackage{fancyhdr} 20 | \RequirePackage{natbib} 21 | 22 | % modification to natbib citations 23 | \setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}} 24 | 25 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 26 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 27 | 28 | % Define iclrfinal, set to true if iclrfinalcopy is defined 29 | \newif\ificlrfinal 30 | \iclrfinalfalse 31 | \def\iclrfinalcopy{\iclrfinaltrue} 32 | \font\iclrtenhv = phvb at 8pt 33 | 34 | % Specify the dimensions of each page 35 | 36 | \setlength{\paperheight}{11in} 37 | \setlength{\paperwidth}{8.5in} 38 | 39 | 40 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 41 | \evensidemargin .5in 42 | \marginparwidth 0.07 true in 43 | %\marginparwidth 0.75 true in 44 | %\topmargin 0 true pt % Nominal distance from top of page to top of 45 | %\topmargin 0.125in 46 | \topmargin -0.625in 47 | \addtolength{\headsep}{0.25in} 48 | \textheight 9.0 true in % Height of text (including footnotes & figures) 49 | \textwidth 5.5 true in % Width of text line. 50 | \widowpenalty=10000 51 | \clubpenalty=10000 52 | 53 | % \thispagestyle{empty} \pagestyle{empty} 54 | \flushbottom \sloppy 55 | 56 | % We're never going to need a table of contents, so just flush it to 57 | % save space --- suggested by drstrip@sandia-2 58 | \def\addcontentsline#1#2#3{} 59 | 60 | % Title stuff, taken from deproc. 61 | \def\maketitle{\par 62 | \begingroup 63 | \def\thefootnote{\fnsymbol{footnote}} 64 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 65 | % name centering 66 | % The footnote-mark was overlapping the footnote-text, 67 | % added the following to fix this problem (MK) 68 | \long\def\@makefntext##1{\parindent 1em\noindent 69 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 70 | \@maketitle \@thanks 71 | \endgroup 72 | \setcounter{footnote}{0} 73 | \let\maketitle\relax \let\@maketitle\relax 74 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 75 | 76 | % The toptitlebar has been raised to top-justify the first page 77 | 78 | \usepackage{fancyhdr} 79 | \pagestyle{fancy} 80 | \fancyhead{} 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | %\linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\sc \@title\par} 86 | %\bottomtitlebar % \vskip 0.1in % minus 87 | \ificlrfinal 88 | \lhead{I Can't Believe It's Not Better Workshop @ ICLR 2025} 89 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 90 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 91 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 92 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 93 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 94 | \else 95 | \lhead{Under review as a workshop paper at ICLR 2025} 96 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 97 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 98 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 99 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 100 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}Anonymous authors\\Paper under double-blind review\end{tabular}% 101 | \fi 102 | \vskip 0.3in minus 0.1in}} 103 | 104 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\sc 105 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 106 | 107 | % sections with less space 108 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 109 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 110 | minus0.2ex}{\large\sc\raggedright}} 111 | 112 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 113 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\sc\raggedright}} 114 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 115 | plus -0.5ex minus -.2ex}{0.5ex plus 116 | .2ex}{\normalsize\sc\raggedright}} 117 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 118 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 119 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 120 | 0.5ex minus .2ex}{-1em}{\normalsize\sc}} 121 | \def\subsubsubsection{\vskip 122 | 5pt{\noindent\normalsize\rm\raggedright}} 123 | 124 | 125 | % Footnotes 126 | \footnotesep 6.65pt % 127 | \skip\footins 9pt plus 4pt minus 2pt 128 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 129 | \setcounter{footnote}{0} 130 | 131 | % Lists and paragraphs 132 | \parindent 0pt 133 | \topsep 4pt plus 1pt minus 2pt 134 | \partopsep 1pt plus 0.5pt minus 0.5pt 135 | \itemsep 2pt plus 1pt minus 0.5pt 136 | \parsep 2pt plus 1pt minus 0.5pt 137 | \parskip .5pc 138 | 139 | 140 | %\leftmargin2em 141 | \leftmargin3pc 142 | \leftmargini\leftmargin \leftmarginii 2em 143 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 144 | 145 | %\labelsep \labelsep 5pt 146 | 147 | \def\@listi{\leftmargin\leftmargini} 148 | \def\@listii{\leftmargin\leftmarginii 149 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 150 | \topsep 2pt plus 1pt minus 0.5pt 151 | \parsep 1pt plus 0.5pt minus 0.5pt 152 | \itemsep \parsep} 153 | \def\@listiii{\leftmargin\leftmarginiii 154 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 155 | \topsep 1pt plus 0.5pt minus 0.5pt 156 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 157 | \itemsep \topsep} 158 | \def\@listiv{\leftmargin\leftmarginiv 159 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 160 | \def\@listv{\leftmargin\leftmarginv 161 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 162 | \def\@listvi{\leftmargin\leftmarginvi 163 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 164 | 165 | \abovedisplayskip 7pt plus2pt minus5pt% 166 | \belowdisplayskip \abovedisplayskip 167 | \abovedisplayshortskip 0pt plus3pt% 168 | \belowdisplayshortskip 4pt plus3pt minus3pt% 169 | 170 | % Less leading in most fonts (due to the narrow columns) 171 | % The choices were between 1-pt and 1.5-pt leading 172 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 173 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 174 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 175 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 176 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 177 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 178 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 179 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 180 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 181 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 182 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 183 | 184 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 185 | 186 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 187 | .09in} % 188 | %Reduced second vskip to compensate for adding the strut in \@author 189 | 190 | 191 | 192 | %% % Vertical Ruler 193 | %% % This code is, largely, from the CVPR 2010 conference style file 194 | %% % ----- define vruler 195 | \makeatletter 196 | \newbox\iclrrulerbox 197 | \newcount\iclrrulercount 198 | \newdimen\iclrruleroffset 199 | \newdimen\cv@lineheight 200 | \newdimen\cv@boxheight 201 | \newbox\cv@tmpbox 202 | \newcount\cv@refno 203 | \newcount\cv@tot 204 | % NUMBER with left flushed zeros \fillzeros[] 205 | \newcount\cv@tmpc@ \newcount\cv@tmpc 206 | \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 207 | \cv@tmpc=1 % 208 | \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 209 | \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 210 | \ifnum#2<0\advance\cv@tmpc1\relax-\fi 211 | \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 212 | \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 213 | % \makevruler[][][][][] 214 | \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 215 | \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 216 | \global\setbox\iclrrulerbox=\vbox to \textheight{% 217 | {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 218 | \cv@lineheight=#1\global\iclrrulercount=#2% 219 | \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 220 | \cv@refno1\vskip-\cv@lineheight\vskip1ex% 221 | \loop\setbox\cv@tmpbox=\hbox to0cm{{\iclrtenhv\hfil\fillzeros[#4]\iclrrulercount}}% 222 | \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 223 | \advance\cv@refno1\global\advance\iclrrulercount#3\relax 224 | \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 225 | \makeatother 226 | % ----- end of vruler 227 | 228 | % \makevruler[][][][][] 229 | \def\iclrruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\iclrrulerbox}} 230 | \AddToShipoutPicture{% 231 | \ificlrfinal\else 232 | \iclrruleroffset=\textheight 233 | \advance\iclrruleroffset by -3.7pt 234 | \color[rgb]{.7,.7,.7} 235 | \AtTextUpperLeft{% 236 | \put(\LenToUnit{-35pt},\LenToUnit{-\iclrruleroffset}){%left ruler 237 | \iclrruler{\iclrrulercount}} 238 | } 239 | \fi 240 | } 241 | % %% To add a vertical bar on the side 242 | % \AddToShipoutPicture{ 243 | % \AtTextLowerLeft{ 244 | % \hspace*{-1.8cm} 245 | % \colorbox[rgb]{0.7,0.7,0.7}{\small \parbox[b][\textheight]{0.1cm}{}}} 246 | % } 247 | -------------------------------------------------------------------------------- /ai_scientist/blank_icbinb_latex/template.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} % For LaTeX2e 2 | \usepackage{iclr2025,times} 3 | 4 | % Optional math commands from https://github.com/goodfeli/dlbook_notation. 5 | \input{math_commands.tex} 6 | 7 | \usepackage{hyperref} 8 | \usepackage{url} 9 | \usepackage{graphicx} 10 | \usepackage{subfigure} 11 | \usepackage{booktabs} 12 | 13 | % For theorems and such 14 | \usepackage{amsmath} 15 | \usepackage{amssymb} 16 | \usepackage{mathtools} 17 | \usepackage{amsthm} 18 | 19 | % Custom 20 | \usepackage{multirow} 21 | \usepackage{color} 22 | \usepackage{colortbl} 23 | \usepackage[capitalize,noabbrev]{cleveref} 24 | \usepackage{xspace} 25 | 26 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 27 | % THEOREMS 28 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 29 | \theoremstyle{plain} 30 | \newtheorem{theorem}{Theorem}[section] 31 | \newtheorem{proposition}[theorem]{Proposition} 32 | \newtheorem{lemma}[theorem]{Lemma} 33 | \newtheorem{corollary}[theorem]{Corollary} 34 | \theoremstyle{definition} 35 | \newtheorem{definition}[theorem]{Definition} 36 | \newtheorem{assumption}[theorem]{Assumption} 37 | \theoremstyle{remark} 38 | \newtheorem{remark}[theorem]{Remark} 39 | 40 | \graphicspath{{../figures/}} % To reference your generated figures, name the PNGs directly. DO NOT CHANGE THIS. 41 | 42 | \begin{filecontents}{references.bib} 43 | @book{goodfellow2016deep, 44 | title={Deep learning}, 45 | author={Goodfellow, Ian and Bengio, Yoshua and Courville, Aaron and Bengio, Yoshua}, 46 | volume={1}, 47 | year={2016}, 48 | publisher={MIT Press} 49 | } 50 | \end{filecontents} 51 | 52 | \title{ 53 | %%%%%%%%%TITLE%%%%%%%%% 54 | TITLE HERE 55 | %%%%%%%%%TITLE%%%%%%%%% 56 | } 57 | 58 | % Authors must not appear in the submitted version. They should be hidden 59 | % as long as the \iclrfinalcopy macro remains commented out below. 60 | % Non-anonymous submissions will be rejected without review. 61 | 62 | \author{Anonymous} 63 | 64 | % The \author macro works with any number of authors. There are two commands 65 | % used to separate the names and addresses of multiple authors: \And and \AND. 66 | % 67 | % Using \And between authors leaves it to \LaTeX{} to determine where to break 68 | % the lines. Using \AND forces a linebreak at that point. So, if \LaTeX{} 69 | % puts 3 of 4 authors names on the first line, and the last on the second 70 | % line, try using \AND instead of \And before the third author name. 71 | 72 | \newcommand{\fix}{\marginpar{FIX}} 73 | \newcommand{\new}{\marginpar{NEW}} 74 | 75 | %\iclrfinalcopy % Uncomment for camera-ready version, but NOT for submission. 76 | \begin{document} 77 | 78 | 79 | \maketitle 80 | 81 | \begin{abstract} 82 | %%%%%%%%%ABSTRACT%%%%%%%%% 83 | ABSTRACT HERE 84 | %%%%%%%%%ABSTRACT%%%%%%%%% 85 | \end{abstract} 86 | 87 | \section{Introduction} 88 | \label{sec:intro} 89 | %%%%%%%%%INTRODUCTION%%%%%%%%% 90 | INTRO HERE 91 | %%%%%%%%%INTRODUCTION%%%%%%%%% 92 | 93 | \section{Related Work} 94 | \label{sec:related} 95 | %%%%%%%%%RELATED WORK%%%%%%%%% 96 | RELATED WORK HERE 97 | %%%%%%%%%RELATED WORK%%%%%%%%% 98 | 99 | \section{Background} 100 | \label{sec:background} 101 | %%%%%%%%%BACKGROUND%%%%%%%%% 102 | BACKGROUND HERE 103 | %%%%%%%%%BACKGROUND%%%%%%%%% 104 | 105 | \section{Method} 106 | \label{sec:method} 107 | %%%%%%%%%METHOD%%%%%%%%% 108 | METHOD HERE 109 | %%%%%%%%%METHOD%%%%%%%%% 110 | 111 | \section{Experimental Setup} 112 | \label{sec:experimental_setup} 113 | %%%%%%%%%EXPERIMENTAL SETUP%%%%%%%%% 114 | EXPERIMENTAL SETUP HERE 115 | %%%%%%%%%EXPERIMENTAL SETUP%%%%%%%%% 116 | 117 | \section{Experiments} 118 | \label{sec:experiments} 119 | %%%%%%%%%EXPERIMENTS%%%%%%%%% 120 | RESULTS HERE 121 | 122 | % EXAMPLE FIGURE: REPLACE AND ADD YOUR OWN FIGURES / CAPTIONS 123 | \begin{figure}[h!] 124 | \centering 125 | \includegraphics[width=0.5\textwidth]{example-image-a} 126 | \caption{PLEASE FILL IN CAPTION HERE} 127 | \label{fig:first_figure} 128 | \end{figure} 129 | %%%%%%%%%EXPERIMENTS%%%%%%%%% 130 | 131 | \section{Conclusion} 132 | \label{sec:conclusion} 133 | %%%%%%%%%CONCLUSION%%%%%%%%% 134 | CONCLUSIONS HERE 135 | %%%%%%%%%CONCLUSION%%%%%%%%% 136 | 137 | \bibliography{iclr2025} 138 | \bibliographystyle{iclr2025} 139 | 140 | \appendix 141 | 142 | \section*{\LARGE Supplementary Material} 143 | \label{sec:appendix} 144 | 145 | %%%%%%%%%APPENDIX%%%%%%%%% 146 | \section{Appendix Section} 147 | APPENDIX TEXT 148 | %%%%%%%%%APPENDIX%%%%%%%%% 149 | 150 | \end{document} 151 | -------------------------------------------------------------------------------- /ai_scientist/blank_icml_latex/algorithm.sty: -------------------------------------------------------------------------------- 1 | % ALGORITHM STYLE -- Released 8 April 1996 2 | % for LaTeX-2e 3 | % Copyright -- 1994 Peter Williams 4 | % E-mail Peter.Williams@dsto.defence.gov.au 5 | \NeedsTeXFormat{LaTeX2e} 6 | \ProvidesPackage{algorithm} 7 | \typeout{Document Style `algorithm' - floating environment} 8 | 9 | \RequirePackage{float} 10 | \RequirePackage{ifthen} 11 | \newcommand{\ALG@within}{nothing} 12 | \newboolean{ALG@within} 13 | \setboolean{ALG@within}{false} 14 | \newcommand{\ALG@floatstyle}{ruled} 15 | \newcommand{\ALG@name}{Algorithm} 16 | \newcommand{\listalgorithmname}{List of \ALG@name s} 17 | 18 | % Declare Options 19 | % first appearance 20 | \DeclareOption{plain}{ 21 | \renewcommand{\ALG@floatstyle}{plain} 22 | } 23 | \DeclareOption{ruled}{ 24 | \renewcommand{\ALG@floatstyle}{ruled} 25 | } 26 | \DeclareOption{boxed}{ 27 | \renewcommand{\ALG@floatstyle}{boxed} 28 | } 29 | % then numbering convention 30 | \DeclareOption{part}{ 31 | \renewcommand{\ALG@within}{part} 32 | \setboolean{ALG@within}{true} 33 | } 34 | \DeclareOption{chapter}{ 35 | \renewcommand{\ALG@within}{chapter} 36 | \setboolean{ALG@within}{true} 37 | } 38 | \DeclareOption{section}{ 39 | \renewcommand{\ALG@within}{section} 40 | \setboolean{ALG@within}{true} 41 | } 42 | \DeclareOption{subsection}{ 43 | \renewcommand{\ALG@within}{subsection} 44 | \setboolean{ALG@within}{true} 45 | } 46 | \DeclareOption{subsubsection}{ 47 | \renewcommand{\ALG@within}{subsubsection} 48 | \setboolean{ALG@within}{true} 49 | } 50 | \DeclareOption{nothing}{ 51 | \renewcommand{\ALG@within}{nothing} 52 | \setboolean{ALG@within}{true} 53 | } 54 | \DeclareOption*{\edef\ALG@name{\CurrentOption}} 55 | 56 | % ALGORITHM 57 | % 58 | \ProcessOptions 59 | \floatstyle{\ALG@floatstyle} 60 | \ifthenelse{\boolean{ALG@within}}{ 61 | \ifthenelse{\equal{\ALG@within}{part}} 62 | {\newfloat{algorithm}{htbp}{loa}[part]}{} 63 | \ifthenelse{\equal{\ALG@within}{chapter}} 64 | {\newfloat{algorithm}{htbp}{loa}[chapter]}{} 65 | \ifthenelse{\equal{\ALG@within}{section}} 66 | {\newfloat{algorithm}{htbp}{loa}[section]}{} 67 | \ifthenelse{\equal{\ALG@within}{subsection}} 68 | {\newfloat{algorithm}{htbp}{loa}[subsection]}{} 69 | \ifthenelse{\equal{\ALG@within}{subsubsection}} 70 | {\newfloat{algorithm}{htbp}{loa}[subsubsection]}{} 71 | \ifthenelse{\equal{\ALG@within}{nothing}} 72 | {\newfloat{algorithm}{htbp}{loa}}{} 73 | }{ 74 | \newfloat{algorithm}{htbp}{loa} 75 | } 76 | \floatname{algorithm}{\ALG@name} 77 | 78 | \newcommand{\listofalgorithms}{\listof{algorithm}{\listalgorithmname}} 79 | 80 | -------------------------------------------------------------------------------- /ai_scientist/blank_icml_latex/algorithmic.sty: -------------------------------------------------------------------------------- 1 | % ALGORITHMIC STYLE -- Released 8 APRIL 1996 2 | % for LaTeX version 2e 3 | % Copyright -- 1994 Peter Williams 4 | % E-mail PeterWilliams@dsto.defence.gov.au 5 | % 6 | % Modified by Alex Smola (08/2000) 7 | % E-mail Alex.Smola@anu.edu.au 8 | % 9 | \NeedsTeXFormat{LaTeX2e} 10 | \ProvidesPackage{algorithmic} 11 | \typeout{Document Style `algorithmic' - environment} 12 | % 13 | \RequirePackage{ifthen} 14 | \RequirePackage{calc} 15 | \newboolean{ALC@noend} 16 | \setboolean{ALC@noend}{false} 17 | \newcounter{ALC@line} 18 | \newcounter{ALC@rem} 19 | \newlength{\ALC@tlm} 20 | % 21 | \DeclareOption{noend}{\setboolean{ALC@noend}{true}} 22 | % 23 | \ProcessOptions 24 | % 25 | % ALGORITHMIC 26 | \newcommand{\algorithmicrequire}{\textbf{Require:}} 27 | \newcommand{\algorithmicensure}{\textbf{Ensure:}} 28 | \newcommand{\algorithmiccomment}[1]{\{#1\}} 29 | \newcommand{\algorithmicend}{\textbf{end}} 30 | \newcommand{\algorithmicif}{\textbf{if}} 31 | \newcommand{\algorithmicthen}{\textbf{then}} 32 | \newcommand{\algorithmicelse}{\textbf{else}} 33 | \newcommand{\algorithmicelsif}{\algorithmicelse\ \algorithmicif} 34 | \newcommand{\algorithmicendif}{\algorithmicend\ \algorithmicif} 35 | \newcommand{\algorithmicfor}{\textbf{for}} 36 | \newcommand{\algorithmicforall}{\textbf{for all}} 37 | \newcommand{\algorithmicdo}{\textbf{do}} 38 | \newcommand{\algorithmicendfor}{\algorithmicend\ \algorithmicfor} 39 | \newcommand{\algorithmicwhile}{\textbf{while}} 40 | \newcommand{\algorithmicendwhile}{\algorithmicend\ \algorithmicwhile} 41 | \newcommand{\algorithmicloop}{\textbf{loop}} 42 | \newcommand{\algorithmicendloop}{\algorithmicend\ \algorithmicloop} 43 | \newcommand{\algorithmicrepeat}{\textbf{repeat}} 44 | \newcommand{\algorithmicuntil}{\textbf{until}} 45 | 46 | %changed by alex smola 47 | \newcommand{\algorithmicinput}{\textbf{input}} 48 | \newcommand{\algorithmicoutput}{\textbf{output}} 49 | \newcommand{\algorithmicset}{\textbf{set}} 50 | \newcommand{\algorithmictrue}{\textbf{true}} 51 | \newcommand{\algorithmicfalse}{\textbf{false}} 52 | \newcommand{\algorithmicand}{\textbf{and\ }} 53 | \newcommand{\algorithmicor}{\textbf{or\ }} 54 | \newcommand{\algorithmicfunction}{\textbf{function}} 55 | \newcommand{\algorithmicendfunction}{\algorithmicend\ \algorithmicfunction} 56 | \newcommand{\algorithmicmain}{\textbf{main}} 57 | \newcommand{\algorithmicendmain}{\algorithmicend\ \algorithmicmain} 58 | %end changed by alex smola 59 | 60 | \def\ALC@item[#1]{% 61 | \if@noparitem \@donoparitem 62 | \else \if@inlabel \indent \par \fi 63 | \ifhmode \unskip\unskip \par \fi 64 | \if@newlist \if@nobreak \@nbitem \else 65 | \addpenalty\@beginparpenalty 66 | \addvspace\@topsep \addvspace{-\parskip}\fi 67 | \else \addpenalty\@itempenalty \addvspace\itemsep 68 | \fi 69 | \global\@inlabeltrue 70 | \fi 71 | \everypar{\global\@minipagefalse\global\@newlistfalse 72 | \if@inlabel\global\@inlabelfalse \hskip -\parindent \box\@labels 73 | \penalty\z@ \fi 74 | \everypar{}}\global\@nobreakfalse 75 | \if@noitemarg \@noitemargfalse \if@nmbrlist \refstepcounter{\@listctr}\fi \fi 76 | \sbox\@tempboxa{\makelabel{#1}}% 77 | \global\setbox\@labels 78 | \hbox{\unhbox\@labels \hskip \itemindent 79 | \hskip -\labelwidth \hskip -\ALC@tlm 80 | \ifdim \wd\@tempboxa >\labelwidth 81 | \box\@tempboxa 82 | \else \hbox to\labelwidth {\unhbox\@tempboxa}\fi 83 | \hskip \ALC@tlm}\ignorespaces} 84 | % 85 | \newenvironment{algorithmic}[1][0]{ 86 | \let\@item\ALC@item 87 | \newcommand{\ALC@lno}{% 88 | \ifthenelse{\equal{\arabic{ALC@rem}}{0}} 89 | {{\footnotesize \arabic{ALC@line}:}}{}% 90 | } 91 | \let\@listii\@listi 92 | \let\@listiii\@listi 93 | \let\@listiv\@listi 94 | \let\@listv\@listi 95 | \let\@listvi\@listi 96 | \let\@listvii\@listi 97 | \newenvironment{ALC@g}{ 98 | \begin{list}{\ALC@lno}{ \itemsep\z@ \itemindent\z@ 99 | \listparindent\z@ \rightmargin\z@ 100 | \topsep\z@ \partopsep\z@ \parskip\z@\parsep\z@ 101 | \leftmargin 1em 102 | \addtolength{\ALC@tlm}{\leftmargin} 103 | } 104 | } 105 | {\end{list}} 106 | \newcommand{\ALC@it}{\addtocounter{ALC@line}{1}\addtocounter{ALC@rem}{1}\ifthenelse{\equal{\arabic{ALC@rem}}{#1}}{\setcounter{ALC@rem}{0}}{}\item} 107 | \newcommand{\ALC@com}[1]{\ifthenelse{\equal{##1}{default}}% 108 | {}{\ \algorithmiccomment{##1}}} 109 | \newcommand{\REQUIRE}{\item[\algorithmicrequire]} 110 | \newcommand{\ENSURE}{\item[\algorithmicensure]} 111 | \newcommand{\STATE}{\ALC@it} 112 | \newcommand{\COMMENT}[1]{\algorithmiccomment{##1}} 113 | %changes by alex smola 114 | \newcommand{\INPUT}{\item[\algorithmicinput]} 115 | \newcommand{\OUTPUT}{\item[\algorithmicoutput]} 116 | \newcommand{\SET}{\item[\algorithmicset]} 117 | % \newcommand{\TRUE}{\algorithmictrue} 118 | % \newcommand{\FALSE}{\algorithmicfalse} 119 | \newcommand{\AND}{\algorithmicand} 120 | \newcommand{\OR}{\algorithmicor} 121 | \newenvironment{ALC@func}{\begin{ALC@g}}{\end{ALC@g}} 122 | \newenvironment{ALC@main}{\begin{ALC@g}}{\end{ALC@g}} 123 | %end changes by alex smola 124 | \newenvironment{ALC@if}{\begin{ALC@g}}{\end{ALC@g}} 125 | \newenvironment{ALC@for}{\begin{ALC@g}}{\end{ALC@g}} 126 | \newenvironment{ALC@whl}{\begin{ALC@g}}{\end{ALC@g}} 127 | \newenvironment{ALC@loop}{\begin{ALC@g}}{\end{ALC@g}} 128 | \newenvironment{ALC@rpt}{\begin{ALC@g}}{\end{ALC@g}} 129 | \renewcommand{\\}{\@centercr} 130 | \newcommand{\IF}[2][default]{\ALC@it\algorithmicif\ ##2\ \algorithmicthen% 131 | \ALC@com{##1}\begin{ALC@if}} 132 | \newcommand{\SHORTIF}[2]{\ALC@it\algorithmicif\ ##1\ 133 | \algorithmicthen\ {##2}} 134 | \newcommand{\ELSE}[1][default]{\end{ALC@if}\ALC@it\algorithmicelse% 135 | \ALC@com{##1}\begin{ALC@if}} 136 | \newcommand{\ELSIF}[2][default]% 137 | {\end{ALC@if}\ALC@it\algorithmicelsif\ ##2\ \algorithmicthen% 138 | \ALC@com{##1}\begin{ALC@if}} 139 | \newcommand{\FOR}[2][default]{\ALC@it\algorithmicfor\ ##2\ \algorithmicdo% 140 | \ALC@com{##1}\begin{ALC@for}} 141 | \newcommand{\FORALL}[2][default]{\ALC@it\algorithmicforall\ ##2\ % 142 | \algorithmicdo% 143 | \ALC@com{##1}\begin{ALC@for}} 144 | \newcommand{\SHORTFORALL}[2]{\ALC@it\algorithmicforall\ ##1\ % 145 | \algorithmicdo\ {##2}} 146 | \newcommand{\WHILE}[2][default]{\ALC@it\algorithmicwhile\ ##2\ % 147 | \algorithmicdo% 148 | \ALC@com{##1}\begin{ALC@whl}} 149 | \newcommand{\LOOP}[1][default]{\ALC@it\algorithmicloop% 150 | \ALC@com{##1}\begin{ALC@loop}} 151 | %changed by alex smola 152 | \newcommand{\FUNCTION}[2][default]{\ALC@it\algorithmicfunction\ ##2\ % 153 | \ALC@com{##1}\begin{ALC@func}} 154 | \newcommand{\MAIN}[2][default]{\ALC@it\algorithmicmain\ ##2\ % 155 | \ALC@com{##1}\begin{ALC@main}} 156 | %end changed by alex smola 157 | \newcommand{\REPEAT}[1][default]{\ALC@it\algorithmicrepeat% 158 | \ALC@com{##1}\begin{ALC@rpt}} 159 | \newcommand{\UNTIL}[1]{\end{ALC@rpt}\ALC@it\algorithmicuntil\ ##1} 160 | \ifthenelse{\boolean{ALC@noend}}{ 161 | \newcommand{\ENDIF}{\end{ALC@if}} 162 | \newcommand{\ENDFOR}{\end{ALC@for}} 163 | \newcommand{\ENDWHILE}{\end{ALC@whl}} 164 | \newcommand{\ENDLOOP}{\end{ALC@loop}} 165 | \newcommand{\ENDFUNCTION}{\end{ALC@func}} 166 | \newcommand{\ENDMAIN}{\end{ALC@main}} 167 | }{ 168 | \newcommand{\ENDIF}{\end{ALC@if}\ALC@it\algorithmicendif} 169 | \newcommand{\ENDFOR}{\end{ALC@for}\ALC@it\algorithmicendfor} 170 | \newcommand{\ENDWHILE}{\end{ALC@whl}\ALC@it\algorithmicendwhile} 171 | \newcommand{\ENDLOOP}{\end{ALC@loop}\ALC@it\algorithmicendloop} 172 | \newcommand{\ENDFUNCTION}{\end{ALC@func}\ALC@it\algorithmicendfunction} 173 | \newcommand{\ENDMAIN}{\end{ALC@main}\ALC@it\algorithmicendmain} 174 | } 175 | \renewcommand{\@toodeep}{} 176 | \begin{list}{\ALC@lno}{\setcounter{ALC@line}{0}\setcounter{ALC@rem}{0}% 177 | \itemsep\z@ \itemindent\z@ \listparindent\z@% 178 | \partopsep\z@ \parskip\z@ \parsep\z@% 179 | \labelsep 0.5em \topsep 0.2em% 180 | \ifthenelse{\equal{#1}{0}} 181 | {\labelwidth 0.5em } 182 | {\labelwidth 1.2em } 183 | \leftmargin\labelwidth \addtolength{\leftmargin}{\labelsep} 184 | \ALC@tlm\labelsep 185 | } 186 | } 187 | {\end{list}} 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /ai_scientist/blank_icml_latex/template.tex: -------------------------------------------------------------------------------- 1 | %%%%%%%% ICML 2025 LATEX SUBMISSION FILE %%%%%%%%%%%%%%%%% 2 | 3 | \documentclass{article} 4 | \usepackage{microtype} 5 | \usepackage{graphicx} 6 | \usepackage{subfigure} 7 | \usepackage{booktabs} % for professional tables 8 | \usepackage{hyperref} 9 | % Attempt to make hyperref and algorithmic work together better: 10 | \newcommand{\theHalgorithm}{\arabic{algorithm}} 11 | 12 | % Use the following line for the initial blind version submitted for review: 13 | \usepackage{icml2025} 14 | 15 | % For theorems and such 16 | \usepackage{amsmath} 17 | \usepackage{amssymb} 18 | \usepackage{mathtools} 19 | \usepackage{amsthm} 20 | 21 | % Custom 22 | \usepackage{multirow} 23 | \usepackage{color} 24 | \usepackage{colortbl} 25 | \usepackage[capitalize,noabbrev]{cleveref} 26 | \usepackage{xspace} 27 | 28 | \DeclareMathOperator*{\argmin}{arg\,min} 29 | \DeclareMathOperator*{\argmax}{arg\,max} 30 | 31 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 32 | % THEOREMS 33 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 34 | \theoremstyle{plain} 35 | \newtheorem{theorem}{Theorem}[section] 36 | \newtheorem{proposition}[theorem]{Proposition} 37 | \newtheorem{lemma}[theorem]{Lemma} 38 | \newtheorem{corollary}[theorem]{Corollary} 39 | \theoremstyle{definition} 40 | \newtheorem{definition}[theorem]{Definition} 41 | \newtheorem{assumption}[theorem]{Assumption} 42 | \theoremstyle{remark} 43 | \newtheorem{remark}[theorem]{Remark} 44 | 45 | \graphicspath{{../figures/}} % To reference your generated figures, name the PNGs directly. DO NOT CHANGE THIS. 46 | 47 | \begin{filecontents}{references.bib} 48 | @book{goodfellow2016deep, 49 | title={Deep learning}, 50 | author={Goodfellow, Ian and Bengio, Yoshua and Courville, Aaron and Bengio, Yoshua}, 51 | volume={1}, 52 | year={2016}, 53 | publisher={MIT Press} 54 | } 55 | \end{filecontents} 56 | 57 | % The \icmltitle you define below is probably too long as a header. 58 | % Therefore, a short form for the running title is supplied here: 59 | \icmltitlerunning{ 60 | %%%%%%%%%TITLE%%%%%%%%% 61 | TITLE HERE 62 | %%%%%%%%%TITLE%%%%%%%%% 63 | } 64 | 65 | \begin{document} 66 | 67 | \twocolumn[ 68 | \icmltitle{ 69 | %%%%%%%%%TITLE%%%%%%%%% 70 | TITLE HERE 71 | %%%%%%%%%TITLE%%%%%%%%% 72 | } 73 | 74 | \icmlsetsymbol{equal}{*} 75 | 76 | \begin{icmlauthorlist} 77 | \icmlauthor{Anonymous}{yyy} 78 | \icmlauthor{Firstname2 Lastname2}{equal,yyy,comp} 79 | \end{icmlauthorlist} 80 | 81 | \icmlaffiliation{yyy}{Department of XXX, University of YYY, Location, Country} 82 | 83 | \icmlcorrespondingauthor{Anonymous}{first1.last1@xxx.edu} 84 | 85 | % You may provide any keywords that you 86 | % find helpful for describing your paper; these are used to populate 87 | % the "keywords" metadata in the PDF but will not be shown in the document 88 | \icmlkeywords{Machine Learning, ICML} 89 | 90 | \vskip 0.3in 91 | ] 92 | 93 | \printAffiliationsAndNotice{} % leave blank if no need to mention equal contribution 94 | 95 | \begin{abstract} 96 | %%%%%%%%%ABSTRACT%%%%%%%%% 97 | ABSTRACT HERE 98 | %%%%%%%%%ABSTRACT%%%%%%%%% 99 | \end{abstract} 100 | 101 | \section{Introduction} 102 | \label{sec:intro} 103 | %%%%%%%%%INTRODUCTION%%%%%%%%% 104 | INTRO HERE 105 | %%%%%%%%%INTRODUCTION%%%%%%%%% 106 | 107 | \section{Related Work} 108 | \label{sec:related} 109 | %%%%%%%%%RELATED WORK%%%%%%%%% 110 | RELATED WORK HERE 111 | %%%%%%%%%RELATED WORK%%%%%%%%% 112 | 113 | \section{Background} 114 | \label{sec:background} 115 | %%%%%%%%%BACKGROUND%%%%%%%%% 116 | BACKGROUND HERE 117 | %%%%%%%%%BACKGROUND%%%%%%%%% 118 | 119 | \section{Method} 120 | \label{sec:method} 121 | %%%%%%%%%METHOD%%%%%%%%% 122 | METHOD HERE 123 | %%%%%%%%%METHOD%%%%%%%%% 124 | 125 | \section{Experimental Setup} 126 | \label{sec:experimental_setup} 127 | %%%%%%%%%EXPERIMENTAL SETUP%%%%%%%%% 128 | EXPERIMENTAL SETUP HERE 129 | %%%%%%%%%EXPERIMENTAL SETUP%%%%%%%%% 130 | 131 | \section{Experiments} 132 | \label{sec:experiments} 133 | %%%%%%%%%EXPERIMENTS%%%%%%%%% 134 | RESULTS HERE 135 | 136 | % EXAMPLE FIGURE: REPLACE AND ADD YOUR OWN FIGURES / CAPTIONS 137 | \begin{figure}[t] 138 | \centering 139 | \includegraphics[width=\columnwidth]{example-image-a} 140 | \caption{PLEASE FILL IN CAPTION HERE} 141 | \label{fig:first_figure} 142 | \end{figure} 143 | %%%%%%%%%EXPERIMENTS%%%%%%%%% 144 | 145 | \section{Conclusion} 146 | \label{sec:conclusion} 147 | %%%%%%%%%CONCLUSION%%%%%%%%% 148 | CONCLUSIONS HERE 149 | %%%%%%%%%CONCLUSION%%%%%%%%% 150 | 151 | 152 | % Authors are \textbf{required} to include a statement of the potential 153 | % broader impact of their work, including its ethical aspects and future 154 | % societal consequences. This statement should be in an unnumbered 155 | % section at the end of the paper (co-located with Acknowledgements -- 156 | % the two may appear in either order, but both must be before References), 157 | % and does not count toward the paper page limit. In many cases, where 158 | % the ethical impacts and expected societal implications are those that 159 | % are well established when advancing the field of Machine Learning, 160 | % substantial discussion is not required, and a simple statement such 161 | % as the following will suffice: 162 | \section*{Impact Statement} 163 | This paper presents work whose goal is to advance the field of 164 | Machine Learning. There are many potential societal consequences 165 | of our work, none which we feel must be specifically highlighted here. 166 | 167 | \bibliography{references} 168 | \bibliographystyle{icml2025} 169 | 170 | 171 | % APPENDIX 172 | \newpage 173 | \appendix 174 | \onecolumn 175 | 176 | \section*{\LARGE Supplementary Material} 177 | \label{sec:appendix} 178 | 179 | %%%%%%%%%APPENDIX%%%%%%%%% 180 | \section{Appendix Section} 181 | APPENDIX TEXT 182 | %%%%%%%%%APPENDIX%%%%%%%%% 183 | 184 | \end{document} 185 | -------------------------------------------------------------------------------- /ai_scientist/fewshot_examples/132_automated_relational.json: -------------------------------------------------------------------------------- 1 | { 2 | "review": "{\n \"Summary\": \"The paper provides an interesting direction in the meta-learning field. In particular, it proposes to enhance meta learning performance by fully exploring relations across multiple tasks. To capture such information, the authors develop a heterogeneity-aware meta-learning framework by introducing a novel architecture--meta-knowledge graph, which can dynamically find the most relevant structure for new tasks.\",\n \"Strengths\": [\n \"The paper takes one of the most important issues of meta-learning: task heterogeneity. For me, the problem itself is real and practical.\",\n \"The proposed meta-knowledge graph is novel for capturing the relation between tasks and addressing the problem of task heterogeneity. Graph structure provides a more flexible way of modeling relations. The design for using the prototype-based relational graph to query the meta-knowledge graph is reasonable and interesting.\",\n \"This paper provides comprehensive experiments, including both qualitative analysis and quantitative results, to show the effectiveness of the proposed framework. The newly constructed Art-Multi dataset further enhances the difficulty of tasks and makes the performance more convincing.\"\n ],\n \"Weaknesses\": [\n \"Although the proposed method provides several ablation studies, I still suggest the authors conduct the following ablation studies to enhance the quality of the paper: (1) It might be valuable to investigate the modulation function. In the paper, the authors compare sigmoid, tanh, and Film layer. Can the authors analyze the results by reducing the number of gating parameters in Eq. 10 by sharing the gate value of each filter in Conv layers? (2) What is the performance of the proposed model by changing the type of aggregators?\",\n \"For the autoencoder aggregator, it would be better to provide more details about it, which seems not very clear to me.\",\n \"In the qualitative analysis (i.e., Figure 2 and Figure 3), the authors provide one visualization for each task. It would be more convincing if the authors can provide more cases in the rebuttal period.\"\n ],\n \"Originality\": 3,\n \"Quality\": 3,\n \"Clarity\": 3,\n \"Significance\": 4,\n \"Questions\": [\n \"Please address and clarify the cons above.\"\n ],\n \"Limitations\": [\n \"My major concern is about the clarity of the paper and some additional ablation models (see cons below). Hopefully the authors can address my concern in the rebuttal period.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 3,\n \"Presentation\": 3,\n \"Contribution\": 3,\n \"Overall\": 7,\n \"Confidence\": 5,\n \"Decision\": \"Accept\"\n}" 3 | } -------------------------------------------------------------------------------- /ai_scientist/fewshot_examples/132_automated_relational.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/ai_scientist/fewshot_examples/132_automated_relational.pdf -------------------------------------------------------------------------------- /ai_scientist/fewshot_examples/2_carpe_diem.json: -------------------------------------------------------------------------------- 1 | { 2 | "review": "{\n \"Summary\": \"This paper proposes Recency Bias, an adaptive mini batch selection method for training deep neural networks. To select informative minibatches for training, the proposed method maintains a fixed size sliding window of past model predictions for each data sample. At a given iteration, samples which have highly inconsistent predictions within the sliding window are added to the minibatch. The main contribution of this paper is the introduction of a sliding window to remember past model predictions, as an improvement over the SOTA approach: Active Bias, which maintains a growing window of model predictions. Empirical studies are performed to show the superiority of Recency Bias over two SOTA approaches. Results are shown on the task of (1) image classification from scratch and (2) image classification by fine-tuning pretrained networks.\",\n \"Strengths\": [\n \"The idea of using a sliding window over a growing window in active batch selection is interesting.\",\n \"Overall, the paper is well written. In particular, the Related Work section has a nice flow and puts the proposed method into context. Despite the method having limited novelty (sliding window instead of a growing window), the method has been well motivated by pointing out the limitations in SOTA methods.\",\n \"The results section is well structured. It's nice to see hyperparameter tuning results; and loss convergence graphs in various learning settings for each dataset.\"\n ],\n \"Weaknesses\": [\n \"The key concern about the paper is the lack of rigorous experimentation to study the usefulness of the proposed method. Despite the paper stating that there have been earlier work (Joseph et al, 2019 and Wang et al, 2019) that attempt mini-batch selection, the paper does not compare with them. This is limiting. Further, since the proposed method is not specific to the domain of images, evaluating it on tasks other than image classification, such as text classification for instance, would have helped validate its applicability across domains.\",\n \"Considering the limited results, a deeper analysis of the proposed method would have been nice. The idea of a sliding window over a growing window is a generic one, and there have been many efforts to theoretically analyze active learning over the last two decades. How does the proposed method fit in there? (For e.g., how does the expected model variance change in this setting?) Some form of theoretical/analytical reasoning behind the effectiveness of recency bias (which is missing) would provide greater insights to the community and facilitate further research in this direction.\",\n \"The claim of 20.5% reduction in test error mentioned in the abstract has not been clearly addressed and pointed out in the results section of the paper.\",\n \"The results would have been more complete if results were shown in a setting where just recency bias is used without the use of the selection pressure parameter. In other words, an ablation study on the effect of the selection pressure parameter would have been very useful.\",\n \"The intuition behind the method is described well, however, the proposed method would have been really solidified if it were analysed in the context of a simple machine learning problem (such as logistic regression). As an example, verifying if the chosen minibatch samples are actually close to the decision boundary of a model (even if the model is very simple) would have helped analyze the proposed method well.\"\n ],\n \"Originality\": 3,\n \"Quality\": 2,\n \"Clarity\": 4,\n \"Significance\": 2,\n \"Questions\": [\n \"How important is the warm-up phase to the proposed method? Considering the paper states that this is required to get good estimates of the quantization index of the samples, some ablation studies on reducing/increasing the warm-up phase and showing the results would have been useful to understand this.\",\n \"Fig 4: Why are there sharp dips periodically in all the graphs? What do these correspond to?\",\n \"The results are not conclusively in favor of the proposed method, and only is marginally better than the competitors. Why does online batch perform consistently than the proposed method? There is no discussion of these inferences from the results.\"\n ],\n \"Limitations\": [\n \"The primary concern is about the strength of the experimental results, which showed only a modest benefit on relatively simple datasets.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 2,\n \"Presentation\": 3,\n \"Contribution\": 2,\n \"Overall\": 4,\n \"Confidence\": 3,\n \"Decision\": \"Reject\"\n}" 3 | } -------------------------------------------------------------------------------- /ai_scientist/fewshot_examples/2_carpe_diem.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/ai_scientist/fewshot_examples/2_carpe_diem.pdf -------------------------------------------------------------------------------- /ai_scientist/fewshot_examples/attention.json: -------------------------------------------------------------------------------- 1 | { 2 | "review": "{\n \"Summary\": \"The paper proposes the Transformer, a novel neural network architecture that relies entirely on self-attention mechanisms, eschewing traditional recurrent and convolutional layers. This innovation allows the model to achieve state-of-the-art results in machine translation tasks with significant improvements in both training efficiency and translation quality. The paper includes detailed descriptions of the model architecture, including multi-head attention and positional encodings, as well as extensive experimental results to validate the model's performance.\",\n \"Questions\": [\n \"Could the authors provide more detailed comparisons with other recent models not included in Table 2?\",\n \"What is the impact of varying the number of layers (N) in both the encoder and decoder stacks?\",\n \"Can the authors provide more insights into the choice of hyperparameters, especially the learning rate schedule and warmup steps?\"\n ],\n \"Limitations\": [\n \"The paper does not explore the application of the Transformer to tasks beyond machine translation, such as image or audio processing.\",\n \"The discussion on the potential negative societal impacts of the model is minimal and could be expanded.\"\n ],\n \"Ethical Concerns\": false,\n \"Soundness\": 4,\n \"Presentation\": 3,\n \"Contribution\": 4,\n \"Overall\": 8,\n \"Confidence\": 5,\n \"Strengths\": [\n \"The Transformer model introduces a highly innovative use of self-attention mechanisms, replacing traditional recurrent and convolutional layers.\",\n \"Comprehensive experimental validation showing state-of-the-art performance in machine translation tasks.\",\n \"Clear and detailed description of the model architecture and its components, facilitating reproducibility and further research.\"\n ],\n \"Weaknesses\": [\n \"Limited discussion on the application of the model to other domains beyond machine translation.\",\n \"The paper could benefit from a deeper analysis of the potential negative societal impacts of the model.\"\n ],\n \"Originality\": 4,\n \"Quality\": 4,\n \"Clarity\": 4,\n \"Significance\": 4,\n \"Decision\": \"Accept\"\n}" 3 | } -------------------------------------------------------------------------------- /ai_scientist/fewshot_examples/attention.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/ai_scientist/fewshot_examples/attention.pdf -------------------------------------------------------------------------------- /ai_scientist/ideas/i_cant_believe_its_not_better.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Name": "compositional_regularization_nn", 4 | "Title": "Enhancing Compositional Generalization in Neural Networks via Compositional Regularization", 5 | "Short Hypothesis": "Introducing a compositional regularization term during training can encourage neural networks to develop compositional representations, thereby improving their ability to generalize to novel combinations of known components.", 6 | "Related Work": "Previous work has highlighted the challenges neural networks face in achieving compositional generalization. Studies such as 'Compositional Generalization through Abstract Representations in Human and Artificial Neural Networks' (Ito et al., NeurIPS 2022) have explored abstract representations to tackle this issue. However, limited research focuses on directly incorporating explicit regularization terms into the training objective to enforce compositional structures. Our proposal distinguishes itself by introducing a novel regularization approach that penalizes deviations from predefined compositional patterns during training, encouraging the network to internalize compositional rules.", 7 | "Abstract": "Neural networks excel in many tasks but often struggle with compositional generalization\u2014the ability to understand and generate novel combinations of familiar components. This limitation hampers their performance on tasks requiring systematic generalization beyond the training data. In this proposal, we introduce a novel training method that incorporates an explicit compositional regularization term into the loss function of neural networks. This regularization term is designed to encourage the formation of compositional representations by penalizing the network when its internal representations deviate from expected compositional structures. We hypothesize that this approach will enhance the network's ability to generalize to unseen combinations, mimicking human-like compositional reasoning. We will test our method on synthetic benchmarks like the SCAN and COGS datasets, which are specifically designed to evaluate compositional generalization, as well as on real-world tasks such as machine translation and semantic parsing. By comparing our method to baseline models and existing approaches, we aim to demonstrate significant improvements in generalization performance. This work offers a new avenue for enforcing compositionality in neural networks through regularization, potentially bridging the gap between neural network capabilities and human cognitive flexibility.", 8 | "Experiments": [ 9 | "Implement the compositional regularization term and integrate it into the loss function of standard sequence-to-sequence neural network architectures with attention mechanisms.", 10 | "Train models on synthetic datasets like SCAN and COGS, evaluating performance on compositional generalization tasks with and without the regularization term.", 11 | "Apply the method to real-world tasks such as machine translation using the IWSLT dataset and semantic parsing with the GeoQuery dataset, assessing improvements in generalization to new language constructs.", 12 | "Analyze the learned representations by visualizing embedding spaces and utilizing compositionality metrics to assess how the regularization affects internal representations.", 13 | "Conduct ablation studies to determine the impact of different strengths of the regularization term, identifying the optimal balance between enforcing compositionality and maintaining overall performance.", 14 | "Compare the proposed method against other approaches aimed at improving compositional generalization, such as meta-learning techniques and specialized architectures." 15 | ], 16 | "Risk Factors and Limitations": [ 17 | "The effectiveness of the compositional regularization may vary across different datasets and tasks, potentially limiting its generalizability.", 18 | "An improperly balanced regularization term could negatively impact model performance on the primary task, leading to lower accuracy.", 19 | "Additional computational overhead from calculating the regularization term may increase training time and resource requirements.", 20 | "Defining appropriate compositional structures for complex or less-understood domains may be challenging, affecting the applicability of the method.", 21 | "The approach may face scalability issues when applied to very large models or datasets common in industrial applications." 22 | ] 23 | }, 24 | { 25 | "Name": "interpretability_failure_modes", 26 | "Title": "When Interpretability Fails: Investigating the Limitations of Explanation Methods in Deep Learning", 27 | "Short Hypothesis": "Explanation methods for deep learning models may not always provide accurate or reliable interpretations of model behavior; identifying when these methods fail can inform better application and development of interpretability techniques.", 28 | "Related Work": "Previous studies, such as 'Sanity Checks for Saliency Maps' (Adebayo et al., 2018) and 'The (Un)reliability of Saliency Methods' (Kindermans et al., 2017), have highlighted issues with the reliability of certain interpretability methods. However, there is a lack of systematic investigation into the failure modes of explanation techniques across different models and tasks, especially in practical, real-world settings. Our proposal distinguishes itself by providing a comprehensive analysis of when and why interpretability methods fail, and how they may mislead users.", 29 | "Abstract": "Interpretability methods are essential for understanding and trusting deep learning models, especially in critical applications where insight into model decisions is necessary. However, these methods may not always provide accurate or meaningful explanations of a model's behavior, and can sometimes be misleading. This proposal aims to investigate the limitations and failure modes of popular interpretability techniques in deep learning, such as saliency maps, attribution methods, and concept activation vectors. We hypothesize that these methods can fail due to factors like model architecture, data biases, or adversarial manipulations, leading to explanations that do not reflect the true decision-making processes of the models. Through systematic experiments across various tasks and models, we will analyze the reliability of interpretability methods, identify conditions under which they fail, and propose metrics to assess explanation fidelity. The findings will inform practitioners about the potential pitfalls of relying on current interpretability techniques and guide the development of more robust methods to enhance the trustworthiness of AI systems.", 30 | "Experiments": [ 31 | "Select a diverse set of pretrained models across different domains (e.g., image classification models like ResNet, NLP models like BERT).", 32 | "Apply various interpretability methods (e.g., Grad-CAM, Integrated Gradients, LIME, SHAP) to these models on standard benchmarks.", 33 | "Design controlled experiments where we introduce factors that may cause interpretability methods to fail, such as adding irrelevant but salient features, using adversarial examples, or modifying model architectures.", 34 | "Evaluate the explanations generated by interpretability methods for accuracy and consistency, comparing them to known model behaviors or ground truth where available.", 35 | "Develop quantitative metrics to assess the fidelity and reliability of explanations, such as explanation invariance and sensitivity.", 36 | "Conduct user studies to understand how misleading or unreliable explanations impact human trust and decision-making.", 37 | "Propose guidelines or improvements for the application of interpretability methods based on the findings." 38 | ], 39 | "Risk Factors and Limitations": [ 40 | "Defining ground truth for explanations is challenging due to the complexity of deep models.", 41 | "Findings may be specific to the selected models or tasks, limiting generalizability.", 42 | "Evaluating interpretability is inherently subjective, and metrics may not capture all aspects of explanation quality.", 43 | "User studies require careful design and sufficient participants to yield meaningful results.", 44 | "Computational resource constraints may limit the scale of experiments." 45 | ] 46 | }, 47 | { 48 | "Name": "real_world_pest_detection", 49 | "Title": "Real-World Challenges in Pest Detection Using Deep Learning: An Investigation into Failures and Solutions", 50 | "Short Hypothesis": "Deep learning models for pest detection often fail to generalize in real-world agricultural settings due to data quality issues, environmental variability, and model limitations. Investigating these failures can lead to more robust solutions.", 51 | "Related Work": "Several studies, such as those by Agarwal et al. (2023) and Dong et al. (2024), have explored deep learning for pest detection in agriculture. These studies generally report high accuracy in controlled settings but often do not address real-world deployment challenges. Our proposal distinguishes itself by focusing on the negative outcomes and the underlying reasons behind these failures.", 52 | "Abstract": "Accurate pest detection is vital for protecting crops and ensuring food security. While deep learning models have shown promise in controlled environments, their performance often degrades in real-world applications. This proposal aims to investigate the reasons behind these failures. We hypothesize that data quality issues, environmental variability, and model limitations are significant factors. By conducting a series of experiments, we will explore these challenges in depth and propose robust solutions to improve the generalizability of deep learning models for pest detection. Our research will provide valuable insights for the agricultural community and contribute to the development of more reliable AI tools for precision farming.", 53 | "Experiments": [ 54 | "1. **Data Quality Analysis**: Collect a diverse dataset of pest images from different agricultural environments and analyze its quality. Identify common issues such as label noise, class imbalance, and distribution shift.", 55 | "2. **Model Robustness Testing**: Train state-of-the-art deep learning models (e.g., YOLOv8, EfficientNetB3) on the collected dataset and evaluate their performance in controlled vs. real-world settings. Metrics: Mean Average Precision (mAP), F1 Score.", 56 | "3. **Environmental Variability Study**: Evaluate model performance under different environmental conditions (e.g., lighting, weather). Identify which conditions most significantly impact model accuracy.", 57 | "4. **Failure Mode Analysis**: Conduct a detailed analysis of misclassifications to identify common patterns and potential causes (e.g., feature overlap between pests and background).", 58 | "5. **Improvement Strategies**: Implement and test various strategies to mitigate identified challenges, such as data augmentation, domain adaptation, and model ensembling. Evaluate their effectiveness in improving model robustness." 59 | ], 60 | "Risk Factors and Limitations": "Potential risks include the availability and quality of real-world data, the computational demands of training and testing multiple deep learning models, and the generalizability of the findings to different types of pests and crops. Additionally, environmental factors may introduce variability that is challenging to control." 61 | } 62 | ] 63 | -------------------------------------------------------------------------------- /ai_scientist/ideas/i_cant_believe_its_not_better.md: -------------------------------------------------------------------------------- 1 | # Title: I Can't Believe It's Not Better: Challenges in Applied Deep Learning 2 | 3 | ## Keywords 4 | negative results, deep learning, failure modes 5 | 6 | ## TL;DR 7 | Why don't deep learning approaches always deliver as expected in the real world? Dive deep into the pitfalls and challenges of applied deep learning. 8 | 9 | ## Abstract 10 | The goal of the I Can’t Believe It’s Not Better (ICBINB) workshop series is to promote slow science and build a community to discuss surprising and negative results, thereby encouraging a culture of transparency and shared learning. In recent years, we have witnessed a remarkable rise of Deep Learning (DL), whose impressive performance on benchmark tasks has led to increasing ambitions to deploy DL in real-world applications across all fields and disciplines. However, despite its potential, DL still faces many challenges during deployment in dynamic, real-world conditions, thus exposing practical limitations that are often overlooked in controlled benchmarks. Therefore, in this year’s ICBINB workshop, we aim to explore the challenges, unexpected outcomes, and common principles underlying similar issues and failure modes encountered across various fields and disciplines when deploying DL models in real-world scenarios. We will invite contributions and discussions from diverse fields including, but not limited to, healthcare, scientific discovery, robotics, education, equality & fairness, and social sciences. The failure modes may include suboptimal performance, concerns with the safety and reliability of applying DL models in unpredictable real-world applications, as well as ethical and societal challenges. More importantly, we aim to discuss common reasons or patterns in challenges and failure modes across disciplines. By creating a platform for researchers from different domains to interact and share insights, we hope to accelerate research by translating findings from one field to another, and also deepen DL researchers’ understanding of the universal fundamental issues that should be addressed within the current theoretical and empirical research paradigms. Embracing negative results as valuable learning opportunities will, therefore, help the community learn from past mistakes, and drive the development of more robust, reliable, and applicable AI models. 11 | -------------------------------------------------------------------------------- /ai_scientist/ideas/i_cant_believe_its_not_betterrealworld.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Name": "real_world_pest_detection", 4 | "Title": "Real-World Challenges in Pest Detection Using Deep Learning: An Investigation into Failures and Solutions", 5 | "Short Hypothesis": "Deep learning models for pest detection often fail to generalize in real-world agricultural settings due to data quality issues, environmental variability, and model limitations. Investigating these failures can lead to more robust solutions.", 6 | "Related Work": "Several studies, such as those by Agarwal et al. (2023) and Dong et al. (2024), have explored deep learning for pest detection in agriculture. These studies generally report high accuracy in controlled settings but often do not address real-world deployment challenges. Our proposal distinguishes itself by focusing on the negative outcomes and the underlying reasons behind these failures.", 7 | "Abstract": "Accurate pest detection is vital for protecting crops and ensuring food security. While deep learning models have shown promise in controlled environments, their performance often degrades in real-world applications. This proposal aims to investigate the reasons behind these failures. We hypothesize that data quality issues, environmental variability, and model limitations are significant factors. By conducting a series of experiments, we will explore these challenges in depth and propose robust solutions to improve the generalizability of deep learning models for pest detection. Our research will provide valuable insights for the agricultural community and contribute to the development of more reliable AI tools for precision farming.", 8 | "Experiments": [ 9 | "1. **Data Quality Analysis**: Collect a diverse dataset of pest images from different agricultural environments and analyze its quality. Identify common issues such as label noise, class imbalance, and distribution shift.", 10 | "2. **Model Robustness Testing**: Train state-of-the-art deep learning models (e.g., YOLOv8, EfficientNetB3) on the collected dataset and evaluate their performance in controlled vs. real-world settings. Metrics: Mean Average Precision (mAP), F1 Score.", 11 | "3. **Environmental Variability Study**: Evaluate model performance under different environmental conditions (e.g., lighting, weather). Identify which conditions most significantly impact model accuracy.", 12 | "4. **Failure Mode Analysis**: Conduct a detailed analysis of misclassifications to identify common patterns and potential causes (e.g., feature overlap between pests and background).", 13 | "5. **Improvement Strategies**: Implement and test various strategies to mitigate identified challenges, such as data augmentation, domain adaptation, and model ensembling. Evaluate their effectiveness in improving model robustness." 14 | ], 15 | "Risk Factors and Limitations": "Potential risks include the availability and quality of real-world data, the computational demands of training and testing multiple deep learning models, and the generalizability of the findings to different types of pests and crops. Additionally, environmental factors may introduce variability that is challenging to control." 16 | } 17 | ] 18 | -------------------------------------------------------------------------------- /ai_scientist/perform_ideation_temp_free.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os.path as osp 4 | import re 5 | import traceback 6 | from typing import Any, Dict, List 7 | 8 | import sys 9 | 10 | sys.path.append(osp.join(osp.dirname(__file__), "..")) 11 | from ai_scientist.llm import ( 12 | AVAILABLE_LLMS, 13 | create_client, 14 | get_response_from_llm, 15 | ) 16 | 17 | from ai_scientist.tools.semantic_scholar import SemanticScholarSearchTool 18 | from ai_scientist.tools.base_tool import BaseTool 19 | 20 | # Create tool instances 21 | semantic_scholar_tool = SemanticScholarSearchTool() 22 | 23 | # Define tools at the top of the file 24 | tools = [ 25 | semantic_scholar_tool, 26 | { 27 | "name": "FinalizeIdea", 28 | "description": """Finalize your idea by providing the idea details. 29 | 30 | The IDEA JSON should include the following fields: 31 | - "Name": A short descriptor of the idea. Lowercase, no spaces, underscores allowed. 32 | - "Title": A catchy and informative title for the proposal. 33 | - "Short Hypothesis": A concise statement of the main hypothesis or research question. Clarify the need for this specific direction, ensure this is the best setting to investigate this idea, and there are not obvious other simpler ways to answer the question. 34 | - "Related Work": A brief discussion of the most relevant related work and how the proposal clearly distinguishes from it, and is not a trivial extension. 35 | - "Abstract": An abstract that summarizes the proposal in conference format (approximately 250 words). 36 | - "Experiments": A list of experiments that would be conducted to validate the proposal. Ensure these are simple and feasible. Be specific in exactly how you would test the hypothesis, and detail precise algorithmic changes. Include the evaluation metrics you would use. 37 | - "Risk Factors and Limitations": A list of potential risks and limitations of the proposal.""", 38 | }, 39 | ] 40 | 41 | # Create a tools dictionary for easy lookup 42 | tools_dict = {tool.name: tool for tool in tools if isinstance(tool, BaseTool)} 43 | 44 | # Create a string with the tool descriptions 45 | tool_descriptions = "\n\n".join( 46 | ( 47 | f"- **{tool.name}**: {tool.description}" 48 | if isinstance(tool, BaseTool) 49 | else f"- **{tool['name']}**: {tool['description']}" 50 | ) 51 | for tool in tools 52 | ) 53 | 54 | # Extract tool names for the prompt 55 | tool_names = [ 56 | f'"{tool.name}"' if isinstance(tool, BaseTool) else f'"{tool["name"]}"' 57 | for tool in tools 58 | ] 59 | tool_names_str = ", ".join(tool_names) 60 | 61 | system_prompt = f"""You are an experienced AI researcher who aims to propose high-impact research ideas resembling exciting grant proposals. Feel free to propose any novel ideas or experiments; make sure they are novel. Be very creative and think out of the box. Each proposal should stem from a simple and elegant question, observation, or hypothesis about the topic. For example, they could involve very interesting and simple interventions or investigations that explore new possibilities or challenge existing assumptions. Clearly clarify how the proposal distinguishes from the existing literature. 62 | 63 | Ensure that the proposal does not require resources beyond what an academic lab could afford. These proposals should lead to papers that are publishable at top ML conferences. 64 | 65 | You have access to the following tools: 66 | 67 | {tool_descriptions} 68 | 69 | Respond in the following format: 70 | 71 | ACTION: 72 | 73 | 74 | ARGUMENTS: 75 | 76 | 77 | If you choose to finalize your idea, provide the IDEA JSON in the arguments: 78 | 79 | IDEA JSON: 80 | ```json 81 | {{ 82 | "Name": "...", 83 | "Title": "...", 84 | "Short Hypothesis": "...", 85 | "Related Work": "...", 86 | "Abstract": "...", 87 | "Experiments": "...", 88 | "Risk Factors and Limitations": "..." 89 | }} 90 | ``` 91 | 92 | Ensure the JSON is properly formatted for automatic parsing. 93 | 94 | Note: You should perform at least one literature search before finalizing your idea to ensure it is well-informed by existing research.""" 95 | 96 | # Define the initial idea generation prompt 97 | idea_generation_prompt = """{workshop_description} 98 | 99 | Here are the proposals that you have already generated: 100 | 101 | ''' 102 | {prev_ideas_string} 103 | ''' 104 | 105 | Begin by generating an interestingly new high-level research proposal that differs from what you have previously proposed. 106 | """ 107 | 108 | # Define the reflection prompt 109 | idea_reflection_prompt = """Round {current_round}/{num_reflections}. 110 | 111 | In your thoughts, first carefully consider the quality, novelty, and feasibility of the proposal you just created. 112 | Include any other factors that you think are important in evaluating the proposal. 113 | Ensure the proposal is clear and concise, and the JSON is in the correct format. 114 | Do not make things overly complicated. 115 | In the next attempt, try to refine and improve your proposal. 116 | Stick to the spirit of the original idea unless there are glaring issues. 117 | 118 | If you have new information from tools, such as literature search results, incorporate them into your reflection and refine your proposal accordingly. 119 | 120 | Results from your last action (if any): 121 | 122 | {last_tool_results} 123 | """ 124 | 125 | 126 | def generate_temp_free_idea( 127 | idea_fname: str, 128 | client: Any, 129 | model: str, 130 | workshop_description: str, 131 | max_num_generations: int = 20, 132 | num_reflections: int = 5, 133 | reload_ideas: bool = True, 134 | ) -> List[Dict]: 135 | idea_str_archive = [] 136 | # load ideas from file 137 | if reload_ideas and osp.exists(idea_fname): 138 | with open(idea_fname, "r") as f: 139 | idea_str_content = json.load(f) 140 | for idea in idea_str_content: 141 | idea_str_archive.append(json.dumps(idea)) 142 | print(f"Loaded {len(idea_str_archive)} ideas from {idea_fname}") 143 | else: 144 | print(f"No ideas found in {idea_fname}. Starting from scratch.") 145 | 146 | for gen_idx in range(max_num_generations): 147 | print() 148 | print(f"Generating proposal {gen_idx + 1}/{max_num_generations}") 149 | try: 150 | prev_ideas_string = "\n\n".join(idea_str_archive) 151 | 152 | last_tool_results = "" 153 | idea_finalized = False 154 | msg_history = [] 155 | 156 | for reflection_round in range(num_reflections): 157 | if reflection_round == 0: 158 | # Use the initial idea generation prompt 159 | prompt_text = idea_generation_prompt.format( 160 | workshop_description=workshop_description, 161 | prev_ideas_string=prev_ideas_string, 162 | ) 163 | else: 164 | # Use the reflection prompt, including tool results if any 165 | prompt_text = idea_reflection_prompt.format( 166 | current_round=reflection_round + 1, 167 | num_reflections=num_reflections, 168 | last_tool_results=last_tool_results or "No new results.", 169 | ) 170 | 171 | response_text, msg_history = get_response_from_llm( 172 | prompt=prompt_text, 173 | client=client, 174 | model=model, 175 | system_message=system_prompt, 176 | msg_history=msg_history, 177 | ) 178 | 179 | # Parse the LLM's response 180 | try: 181 | # Use regular expressions to extract the components 182 | action_pattern = r"ACTION:\s*(.*?)\s*ARGUMENTS:" 183 | arguments_pattern = r"ARGUMENTS:\s*(.*?)(?:$|\nTHOUGHT:|\n$)" 184 | 185 | action_match = re.search( 186 | action_pattern, response_text, re.DOTALL | re.IGNORECASE 187 | ) 188 | arguments_match = re.search( 189 | arguments_pattern, response_text, re.DOTALL | re.IGNORECASE 190 | ) 191 | 192 | if not all([action_match, arguments_match]): 193 | raise ValueError("Failed to parse the LLM response.") 194 | 195 | action = action_match.group(1).strip() 196 | arguments_text = arguments_match.group(1).strip() 197 | print(f"Action: {action}") 198 | print(f"Arguments: {arguments_text}") 199 | 200 | # If arguments are wrapped in ```json blocks, extract the content 201 | if arguments_text.startswith("```json"): 202 | arguments_text = re.search( 203 | r"```json\s*(.*?)\s*```", arguments_text, re.DOTALL 204 | ).group(1) 205 | 206 | # Process the action and arguments 207 | if action in tools_dict: 208 | # It's a tool we have defined 209 | tool = tools_dict[action] 210 | # Parse arguments 211 | try: 212 | arguments_json = json.loads(arguments_text) 213 | except json.JSONDecodeError: 214 | raise ValueError(f"Invalid arguments JSON for {action}.") 215 | 216 | # Use the tool 217 | try: 218 | # Assuming the arguments match the parameters of the tool 219 | result = tool.use_tool(**arguments_json) 220 | last_tool_results = result 221 | except Exception as e: 222 | last_tool_results = f"Error using tool {action}: {str(e)}" 223 | elif action == "FinalizeIdea": 224 | # Parse arguments 225 | try: 226 | arguments_json = json.loads(arguments_text) 227 | idea = arguments_json.get("idea") 228 | if not idea: 229 | raise ValueError("Missing 'idea' in arguments.") 230 | 231 | # Append the idea to the archive 232 | idea_str_archive.append(json.dumps(idea)) 233 | print(f"Proposal finalized: {idea}") 234 | idea_finalized = True 235 | break 236 | except json.JSONDecodeError: 237 | raise ValueError("Invalid arguments JSON for FinalizeIdea.") 238 | else: 239 | print( 240 | "Invalid action. Please specify one of the available tools." 241 | ) 242 | print(f"Available actions are: {tool_names_str}") 243 | except Exception as e: 244 | print( 245 | f"Failed to parse LLM response. Response text:\n{response_text}" 246 | ) 247 | traceback.print_exc() 248 | break # Exit the loop if parsing fails 249 | 250 | if idea_finalized: 251 | continue # Move to the next idea 252 | 253 | except Exception as e: 254 | print("Failed to generate proposal:") 255 | traceback.print_exc() 256 | continue 257 | 258 | # Save ideas 259 | ideas = [json.loads(idea_str) for idea_str in idea_str_archive] 260 | 261 | with open(idea_fname, "w") as f: 262 | json.dump(ideas, f, indent=4) 263 | print(f"Stored {len(ideas)} ideas in {idea_fname}") 264 | return ideas 265 | 266 | 267 | if __name__ == "__main__": 268 | parser = argparse.ArgumentParser( 269 | description="Generate AI scientist proposals - template free" 270 | ) 271 | parser.add_argument( 272 | "--model", 273 | type=str, 274 | default="gpt-4o-2024-05-13", 275 | choices=AVAILABLE_LLMS, 276 | help="Model to use for AI Scientist.", 277 | ) 278 | parser.add_argument( 279 | "--max-num-generations", 280 | type=int, 281 | default=1, 282 | help="Maximum number of proposal generations.", 283 | ) 284 | parser.add_argument( 285 | "--workshop-file", 286 | type=str, 287 | default="ideas/i_cant_believe_its_not_better.md", 288 | help="Path to the workshop description file.", 289 | ) 290 | parser.add_argument( 291 | "--num-reflections", 292 | type=int, 293 | default=5, 294 | help="Number of reflection rounds per proposal.", 295 | ) 296 | args = parser.parse_args() 297 | 298 | # Create the LLM client 299 | client, client_model = create_client(args.model) 300 | 301 | with open(args.workshop_file, "r") as f: 302 | workshop_description = f.read() 303 | print(f"Using workshop description from {args.workshop_file} for idea generation.") 304 | print(f"Workshop description:\n{workshop_description}") 305 | 306 | # Create output filename by replacing .md extension with .json 307 | idea_fname = args.workshop_file.replace(".md", ".json") 308 | print("Starting idea generation for", idea_fname) 309 | ideas = generate_temp_free_idea( 310 | idea_fname=idea_fname, 311 | client=client, 312 | model=client_model, 313 | workshop_description=workshop_description, 314 | max_num_generations=args.max_num_generations, 315 | num_reflections=args.num_reflections, 316 | ) 317 | print(f"{args.workshop_file} generated {len(ideas)} ideas.") 318 | -------------------------------------------------------------------------------- /ai_scientist/perform_plotting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import shutil 6 | import subprocess 7 | import sys 8 | import traceback 9 | from rich import print 10 | 11 | from ai_scientist.llm import create_client, get_response_from_llm 12 | from ai_scientist.utils.token_tracker import token_tracker 13 | from ai_scientist.perform_icbinb_writeup import ( 14 | load_idea_text, 15 | load_exp_summaries, 16 | filter_experiment_summaries, 17 | ) 18 | 19 | MAX_FIGURES = 12 20 | 21 | AGGREGATOR_SYSTEM_MSG = f"""You are an ambitious AI researcher who is preparing final plots for a scientific paper submission. 22 | You have multiple experiment summaries (baseline, research, ablation), each possibly containing references to different plots or numerical insights. 23 | There is also a top-level 'research_idea.md' file that outlines the overarching research direction. 24 | Your job is to produce ONE Python script that fully aggregates and visualizes the final results for a comprehensive research paper. 25 | 26 | Key points: 27 | 1) Combine or replicate relevant existing plotting code, referencing how data was originally generated (from code references) to ensure correctness. 28 | 2) Create a complete set of final scientific plots, stored in 'figures/' only (since only those are used in the final paper). 29 | 3) Make sure to use existing .npy data for analysis; do NOT hallucinate data. If single numeric results are needed, these may be copied from the JSON summaries. 30 | 4) Only create plots where the data is best presented as a figure and not as a table. E.g. don't use bar plots if the data is hard to visually compare. 31 | 5) The final aggregator script must be in triple backticks and stand alone so it can be dropped into a codebase and run. 32 | 6) If there are plots based on synthetic data, include them in the appendix. 33 | 34 | Implement best practices: 35 | - Do not produce extraneous or irrelevant plots. 36 | - Maintain clarity, minimal but sufficient code. 37 | - Demonstrate thoroughness for a final research paper submission. 38 | - Do NOT reference non-existent files or images. 39 | - Use the .npy files to get data for the plots and key numbers from the JSON summaries. 40 | - Demarcate each individual plot, and put them in separate try-catch blocks so that the failure of one plot does not affect the others. 41 | - Make sure to only create plots that are unique and needed for the final paper and appendix. A good number could be around {MAX_FIGURES} plots in total. 42 | - Aim to aggregate multiple figures into one plot if suitable, i.e. if they are all related to the same topic. You can place up to 3 plots in one row. 43 | - Provide well-labeled plots (axes, legends, titles) that highlight main findings. Use informative names everywhere, including in the legend for referencing them in the final paper. Make sure the legend is always visible. 44 | - Make the plots look professional (if applicable, no top and right spines, dpi of 300, adequate ylim, etc.). 45 | - Do not use labels with underscores, e.g. "loss_vs_epoch" should be "loss vs epoch". 46 | - For image examples, select a few categories/classes to showcase the diversity of results instead of showing a single category/class. Some can be included in the main paper, while the rest can go in the appendix. 47 | 48 | Your output should be the entire Python aggregator script in triple backticks. 49 | """ 50 | 51 | 52 | def build_aggregator_prompt(combined_summaries_str, idea_text): 53 | return f""" 54 | We have three JSON summaries of scientific experiments: baseline, research, ablation. 55 | They may contain lists of figure descriptions, code to generate the figures, and paths to the .npy files containing the numerical results. 56 | Our goal is to produce final, publishable figures. 57 | 58 | --- RESEARCH IDEA --- 59 | ``` 60 | {idea_text} 61 | ``` 62 | 63 | IMPORTANT: 64 | - The aggregator script must load existing .npy experiment data from the "exp_results_npy_files" fields (ONLY using full and exact file paths in the summary JSONs) for thorough plotting. 65 | - It should call os.makedirs("figures", exist_ok=True) before saving any plots. 66 | - Aim for a balance of empirical results, ablations, and diverse, informative visuals in 'figures/' that comprehensively showcase the finalized research outcomes. 67 | - If you need .npy paths from the summary, only copy those paths directly (rather than copying and parsing the entire summary). 68 | 69 | Your generated Python script must: 70 | 1) Load or refer to relevant data and .npy files from these summaries. Use the full and exact file paths in the summary JSONs. 71 | 2) Synthesize or directly create final, scientifically meaningful plots for a final research paper (comprehensive and complete), referencing the original code if needed to see how the data was generated. 72 | 3) Carefully combine or replicate relevant existing plotting code to produce these final aggregated plots in 'figures/' only, since only those are used in the final paper. 73 | 4) Do not hallucinate data. Data must either be loaded from .npy files or copied from the JSON summaries. 74 | 5) The aggregator script must be fully self-contained, and place the final plots in 'figures/'. 75 | 6) This aggregator script should produce a comprehensive and final set of scientific plots for the final paper, reflecting all major findings from the experiment data. 76 | 7) Make sure that every plot is unique and not duplicated from the original plots. Delete any duplicate plots if necessary. 77 | 8) Each figure can have up to 3 subplots using fig, ax = plt.subplots(1, 3). 78 | 9) Use a font size larger than the default for plot labels and titles to ensure they are readable in the final PDF paper. 79 | 80 | 81 | Below are the summaries in JSON: 82 | 83 | {combined_summaries_str} 84 | 85 | Respond with a Python script in triple backticks. 86 | """ 87 | 88 | 89 | def extract_code_snippet(text: str) -> str: 90 | """ 91 | Look for a Python code block in triple backticks in the LLM response. 92 | Return only that code. If no code block is found, return the entire text. 93 | """ 94 | pattern = r"```(?:python)?(.*?)```" 95 | matches = re.findall(pattern, text, flags=re.DOTALL) 96 | return matches[0].strip() if matches else text.strip() 97 | 98 | 99 | def run_aggregator_script( 100 | aggregator_code, aggregator_script_path, base_folder, script_name 101 | ): 102 | if not aggregator_code.strip(): 103 | print("No aggregator code was provided. Skipping aggregator script run.") 104 | return "" 105 | with open(aggregator_script_path, "w") as f: 106 | f.write(aggregator_code) 107 | 108 | print( 109 | f"Aggregator script written to '{aggregator_script_path}'. Attempting to run it..." 110 | ) 111 | 112 | aggregator_out = "" 113 | try: 114 | result = subprocess.run( 115 | [sys.executable, script_name], 116 | cwd=base_folder, 117 | check=True, 118 | stdout=subprocess.PIPE, 119 | stderr=subprocess.PIPE, 120 | text=True, 121 | ) 122 | aggregator_out = result.stdout + "\n" + result.stderr 123 | print("Aggregator script ran successfully.") 124 | except subprocess.CalledProcessError as e: 125 | aggregator_out = (e.stdout or "") + "\n" + (e.stderr or "") 126 | print("Error: aggregator script returned a non-zero exit code.") 127 | print(e) 128 | except Exception as e: 129 | aggregator_out = str(e) 130 | print("Error while running aggregator script.") 131 | print(e) 132 | 133 | return aggregator_out 134 | 135 | 136 | def aggregate_plots( 137 | base_folder: str, model: str = "o1-2024-12-17", n_reflections: int = 5 138 | ) -> None: 139 | filename = "auto_plot_aggregator.py" 140 | aggregator_script_path = os.path.join(base_folder, filename) 141 | figures_dir = os.path.join(base_folder, "figures") 142 | 143 | # Clean up previous files 144 | if os.path.exists(aggregator_script_path): 145 | os.remove(aggregator_script_path) 146 | if os.path.exists(figures_dir): 147 | shutil.rmtree(figures_dir) 148 | print(f"Cleaned up previous figures directory") 149 | 150 | idea_text = load_idea_text(base_folder) 151 | exp_summaries = load_exp_summaries(base_folder) 152 | filtered_summaries_for_plot_agg = filter_experiment_summaries( 153 | exp_summaries, step_name="plot_aggregation" 154 | ) 155 | # Convert them to one big JSON string for context 156 | combined_summaries_str = json.dumps(filtered_summaries_for_plot_agg, indent=2) 157 | 158 | # Build aggregator prompt 159 | aggregator_prompt = build_aggregator_prompt(combined_summaries_str, idea_text) 160 | 161 | # Call LLM 162 | client, model_name = create_client(model) 163 | response, msg_history = None, [] 164 | try: 165 | response, msg_history = get_response_from_llm( 166 | prompt=aggregator_prompt, 167 | client=client, 168 | model=model_name, 169 | system_message=AGGREGATOR_SYSTEM_MSG, 170 | print_debug=False, 171 | msg_history=msg_history, 172 | ) 173 | except Exception: 174 | traceback.print_exc() 175 | print("Failed to get aggregator script from LLM.") 176 | return 177 | 178 | aggregator_code = extract_code_snippet(response) 179 | if not aggregator_code.strip(): 180 | print( 181 | "No Python code block was found in LLM response. Full response:\n", response 182 | ) 183 | return 184 | 185 | # First run of aggregator script 186 | aggregator_out = run_aggregator_script( 187 | aggregator_code, aggregator_script_path, base_folder, filename 188 | ) 189 | 190 | # Multiple reflection loops 191 | for i in range(n_reflections): 192 | # Check number of figures 193 | figure_count = 0 194 | if os.path.exists(figures_dir): 195 | figure_count = len( 196 | [ 197 | f 198 | for f in os.listdir(figures_dir) 199 | if os.path.isfile(os.path.join(figures_dir, f)) 200 | ] 201 | ) 202 | print(f"[{i + 1} / {n_reflections}]: Number of figures: {figure_count}") 203 | # Reflection prompt with reminder for common checks and early exit 204 | reflection_prompt = f"""We have run your aggregator script and it produced {figure_count} figure(s). The script's output is: 205 | ``` 206 | {aggregator_out} 207 | ``` 208 | 209 | Please criticize the current script for any flaws including but not limited to: 210 | - Are these enough plots for a final paper submission? Don't create more than {MAX_FIGURES} plots. 211 | - Have you made sure to both use key numbers and generate more detailed plots from .npy files? 212 | - Does the figure title and legend have informative and descriptive names? These plots are the final versions, ensure there are no comments or other notes. 213 | - Can you aggregate multiple plots into one figure if suitable? 214 | - Do the labels have underscores? If so, replace them with spaces. 215 | - Make sure that every plot is unique and not duplicated from the original plots. 216 | 217 | If you believe you are done, simply say: "I am done". Otherwise, please provide an updated aggregator script in triple backticks.""" 218 | 219 | print("[green]Reflection prompt:[/green] ", reflection_prompt) 220 | try: 221 | reflection_response, msg_history = get_response_from_llm( 222 | prompt=reflection_prompt, 223 | client=client, 224 | model=model_name, 225 | system_message=AGGREGATOR_SYSTEM_MSG, 226 | print_debug=False, 227 | msg_history=msg_history, 228 | ) 229 | 230 | except Exception: 231 | traceback.print_exc() 232 | print("Failed to get reflection from LLM.") 233 | return 234 | 235 | # Early-exit check 236 | if figure_count > 0 and "I am done" in reflection_response: 237 | print("LLM indicated it is done with reflections. Exiting reflection loop.") 238 | break 239 | 240 | aggregator_new_code = extract_code_snippet(reflection_response) 241 | 242 | # If new code is provided and differs, run again 243 | if ( 244 | aggregator_new_code.strip() 245 | and aggregator_new_code.strip() != aggregator_code.strip() 246 | ): 247 | aggregator_code = aggregator_new_code 248 | aggregator_out = run_aggregator_script( 249 | aggregator_code, aggregator_script_path, base_folder, filename 250 | ) 251 | else: 252 | print( 253 | f"No new aggregator script was provided or it was identical. Reflection step {i+1} complete." 254 | ) 255 | 256 | 257 | def main(): 258 | parser = argparse.ArgumentParser( 259 | description="Generate and execute a final plot aggregation script with LLM assistance." 260 | ) 261 | parser.add_argument( 262 | "--folder", 263 | required=True, 264 | help="Path to the experiment folder with summary JSON files.", 265 | ) 266 | parser.add_argument( 267 | "--model", 268 | default="o1-2024-12-17", 269 | help="LLM model to use (default: o1-2024-12-17).", 270 | ) 271 | parser.add_argument( 272 | "--reflections", 273 | type=int, 274 | default=5, 275 | help="Number of reflection steps to attempt (default: 5).", 276 | ) 277 | args = parser.parse_args() 278 | aggregate_plots( 279 | base_folder=args.folder, model=args.model, n_reflections=args.reflections 280 | ) 281 | 282 | 283 | if __name__ == "__main__": 284 | main() 285 | -------------------------------------------------------------------------------- /ai_scientist/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/ai_scientist/tools/__init__.py -------------------------------------------------------------------------------- /ai_scientist/tools/base_tool.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List 3 | 4 | 5 | class BaseTool(ABC): 6 | """ 7 | An abstract base class for defining custom tools. 8 | 9 | Attributes: 10 | ----------- 11 | - name (str): The name of the tool. 12 | - description (str): A short description of what the tool does. 13 | - parameters (list): A list of parameters that the tool requires, each parameter should be a dictionary with 'name', 'type', and 'description' key/value pairs. 14 | 15 | Usage: 16 | ------ 17 | To use this class, you should subclass it and provide an implementation for the `use_tool` abstract method. 18 | """ 19 | 20 | def __init__(self, name: str, description: str, parameters: List[Dict[str, Any]]): 21 | self.name = name 22 | self.description = description 23 | self.parameters = parameters 24 | 25 | @abstractmethod 26 | def use_tool(self, **kwargs) -> Any: 27 | """Abstract method that should be implemented by subclasses to define the functionality of the tool.""" 28 | pass 29 | -------------------------------------------------------------------------------- /ai_scientist/tools/semantic_scholar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import time 4 | import warnings 5 | from typing import Dict, List, Optional, Union 6 | 7 | import backoff 8 | 9 | from ai_scientist.tools.base_tool import BaseTool 10 | 11 | 12 | def on_backoff(details: Dict) -> None: 13 | print( 14 | f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries " 15 | f"calling function {details['target'].__name__} at {time.strftime('%X')}" 16 | ) 17 | 18 | 19 | class SemanticScholarSearchTool(BaseTool): 20 | def __init__( 21 | self, 22 | name: str = "SearchSemanticScholar", 23 | description: str = ( 24 | "Search for relevant literature using Semantic Scholar. " 25 | "Provide a search query to find relevant papers." 26 | ), 27 | max_results: int = 10, 28 | ): 29 | parameters = [ 30 | { 31 | "name": "query", 32 | "type": "str", 33 | "description": "The search query to find relevant papers.", 34 | } 35 | ] 36 | super().__init__(name, description, parameters) 37 | self.max_results = max_results 38 | self.S2_API_KEY = os.getenv("S2_API_KEY") 39 | if not self.S2_API_KEY: 40 | warnings.warn( 41 | "No Semantic Scholar API key found. Requests will be subject to stricter rate limits. " 42 | "Set the S2_API_KEY environment variable for higher limits." 43 | ) 44 | 45 | def use_tool(self, query: str) -> Optional[str]: 46 | papers = self.search_for_papers(query) 47 | if papers: 48 | return self.format_papers(papers) 49 | else: 50 | return "No papers found." 51 | 52 | @backoff.on_exception( 53 | backoff.expo, 54 | (requests.exceptions.HTTPError, requests.exceptions.ConnectionError), 55 | on_backoff=on_backoff, 56 | ) 57 | def search_for_papers(self, query: str) -> Optional[List[Dict]]: 58 | if not query: 59 | return None 60 | 61 | headers = {} 62 | if self.S2_API_KEY: 63 | headers["X-API-KEY"] = self.S2_API_KEY 64 | 65 | rsp = requests.get( 66 | "https://api.semanticscholar.org/graph/v1/paper/search", 67 | headers=headers, 68 | params={ 69 | "query": query, 70 | "limit": self.max_results, 71 | "fields": "title,authors,venue,year,abstract,citationCount", 72 | }, 73 | ) 74 | print(f"Response Status Code: {rsp.status_code}") 75 | print(f"Response Content: {rsp.text[:500]}") 76 | rsp.raise_for_status() 77 | results = rsp.json() 78 | total = results.get("total", 0) 79 | if total == 0: 80 | return None 81 | 82 | papers = results.get("data", []) 83 | # Sort papers by citationCount in descending order 84 | papers.sort(key=lambda x: x.get("citationCount", 0), reverse=True) 85 | return papers 86 | 87 | def format_papers(self, papers: List[Dict]) -> str: 88 | paper_strings = [] 89 | for i, paper in enumerate(papers): 90 | authors = ", ".join( 91 | [author.get("name", "Unknown") for author in paper.get("authors", [])] 92 | ) 93 | paper_strings.append( 94 | f"""{i + 1}: {paper.get("title", "Unknown Title")}. {authors}. {paper.get("venue", "Unknown Venue")}, {paper.get("year", "Unknown Year")}. 95 | Number of citations: {paper.get("citationCount", "N/A")} 96 | Abstract: {paper.get("abstract", "No abstract available.")}""" 97 | ) 98 | return "\n\n".join(paper_strings) 99 | 100 | 101 | @backoff.on_exception( 102 | backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff 103 | ) 104 | def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]: 105 | S2_API_KEY = os.getenv("S2_API_KEY") 106 | headers = {} 107 | if not S2_API_KEY: 108 | warnings.warn( 109 | "No Semantic Scholar API key found. Requests will be subject to stricter rate limits." 110 | ) 111 | else: 112 | headers["X-API-KEY"] = S2_API_KEY 113 | 114 | if not query: 115 | return None 116 | 117 | rsp = requests.get( 118 | "https://api.semanticscholar.org/graph/v1/paper/search", 119 | headers=headers, 120 | params={ 121 | "query": query, 122 | "limit": result_limit, 123 | "fields": "title,authors,venue,year,abstract,citationStyles,citationCount", 124 | }, 125 | ) 126 | print(f"Response Status Code: {rsp.status_code}") 127 | print( 128 | f"Response Content: {rsp.text[:500]}" 129 | ) # Print the first 500 characters of the response content 130 | rsp.raise_for_status() 131 | results = rsp.json() 132 | total = results["total"] 133 | time.sleep(1.0) 134 | if not total: 135 | return None 136 | 137 | papers = results["data"] 138 | return papers 139 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/ai_scientist/treesearch/__init__.py -------------------------------------------------------------------------------- /ai_scientist/treesearch/backend/__init__.py: -------------------------------------------------------------------------------- 1 | from . import backend_anthropic, backend_openai 2 | from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md 3 | 4 | 5 | def query( 6 | system_message: PromptType | None, 7 | user_message: PromptType | None, 8 | model: str, 9 | temperature: float | None = None, 10 | max_tokens: int | None = None, 11 | func_spec: FunctionSpec | None = None, 12 | **model_kwargs, 13 | ) -> OutputType: 14 | """ 15 | General LLM query for various backends with a single system and user message. 16 | Supports function calling for some backends. 17 | 18 | Args: 19 | system_message (PromptType | None): Uncompiled system message (will generate a message following the OpenAI/Anthropic format) 20 | user_message (PromptType | None): Uncompiled user message (will generate a message following the OpenAI/Anthropic format) 21 | model (str): string identifier for the model to use (e.g. "gpt-4-turbo") 22 | temperature (float | None, optional): Temperature to sample at. Defaults to the model-specific default. 23 | max_tokens (int | None, optional): Maximum number of tokens to generate. Defaults to the model-specific max tokens. 24 | func_spec (FunctionSpec | None, optional): Optional FunctionSpec object defining a function call. If given, the return value will be a dict. 25 | 26 | Returns: 27 | OutputType: A string completion if func_spec is None, otherwise a dict with the function call details. 28 | """ 29 | 30 | model_kwargs = model_kwargs | { 31 | "model": model, 32 | "temperature": temperature, 33 | } 34 | 35 | # Handle models with beta limitations 36 | # ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations 37 | if model.startswith("o1"): 38 | if system_message and user_message is None: 39 | user_message = system_message 40 | elif system_message is None and user_message: 41 | pass 42 | elif system_message and user_message: 43 | system_message["Main Instructions"] = {} 44 | system_message["Main Instructions"] |= user_message 45 | user_message = system_message 46 | system_message = None 47 | # model_kwargs["temperature"] = 0.5 48 | model_kwargs["reasoning_effort"] = "high" 49 | model_kwargs["max_completion_tokens"] = 100000 # max_tokens 50 | # remove 'temperature' from model_kwargs 51 | model_kwargs.pop("temperature", None) 52 | else: 53 | model_kwargs["max_tokens"] = max_tokens 54 | 55 | query_func = backend_anthropic.query if "claude-" in model else backend_openai.query 56 | output, req_time, in_tok_count, out_tok_count, info = query_func( 57 | system_message=compile_prompt_to_md(system_message) if system_message else None, 58 | user_message=compile_prompt_to_md(user_message) if user_message else None, 59 | func_spec=func_spec, 60 | **model_kwargs, 61 | ) 62 | 63 | return output 64 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/backend/backend_anthropic.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create 5 | from funcy import notnone, once, select_values 6 | import anthropic 7 | 8 | # _client: anthropic.Anthropic = None # type: ignore 9 | _client: anthropic.AnthropicBedrock = None # type: ignore 10 | 11 | ANTHROPIC_TIMEOUT_EXCEPTIONS = ( 12 | anthropic.RateLimitError, 13 | anthropic.APIConnectionError, 14 | anthropic.APITimeoutError, 15 | anthropic.InternalServerError, 16 | anthropic.APIStatusError, 17 | ) 18 | 19 | 20 | @once 21 | def _setup_anthropic_client(): 22 | global _client 23 | # _client = anthropic.Anthropic(max_retries=0) 24 | _client = anthropic.AnthropicBedrock(max_retries=0) 25 | 26 | 27 | def query( 28 | system_message: str | None, 29 | user_message: str | None, 30 | func_spec: FunctionSpec | None = None, 31 | **model_kwargs, 32 | ) -> tuple[OutputType, float, int, int, dict]: 33 | _setup_anthropic_client() 34 | 35 | filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore 36 | if "max_tokens" not in filtered_kwargs: 37 | filtered_kwargs["max_tokens"] = 8192 # default for Claude models 38 | 39 | if func_spec is not None: 40 | raise NotImplementedError( 41 | "Anthropic does not support function calling for now." 42 | ) 43 | 44 | # Anthropic doesn't allow not having a user messages 45 | # if we only have system msg -> use it as user msg 46 | if system_message is not None and user_message is None: 47 | system_message, user_message = user_message, system_message 48 | 49 | # Anthropic passes the system messages as a separate argument 50 | if system_message is not None: 51 | filtered_kwargs["system"] = system_message 52 | 53 | messages = opt_messages_to_list(None, user_message) 54 | 55 | t0 = time.time() 56 | message = backoff_create( 57 | _client.messages.create, 58 | ANTHROPIC_TIMEOUT_EXCEPTIONS, 59 | messages=messages, 60 | **filtered_kwargs, 61 | ) 62 | req_time = time.time() - t0 63 | print(filtered_kwargs) 64 | 65 | if "thinking" in filtered_kwargs: 66 | assert ( 67 | len(message.content) == 2 68 | and message.content[0].type == "thinking" 69 | and message.content[1].type == "text" 70 | ) 71 | output: str = message.content[1].text 72 | else: 73 | assert len(message.content) == 1 and message.content[0].type == "text" 74 | output: str = message.content[0].text 75 | 76 | in_tokens = message.usage.input_tokens 77 | out_tokens = message.usage.output_tokens 78 | 79 | info = { 80 | "stop_reason": message.stop_reason, 81 | } 82 | 83 | return output, req_time, in_tokens, out_tokens, info 84 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/backend/backend_openai.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | 5 | from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create 6 | from funcy import notnone, once, select_values 7 | import openai 8 | from rich import print 9 | 10 | logger = logging.getLogger("ai-scientist") 11 | 12 | _client: openai.OpenAI = None # type: ignore 13 | 14 | OPENAI_TIMEOUT_EXCEPTIONS = ( 15 | openai.RateLimitError, 16 | openai.APIConnectionError, 17 | openai.APITimeoutError, 18 | openai.InternalServerError, 19 | ) 20 | 21 | 22 | @once 23 | def _setup_openai_client(): 24 | global _client 25 | _client = openai.OpenAI(max_retries=0) 26 | 27 | 28 | def query( 29 | system_message: str | None, 30 | user_message: str | None, 31 | func_spec: FunctionSpec | None = None, 32 | **model_kwargs, 33 | ) -> tuple[OutputType, float, int, int, dict]: 34 | _setup_openai_client() 35 | filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore 36 | 37 | messages = opt_messages_to_list(system_message, user_message) 38 | 39 | if func_spec is not None: 40 | filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict] 41 | # force the model to use the function 42 | filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict 43 | 44 | t0 = time.time() 45 | completion = backoff_create( 46 | _client.chat.completions.create, 47 | OPENAI_TIMEOUT_EXCEPTIONS, 48 | messages=messages, 49 | **filtered_kwargs, 50 | ) 51 | req_time = time.time() - t0 52 | 53 | choice = completion.choices[0] 54 | 55 | if func_spec is None: 56 | output = choice.message.content 57 | else: 58 | assert ( 59 | choice.message.tool_calls 60 | ), f"function_call is empty, it is not a function call: {choice.message}" 61 | assert ( 62 | choice.message.tool_calls[0].function.name == func_spec.name 63 | ), "Function name mismatch" 64 | try: 65 | print(f"[cyan]Raw func call response: {choice}[/cyan]") 66 | output = json.loads(choice.message.tool_calls[0].function.arguments) 67 | except json.JSONDecodeError as e: 68 | logger.error( 69 | f"Error decoding the function arguments: {choice.message.tool_calls[0].function.arguments}" 70 | ) 71 | raise e 72 | 73 | in_tokens = completion.usage.prompt_tokens 74 | out_tokens = completion.usage.completion_tokens 75 | 76 | info = { 77 | "system_fingerprint": completion.system_fingerprint, 78 | "model": completion.model, 79 | "created": completion.created, 80 | } 81 | 82 | return output, req_time, in_tokens, out_tokens, info 83 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/backend/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jsonschema 4 | from dataclasses_json import DataClassJsonMixin 5 | 6 | PromptType = str | dict | list 7 | FunctionCallType = dict 8 | OutputType = str | FunctionCallType 9 | 10 | 11 | import backoff 12 | import logging 13 | from typing import Callable 14 | 15 | logger = logging.getLogger("ai-scientist") 16 | 17 | 18 | @backoff.on_predicate( 19 | wait_gen=backoff.expo, 20 | max_value=60, 21 | factor=1.5, 22 | ) 23 | def backoff_create( 24 | create_fn: Callable, retry_exceptions: list[Exception], *args, **kwargs 25 | ): 26 | try: 27 | return create_fn(*args, **kwargs) 28 | except retry_exceptions as e: 29 | logger.info(f"Backoff exception: {e}") 30 | return False 31 | 32 | 33 | def opt_messages_to_list( 34 | system_message: str | None, user_message: str | None 35 | ) -> list[dict[str, str]]: 36 | messages = [] 37 | if system_message: 38 | messages.append({"role": "system", "content": system_message}) 39 | if user_message: 40 | messages.append({"role": "user", "content": user_message}) 41 | return messages 42 | 43 | 44 | def compile_prompt_to_md(prompt: PromptType, _header_depth: int = 1) -> str: 45 | """Convert a prompt into markdown format""" 46 | try: 47 | logger.debug(f"compile_prompt_to_md input: type={type(prompt)}") 48 | if isinstance(prompt, (list, dict)): 49 | logger.debug(f"prompt content: {prompt}") 50 | 51 | if prompt is None: 52 | return "" 53 | 54 | if isinstance(prompt, str): 55 | return prompt.strip() + "\n" 56 | 57 | if isinstance(prompt, list): 58 | # Handle empty list case 59 | if not prompt: 60 | return "" 61 | # Special handling for multi-modal messages 62 | if all(isinstance(item, dict) and "type" in item for item in prompt): 63 | # For multi-modal messages, just pass through without modification 64 | return prompt 65 | 66 | try: 67 | result = "\n".join([f"- {s.strip()}" for s in prompt] + ["\n"]) 68 | return result 69 | except Exception as e: 70 | logger.error(f"Error processing list items: {e}") 71 | logger.error("List contents:") 72 | for i, item in enumerate(prompt): 73 | logger.error(f" Item {i}: type={type(item)}, value={item}") 74 | raise 75 | 76 | if isinstance(prompt, dict): 77 | # Check if this is a single multi-modal message 78 | if "type" in prompt: 79 | return prompt 80 | 81 | # Regular dict processing 82 | try: 83 | out = [] 84 | header_prefix = "#" * _header_depth 85 | for k, v in prompt.items(): 86 | logger.debug(f"Processing dict key: {k}") 87 | out.append(f"{header_prefix} {k}\n") 88 | out.append(compile_prompt_to_md(v, _header_depth=_header_depth + 1)) 89 | return "\n".join(out) 90 | except Exception as e: 91 | logger.error(f"Error processing dict: {e}") 92 | logger.error(f"Dict contents: {prompt}") 93 | raise 94 | 95 | raise ValueError(f"Unsupported prompt type: {type(prompt)}") 96 | 97 | except Exception as e: 98 | logger.error("Error in compile_prompt_to_md:") 99 | logger.error(f"Input type: {type(prompt)}") 100 | logger.error(f"Input content: {prompt}") 101 | logger.error(f"Error: {str(e)}") 102 | raise 103 | 104 | 105 | @dataclass 106 | class FunctionSpec(DataClassJsonMixin): 107 | name: str 108 | json_schema: dict # JSON schema 109 | description: str 110 | 111 | def __post_init__(self): 112 | # validate the schema 113 | jsonschema.Draft7Validator.check_schema(self.json_schema) 114 | 115 | @property 116 | def as_openai_tool_dict(self): 117 | return { 118 | "type": "function", 119 | "function": { 120 | "name": self.name, 121 | "description": self.description, 122 | "parameters": self.json_schema, 123 | }, 124 | } 125 | 126 | @property 127 | def openai_tool_choice_dict(self): 128 | return { 129 | "type": "function", 130 | "function": {"name": self.name}, 131 | } 132 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/bfts_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import yaml 5 | 6 | 7 | def idea_to_markdown(data: dict, output_path: str, load_code: str) -> None: 8 | """ 9 | Convert a dictionary into a markdown file. 10 | 11 | Args: 12 | data: Dictionary containing the data to convert 13 | output_path: Path where the markdown file will be saved 14 | load_code: Path to a code file to include in the markdown 15 | """ 16 | with open(output_path, "w", encoding="utf-8") as f: 17 | for key, value in data.items(): 18 | # Convert key to title format and make it a header 19 | header = key.replace("_", " ").title() 20 | f.write(f"## {header}\n\n") 21 | 22 | # Handle different value types 23 | if isinstance(value, (list, tuple)): 24 | for item in value: 25 | f.write(f"- {item}\n") 26 | f.write("\n") 27 | elif isinstance(value, dict): 28 | for sub_key, sub_value in value.items(): 29 | f.write(f"### {sub_key}\n") 30 | f.write(f"{sub_value}\n\n") 31 | else: 32 | f.write(f"{value}\n\n") 33 | 34 | # Add the code to the markdown file 35 | if load_code: 36 | # Assert that the code file exists before trying to open it 37 | assert os.path.exists(load_code), f"Code path at {load_code} must exist if using the 'load_code' flag. This is an optional code prompt that you may choose to include; if not, please do not set 'load_code'." 38 | f.write(f"## Code To Potentially Use\n\n") 39 | f.write(f"Use the following code as context for your experiments:\n\n") 40 | with open(load_code, "r") as code_file: 41 | code = code_file.read() 42 | f.write(f"```python\n{code}\n```\n\n") 43 | 44 | 45 | def edit_bfts_config_file(config_path: str, idea_dir: str, idea_path: str) -> str: 46 | """ 47 | Edit the bfts_config.yaml file to point to the idea.md file 48 | 49 | Args: 50 | config_path: Path to the bfts_config.yaml file 51 | idea_dir: Directory where the idea.md file is located 52 | idea_path: Path to the idea.md file 53 | 54 | Returns: 55 | Path to the edited bfts_config.yaml file 56 | """ 57 | run_config_path = osp.join(idea_dir, "bfts_config.yaml") 58 | shutil.copy(config_path, run_config_path) 59 | with open(run_config_path, "r") as f: 60 | config = yaml.load(f, Loader=yaml.FullLoader) 61 | config["desc_file"] = idea_path 62 | config["workspace_dir"] = idea_dir 63 | 64 | # make an empty data directory 65 | data_dir = osp.join(idea_dir, "data") 66 | os.makedirs(data_dir, exist_ok=True) 67 | config["data_dir"] = data_dir 68 | 69 | # make an empty log directory 70 | log_dir = osp.join(idea_dir, "logs") 71 | os.makedirs(log_dir, exist_ok=True) 72 | config["log_dir"] = log_dir 73 | 74 | with open(run_config_path, "w") as f: 75 | yaml.dump(config, f) 76 | return run_config_path 77 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/interpreter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python interpreter for executing code snippets and capturing their output. 3 | Supports: 4 | - captures stdout and stderr 5 | - captures exceptions and stack traces 6 | - limits execution time 7 | """ 8 | 9 | import logging 10 | import os 11 | import queue 12 | import signal 13 | import sys 14 | import time 15 | import traceback 16 | from dataclasses import dataclass 17 | from multiprocessing import Process, Queue 18 | from pathlib import Path 19 | 20 | import humanize 21 | from dataclasses_json import DataClassJsonMixin 22 | 23 | logger = logging.getLogger("ai-scientist") 24 | 25 | 26 | @dataclass 27 | class ExecutionResult(DataClassJsonMixin): 28 | """ 29 | Result of executing a code snippet in the interpreter. 30 | Contains the output, execution time, and exception information. 31 | """ 32 | 33 | term_out: list[str] 34 | exec_time: float 35 | exc_type: str | None 36 | exc_info: dict | None = None 37 | exc_stack: list[tuple] | None = None 38 | 39 | 40 | def exception_summary(e, working_dir, exec_file_name, format_tb_ipython): 41 | """Generates a string that summarizes an exception and its stack trace (either in standard python repl or in IPython format).""" 42 | if format_tb_ipython: 43 | import IPython.core.ultratb 44 | 45 | tb = IPython.core.ultratb.VerboseTB(tb_offset=1, color_scheme="NoColor") 46 | tb_str = str(tb.text(*sys.exc_info())) 47 | else: 48 | tb_lines = traceback.format_exception(e) 49 | # skip parts of stack trace in weflow code 50 | tb_str = "".join( 51 | [l for l in tb_lines if "treesearch/" not in l and "importlib" not in l] 52 | ) 53 | 54 | # replace whole path to file with just filename (to remove agent workspace dir) 55 | tb_str = tb_str.replace(str(working_dir / exec_file_name), exec_file_name) 56 | 57 | exc_info = {} 58 | if hasattr(e, "args"): 59 | exc_info["args"] = [str(i) for i in e.args] 60 | for att in ["name", "msg", "obj"]: 61 | if hasattr(e, att): 62 | exc_info[att] = str(getattr(e, att)) 63 | 64 | tb = traceback.extract_tb(e.__traceback__) 65 | exc_stack = [(t.filename, t.lineno, t.name, t.line) for t in tb] 66 | 67 | return tb_str, e.__class__.__name__, exc_info, exc_stack 68 | 69 | 70 | class RedirectQueue: 71 | def __init__(self, queue): 72 | self.queue = queue 73 | 74 | def write(self, msg): 75 | self.queue.put(msg) 76 | 77 | def flush(self): 78 | pass 79 | 80 | 81 | class Interpreter: 82 | def __init__( 83 | self, 84 | working_dir: Path | str, 85 | timeout: int = 3600, 86 | format_tb_ipython: bool = False, 87 | agent_file_name: str = "runfile.py", 88 | env_vars: dict[str, str] = {}, 89 | ): 90 | """ 91 | Simulates a standalone Python REPL with an execution time limit. 92 | 93 | Args: 94 | working_dir (Path | str): working directory of the agent 95 | timeout (int, optional): Timeout for each code execution step. Defaults to 3600. 96 | format_tb_ipython (bool, optional): Whether to use IPython or default python REPL formatting for exceptions. Defaults to False. 97 | agent_file_name (str, optional): The name for the agent's code file. Defaults to "runfile.py". 98 | env_vars (dict[str, str], optional): Environment variables to set in the child process. Defaults to {}. 99 | """ 100 | # this really needs to be a path, otherwise causes issues that don't raise exc 101 | self.working_dir = Path(working_dir).resolve() 102 | assert ( 103 | self.working_dir.exists() 104 | ), f"Working directory {self.working_dir} does not exist" 105 | self.timeout = timeout 106 | self.format_tb_ipython = format_tb_ipython 107 | self.agent_file_name = agent_file_name 108 | self.process: Process = None # type: ignore 109 | self.env_vars = env_vars 110 | 111 | def child_proc_setup(self, result_outq: Queue) -> None: 112 | # disable all warnings (before importing anything) 113 | import shutup 114 | 115 | shutup.mute_warnings() 116 | 117 | for key, value in self.env_vars.items(): 118 | os.environ[key] = value 119 | 120 | os.chdir(str(self.working_dir)) 121 | 122 | # this seems to only benecessary because we're exec'ing code from a string, 123 | # a .py file should be able to import modules from the cwd anyway 124 | sys.path.append(str(self.working_dir)) 125 | 126 | # capture stdout and stderr 127 | # trunk-ignore(mypy/assignment) 128 | sys.stdout = sys.stderr = RedirectQueue(result_outq) 129 | 130 | def _run_session( 131 | self, code_inq: Queue, result_outq: Queue, event_outq: Queue 132 | ) -> None: 133 | self.child_proc_setup(result_outq) 134 | 135 | global_scope: dict = {} 136 | while True: 137 | code = code_inq.get() 138 | os.chdir(str(self.working_dir)) 139 | with open(self.agent_file_name, "w") as f: 140 | f.write(code) 141 | 142 | event_outq.put(("state:ready",)) 143 | try: 144 | exec(compile(code, self.agent_file_name, "exec"), global_scope) 145 | except BaseException as e: 146 | tb_str, e_cls_name, exc_info, exc_stack = exception_summary( 147 | e, 148 | self.working_dir, 149 | self.agent_file_name, 150 | self.format_tb_ipython, 151 | ) 152 | result_outq.put(tb_str) 153 | if e_cls_name == "KeyboardInterrupt": 154 | e_cls_name = "TimeoutError" 155 | 156 | event_outq.put(("state:finished", e_cls_name, exc_info, exc_stack)) 157 | else: 158 | event_outq.put(("state:finished", None, None, None)) 159 | 160 | # put EOF marker to indicate that we're done 161 | result_outq.put("<|EOF|>") 162 | 163 | def create_process(self) -> None: 164 | # we use three queues to communicate with the child process: 165 | # - code_inq: send code to child to execute 166 | # - result_outq: receive stdout/stderr from child 167 | # - event_outq: receive events from child (e.g. state:ready, state:finished) 168 | # trunk-ignore(mypy/var-annotated) 169 | self.code_inq, self.result_outq, self.event_outq = Queue(), Queue(), Queue() 170 | self.process = Process( 171 | target=self._run_session, 172 | args=(self.code_inq, self.result_outq, self.event_outq), 173 | ) 174 | self.process.start() 175 | 176 | def _drain_queues(self): 177 | """Quickly drain all in-flight messages to prevent blocking.""" 178 | while not self.result_outq.empty(): 179 | try: 180 | self.result_outq.get_nowait() 181 | except Exception: 182 | break 183 | 184 | while not self.event_outq.empty(): 185 | try: 186 | self.event_outq.get_nowait() 187 | except Exception: 188 | break 189 | 190 | while not self.code_inq.empty(): 191 | try: 192 | self.code_inq.get_nowait() 193 | except Exception: 194 | break 195 | 196 | def cleanup_session(self): 197 | if self.process is None: 198 | return 199 | # give the child process a chance to terminate gracefully 200 | self.process.terminate() 201 | self._drain_queues() 202 | self.process.join(timeout=2) 203 | # kill the child process if it's still alive 204 | if self.process.exitcode is None: 205 | logger.warning("Child process failed to terminate gracefully, killing it..") 206 | self.process.kill() 207 | self._drain_queues() 208 | self.process.join(timeout=2) 209 | # don't wait for gc, clean up immediately 210 | self.process.close() 211 | self.process = None # type: ignore 212 | 213 | def run(self, code: str, reset_session=True) -> ExecutionResult: 214 | """ 215 | Execute the provided Python command in a separate process and return its output. 216 | 217 | Parameters: 218 | code (str): Python code to execute. 219 | reset_session (bool, optional): Whether to reset the interpreter session before executing the code. Defaults to True. 220 | 221 | Returns: 222 | ExecutionResult: Object containing the output and metadata of the code execution. 223 | 224 | """ 225 | 226 | logger.debug(f"REPL is executing code (reset_session={reset_session})") 227 | 228 | if reset_session: 229 | if self.process is not None: 230 | # terminate and clean up previous process 231 | self.cleanup_session() 232 | self.create_process() 233 | else: 234 | # reset_session needs to be True on first exec 235 | assert self.process is not None 236 | 237 | assert self.process.is_alive() 238 | 239 | self.code_inq.put(code) 240 | 241 | # wait for child to actually start execution (we don't want interrupt child setup) 242 | try: 243 | state = self.event_outq.get(timeout=10) 244 | except queue.Empty: 245 | msg = "REPL child process failed to start execution" 246 | logger.critical(msg) 247 | while not self.result_outq.empty(): 248 | logger.error(f"REPL output queue dump: {self.result_outq.get()}") 249 | raise RuntimeError(msg) from None 250 | assert state[0] == "state:ready", state 251 | start_time = time.time() 252 | 253 | # this flag indicates that the child ahs exceeded the time limit and an interrupt was sent 254 | # if the child process dies without this flag being set, it's an unexpected termination 255 | child_in_overtime = False 256 | 257 | while True: 258 | try: 259 | # check if the child is done 260 | state = self.event_outq.get(timeout=1) # wait for state:finished 261 | assert state[0] == "state:finished", state 262 | exec_time = time.time() - start_time 263 | break 264 | except queue.Empty: 265 | # we haven't heard back from the child -> check if it's still alive (assuming overtime interrupt wasn't sent yet) 266 | if not child_in_overtime and not self.process.is_alive(): 267 | msg = "REPL child process died unexpectedly" 268 | logger.critical(msg) 269 | while not self.result_outq.empty(): 270 | logger.error( 271 | f"REPL output queue dump: {self.result_outq.get()}" 272 | ) 273 | raise RuntimeError(msg) from None 274 | 275 | # child is alive and still executing -> check if we should sigint.. 276 | if self.timeout is None: 277 | continue 278 | running_time = time.time() - start_time 279 | if running_time > self.timeout: 280 | # [TODO] handle this in a better way 281 | assert reset_session, "Timeout ocurred in interactive session" 282 | 283 | # send interrupt to child 284 | os.kill(self.process.pid, signal.SIGINT) # type: ignore 285 | child_in_overtime = True 286 | # terminate if we're overtime by more than a minute 287 | if running_time > self.timeout + 60: 288 | logger.warning("Child failed to terminate, killing it..") 289 | self.cleanup_session() 290 | 291 | state = (None, "TimeoutError", {}, []) 292 | exec_time = self.timeout 293 | break 294 | 295 | output: list[str] = [] 296 | # read all stdout/stderr from child up to the EOF marker 297 | # waiting until the queue is empty is not enough since 298 | # the feeder thread in child might still be adding to the queue 299 | while not self.result_outq.empty() or not output or output[-1] != "<|EOF|>": 300 | output.append(self.result_outq.get()) 301 | output.pop() # remove the EOF marker 302 | 303 | e_cls_name, exc_info, exc_stack = state[1:] 304 | 305 | if e_cls_name == "TimeoutError": 306 | output.append( 307 | f"TimeoutError: Execution exceeded the time limit of {humanize.naturaldelta(self.timeout)}" 308 | ) 309 | else: 310 | output.append( 311 | f"Execution time: {humanize.naturaldelta(exec_time)} seconds (time limit is {humanize.naturaldelta(self.timeout)})." 312 | ) 313 | return ExecutionResult(output, exec_time, e_cls_name, exc_info, exc_stack) 314 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/journal2report.py: -------------------------------------------------------------------------------- 1 | from .backend import query 2 | from .journal import Journal 3 | from .utils.config import StageConfig 4 | 5 | 6 | def journal2report(journal: Journal, task_desc: dict, rcfg: StageConfig): 7 | """ 8 | Generate a report from a journal, the report will be in markdown format. 9 | """ 10 | report_input = journal.generate_summary(include_code=True) 11 | system_prompt_dict = { 12 | "Role": "You are a research assistant that always uses concise language.", 13 | "Goal": "The goal is to write a technical report summarising the empirical findings and technical decisions.", 14 | "Input": "You are given a raw research journal with list of design attempts and their outcomes, and a research idea description.", 15 | "Output": [ 16 | "Your output should be a single markdown document.", 17 | "Your report should have the following sections: Introduction, Preprocessing, Methods, Results Discussion, Future Work", 18 | "You can include subsections if needed.", 19 | ], 20 | } 21 | context_prompt = ( 22 | f"Here is the research journal of the agent: {report_input}<\\journal>, " 23 | f"and the research idea description is: {task_desc}<\\research_proposal>." 24 | ) 25 | return query( 26 | system_message=system_prompt_dict, 27 | user_message=context_prompt, 28 | model=rcfg.model, 29 | temperature=rcfg.temp, 30 | max_tokens=4096, 31 | ) 32 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/perform_experiments_bfts_with_agentmanager.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import logging 3 | import shutil 4 | import json 5 | import pickle 6 | from . import backend 7 | from .journal import Journal, Node 8 | from .journal2report import journal2report 9 | from rich.columns import Columns 10 | from rich.console import Group 11 | from rich.live import Live 12 | from rich.padding import Padding 13 | from rich.panel import Panel 14 | from rich.progress import ( 15 | BarColumn, 16 | MofNCompleteColumn, 17 | Progress, 18 | TextColumn, 19 | TimeRemainingColumn, 20 | ) 21 | from rich.text import Text 22 | from rich.status import Status 23 | from rich.tree import Tree 24 | from .utils.config import load_task_desc, prep_agent_workspace, save_run, load_cfg 25 | from .agent_manager import AgentManager 26 | from pathlib import Path 27 | from .agent_manager import Stage 28 | from .log_summarization import overall_summarize 29 | 30 | 31 | logger = logging.getLogger("ai-scientist") 32 | 33 | 34 | def journal_to_rich_tree(journal: Journal): 35 | best_node = journal.get_best_node() 36 | 37 | def append_rec(node: Node, tree): 38 | if node.is_buggy: 39 | s = "[red]◍ bug" 40 | else: 41 | style = "bold " if node is best_node else "" 42 | 43 | if node is best_node: 44 | s = f"[{style}green]● {node.metric.value:.3f} (best)" 45 | else: 46 | s = f"[{style}green]● {node.metric.value:.3f}" 47 | 48 | subtree = tree.add(s) 49 | for child in node.children: 50 | append_rec(child, subtree) 51 | 52 | tree = Tree("[bold blue]Solution tree") 53 | for n in journal.draft_nodes: 54 | append_rec(n, tree) 55 | return tree 56 | 57 | 58 | def perform_experiments_bfts(config_path: str): 59 | # turn config path string into a path object 60 | config_path = Path(config_path) 61 | cfg = load_cfg(config_path) 62 | logger.info(f'Starting run "{cfg.exp_name}"') 63 | 64 | task_desc = load_task_desc(cfg) 65 | print(task_desc) 66 | task_desc_str = backend.compile_prompt_to_md(task_desc) 67 | 68 | global_step = 0 69 | 70 | with Status("Preparing agent workspace (copying and extracting files) ..."): 71 | prep_agent_workspace(cfg) 72 | 73 | def cleanup(): 74 | if global_step == 0: 75 | shutil.rmtree(cfg.workspace_dir) 76 | 77 | atexit.register(cleanup) 78 | 79 | manager = AgentManager( 80 | task_desc=task_desc, 81 | cfg=cfg, 82 | workspace_dir=Path(cfg.workspace_dir), 83 | ) 84 | 85 | prog = Progress( 86 | TextColumn("[progress.description]{task.description}"), 87 | BarColumn(bar_width=20), 88 | MofNCompleteColumn(), 89 | TimeRemainingColumn(), 90 | ) 91 | status = Status("[green]Running experiments...") 92 | prog.add_task("Progress:", total=cfg.agent.steps, completed=global_step) 93 | 94 | def create_exec_callback(status_obj): 95 | def exec_callback(*args, **kwargs): 96 | status_obj.update("[magenta]Executing code...") 97 | res = interpreter.run(*args, **kwargs) 98 | status_obj.update("[green]Generating code...") 99 | return res 100 | 101 | return exec_callback 102 | 103 | def step_callback(stage, journal): 104 | print("Step complete") 105 | try: 106 | # Generate and save notes for this step 107 | notes_dir = cfg.log_dir / f"stage_{stage.name}" / "notes" 108 | notes_dir.mkdir(parents=True, exist_ok=True) 109 | 110 | # Save latest node summary 111 | if journal.nodes: 112 | latest_node = journal.nodes[-1] 113 | if hasattr(latest_node, "_agent"): 114 | summary = latest_node._agent._generate_node_summary(latest_node) 115 | with open( 116 | notes_dir / f"node_{latest_node.id}_summary.json", "w" 117 | ) as f: 118 | json.dump(summary, f, indent=2) 119 | 120 | # Generate and save stage progress summary 121 | stage_summary = { 122 | "stage": stage.name, 123 | "total_nodes": len(journal.nodes), 124 | "buggy_nodes": len(journal.buggy_nodes), 125 | "good_nodes": len(journal.good_nodes), 126 | "best_metric": ( 127 | str(journal.get_best_node().metric) 128 | if journal.get_best_node() 129 | else "None" 130 | ), 131 | "current_findings": journal.generate_summary(include_code=False), 132 | } 133 | 134 | with open(notes_dir / "stage_progress.json", "w") as f: 135 | json.dump(stage_summary, f, indent=2) 136 | 137 | # Save the run as before 138 | save_run(cfg, journal, stage_name=f"stage_{stage.name}") 139 | 140 | except Exception as e: 141 | print(f"Error in step callback: {e}") 142 | 143 | print(f"Run saved at {cfg.log_dir / f'stage_{stage.name}'}") 144 | print(f"Step {len(journal)}/{stage.max_iterations} at stage_{stage.name}") 145 | print(f"Run saved at {cfg.log_dir / f'stage_{stage.name}'}") 146 | 147 | def generate_live(manager): 148 | current_stage = manager.current_stage 149 | current_journal = manager.journals.get( 150 | current_stage.name if current_stage else None, None 151 | ) 152 | 153 | if current_journal: 154 | tree = journal_to_rich_tree(current_journal) 155 | else: 156 | tree = Tree("[bold blue]No results yet") 157 | 158 | file_paths = [ 159 | f"Result visualization:\n[yellow]▶ {str((cfg.log_dir / 'tree_plot.html'))}", 160 | f"Agent workspace directory:\n[yellow]▶ {str(cfg.workspace_dir)}", 161 | f"Experiment log directory:\n[yellow]▶ {str(cfg.log_dir)}", 162 | ] 163 | 164 | stage_info = [ 165 | "[bold]Experiment Progress:", 166 | f"Current Stage: [cyan]{current_stage.name if current_stage else 'None'}[/cyan]", 167 | f"Completed Stages: [green]{', '.join(manager.completed_stages)}[/green]", 168 | ] 169 | 170 | left = Group( 171 | Panel(Text(task_desc_str.strip()), title="Task description"), 172 | Panel(Text("\n".join(stage_info)), title="Stage Progress"), 173 | prog, 174 | status, 175 | ) 176 | right = tree 177 | wide = Group(*file_paths) 178 | 179 | return Panel( 180 | Group( 181 | Padding(wide, (1, 1, 1, 1)), 182 | Columns( 183 | [Padding(left, (1, 2, 1, 1)), Padding(right, (1, 1, 1, 2))], 184 | equal=True, 185 | ), 186 | ), 187 | title=f'[b]AIDE is working on experiment: [bold green]"{cfg.exp_name}[/b]"', 188 | subtitle="Press [b]Ctrl+C[/b] to stop the run", 189 | ) 190 | 191 | live = Live( 192 | generate_live(manager), 193 | refresh_per_second=16, 194 | screen=True, 195 | ) 196 | 197 | manager.run(exec_callback=create_exec_callback(status), step_callback=step_callback) 198 | 199 | manager_pickle_path = cfg.log_dir / "manager.pkl" 200 | try: 201 | with open(manager_pickle_path, "wb") as f: 202 | pickle.dump(manager, f) 203 | logger.info(f"Saved manager state to: {manager_pickle_path}") 204 | except Exception as e: 205 | logger.warning(f"Failed to save full manager state: {e}") 206 | try: 207 | with open(manager_pickle_path, "wb") as f: 208 | pickle.dump(manager.journals.items(), f) 209 | logger.info(f"Saved manager journals to: {manager_pickle_path}") 210 | except Exception as e: 211 | logger.error(f"Failed to save manager journals: {e}") 212 | 213 | if cfg.generate_report: 214 | print("Generating final report from all stages...") 215 | ( 216 | draft_summary, 217 | baseline_summary, 218 | research_summary, 219 | ablation_summary, 220 | ) = overall_summarize(manager.journals.items()) 221 | draft_summary_path = cfg.log_dir / "draft_summary.json" 222 | baseline_summary_path = cfg.log_dir / "baseline_summary.json" 223 | research_summary_path = cfg.log_dir / "research_summary.json" 224 | ablation_summary_path = cfg.log_dir / "ablation_summary.json" 225 | 226 | with open(draft_summary_path, "w") as draft_file: 227 | json.dump(draft_summary, draft_file, indent=2) 228 | 229 | with open(baseline_summary_path, "w") as baseline_file: 230 | json.dump(baseline_summary, baseline_file, indent=2) 231 | 232 | with open(research_summary_path, "w") as research_file: 233 | json.dump(research_summary, research_file, indent=2) 234 | 235 | with open(ablation_summary_path, "w") as ablation_file: 236 | json.dump(ablation_summary, ablation_file, indent=2) 237 | 238 | print(f"Summary reports written to files:") 239 | print(f"- Draft summary: {draft_summary_path}") 240 | print(f"- Baseline summary: {baseline_summary_path}") 241 | print(f"- Research summary: {research_summary_path}") 242 | print(f"- Ablation summary: {ablation_summary_path}") 243 | 244 | 245 | if __name__ == "__main__": 246 | cfg_path = "treesearch/utils/config.yaml" 247 | cfg = load_cfg(cfg_path) 248 | perform_experiments_bfts(cfg_path) 249 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | import zipfile 4 | from pathlib import Path 5 | 6 | logger = logging.getLogger("ai-scientist") 7 | 8 | 9 | def copytree(src: Path, dst: Path, use_symlinks=True): 10 | """ 11 | Copy contents of `src` to `dst`. Unlike shutil.copytree, the dst dir can exist and will be merged. 12 | If src is a file, only that file will be copied. Optionally uses symlinks instead of copying. 13 | 14 | Args: 15 | src (Path): source directory 16 | dst (Path): destination directory 17 | """ 18 | assert dst.is_dir() 19 | 20 | if src.is_file(): 21 | dest_f = dst / src.name 22 | assert not dest_f.exists(), dest_f 23 | if use_symlinks: 24 | (dest_f).symlink_to(src) 25 | else: 26 | shutil.copyfile(src, dest_f) 27 | return 28 | 29 | for f in src.iterdir(): 30 | dest_f = dst / f.name 31 | assert not dest_f.exists(), dest_f 32 | if use_symlinks: 33 | (dest_f).symlink_to(f) 34 | elif f.is_dir(): 35 | shutil.copytree(f, dest_f) 36 | else: 37 | shutil.copyfile(f, dest_f) 38 | 39 | 40 | def clean_up_dataset(path: Path): 41 | for item in path.rglob("__MACOSX"): 42 | if item.is_dir(): 43 | shutil.rmtree(item) 44 | for item in path.rglob(".DS_Store"): 45 | if item.is_file(): 46 | item.unlink() 47 | 48 | 49 | def extract_archives(path: Path): 50 | """ 51 | unzips all .zip files within `path` and cleans up task dir 52 | 53 | [TODO] handle nested zips 54 | """ 55 | for zip_f in path.rglob("*.zip"): 56 | f_out_dir = zip_f.with_suffix("") 57 | 58 | # special case: the intended output path already exists (maybe data has already been extracted by user) 59 | if f_out_dir.exists(): 60 | logger.debug( 61 | f"Skipping {zip_f} as an item with the same name already exists." 62 | ) 63 | # if it's a file, it's probably exactly the same as in the zip -> remove the zip 64 | # [TODO] maybe add an extra check to see if zip file content matches the colliding file 65 | if f_out_dir.is_file() and f_out_dir.suffix != "": 66 | zip_f.unlink() 67 | continue 68 | 69 | logger.debug(f"Extracting: {zip_f}") 70 | f_out_dir.mkdir(exist_ok=True) 71 | with zipfile.ZipFile(zip_f, "r") as zip_ref: 72 | zip_ref.extractall(f_out_dir) 73 | 74 | # remove any unwanted files 75 | clean_up_dataset(f_out_dir) 76 | 77 | contents = list(f_out_dir.iterdir()) 78 | 79 | # special case: the zip contains a single dir/file with the same name as the zip 80 | if len(contents) == 1 and contents[0].name == f_out_dir.name: 81 | sub_item = contents[0] 82 | # if it's a dir, move its contents to the parent and remove it 83 | if sub_item.is_dir(): 84 | logger.debug(f"Special handling (child is dir) enabled for: {zip_f}") 85 | for f in sub_item.rglob("*"): 86 | shutil.move(f, f_out_dir) 87 | sub_item.rmdir() 88 | # if it's a file, rename it to the parent and remove the parent 89 | elif sub_item.is_file(): 90 | logger.debug(f"Special handling (child is file) enabled for: {zip_f}") 91 | sub_item_tmp = sub_item.rename(f_out_dir.with_suffix(".__tmp_rename")) 92 | f_out_dir.rmdir() 93 | sub_item_tmp.rename(f_out_dir) 94 | 95 | zip_f.unlink() 96 | 97 | 98 | def preproc_data(path: Path): 99 | extract_archives(path) 100 | clean_up_dataset(path) 101 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/utils/config.py: -------------------------------------------------------------------------------- 1 | """configuration and setup utils""" 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Hashable, cast, Literal, Optional 6 | 7 | import coolname 8 | import rich 9 | from omegaconf import OmegaConf 10 | from rich.syntax import Syntax 11 | import shutup 12 | from rich.logging import RichHandler 13 | import logging 14 | 15 | from . import tree_export 16 | from . import copytree, preproc_data, serialize 17 | 18 | shutup.mute_warnings() 19 | logging.basicConfig( 20 | level="WARNING", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] 21 | ) 22 | logger = logging.getLogger("ai-scientist") 23 | logger.setLevel(logging.WARNING) 24 | 25 | 26 | """ these dataclasses are just for type hinting, the actual config is in config.yaml """ 27 | 28 | 29 | @dataclass 30 | class ThinkingConfig: 31 | type: str 32 | budget_tokens: Optional[int] = None 33 | 34 | 35 | @dataclass 36 | class StageConfig: 37 | model: str 38 | temp: float 39 | thinking: ThinkingConfig 40 | betas: str 41 | max_tokens: Optional[int] = None 42 | 43 | 44 | @dataclass 45 | class SearchConfig: 46 | max_debug_depth: int 47 | debug_prob: float 48 | num_drafts: int 49 | 50 | 51 | @dataclass 52 | class DebugConfig: 53 | stage4: bool 54 | 55 | 56 | @dataclass 57 | class AgentConfig: 58 | steps: int 59 | stages: dict[str, int] 60 | k_fold_validation: int 61 | expose_prediction: bool 62 | data_preview: bool 63 | 64 | code: StageConfig 65 | feedback: StageConfig 66 | vlm_feedback: StageConfig 67 | 68 | search: SearchConfig 69 | num_workers: int 70 | type: str 71 | multi_seed_eval: dict[str, int] 72 | 73 | 74 | @dataclass 75 | class ExecConfig: 76 | timeout: int 77 | agent_file_name: str 78 | format_tb_ipython: bool 79 | 80 | 81 | @dataclass 82 | class ExperimentConfig: 83 | num_syn_datasets: int 84 | 85 | 86 | @dataclass 87 | class Config(Hashable): 88 | data_dir: Path 89 | desc_file: Path | None 90 | 91 | goal: str | None 92 | eval: str | None 93 | 94 | log_dir: Path 95 | workspace_dir: Path 96 | 97 | preprocess_data: bool 98 | copy_data: bool 99 | 100 | exp_name: str 101 | 102 | exec: ExecConfig 103 | generate_report: bool 104 | report: StageConfig 105 | agent: AgentConfig 106 | experiment: ExperimentConfig 107 | debug: DebugConfig 108 | 109 | 110 | def _get_next_logindex(dir: Path) -> int: 111 | """Get the next available index for a log directory.""" 112 | max_index = -1 113 | for p in dir.iterdir(): 114 | try: 115 | if (current_index := int(p.name.split("-")[0])) > max_index: 116 | max_index = current_index 117 | except ValueError: 118 | pass 119 | print("max_index: ", max_index) 120 | return max_index + 1 121 | 122 | 123 | def _load_cfg( 124 | path: Path = Path(__file__).parent / "config.yaml", use_cli_args=False 125 | ) -> Config: 126 | cfg = OmegaConf.load(path) 127 | if use_cli_args: 128 | cfg = OmegaConf.merge(cfg, OmegaConf.from_cli()) 129 | return cfg 130 | 131 | 132 | def load_cfg(path: Path = Path(__file__).parent / "config.yaml") -> Config: 133 | """Load config from .yaml file and CLI args, and set up logging directory.""" 134 | return prep_cfg(_load_cfg(path)) 135 | 136 | 137 | def prep_cfg(cfg: Config): 138 | if cfg.data_dir is None: 139 | raise ValueError("`data_dir` must be provided.") 140 | 141 | if cfg.desc_file is None and cfg.goal is None: 142 | raise ValueError( 143 | "You must provide either a description of the task goal (`goal=...`) or a path to a plaintext file containing the description (`desc_file=...`)." 144 | ) 145 | 146 | if cfg.data_dir.startswith("example_tasks/"): 147 | cfg.data_dir = Path(__file__).parent.parent / cfg.data_dir 148 | cfg.data_dir = Path(cfg.data_dir).resolve() 149 | 150 | if cfg.desc_file is not None: 151 | cfg.desc_file = Path(cfg.desc_file).resolve() 152 | 153 | top_log_dir = Path(cfg.log_dir).resolve() 154 | top_log_dir.mkdir(parents=True, exist_ok=True) 155 | 156 | top_workspace_dir = Path(cfg.workspace_dir).resolve() 157 | top_workspace_dir.mkdir(parents=True, exist_ok=True) 158 | 159 | # generate experiment name and prefix with consecutive index 160 | ind = max(_get_next_logindex(top_log_dir), _get_next_logindex(top_workspace_dir)) 161 | cfg.exp_name = cfg.exp_name or coolname.generate_slug(3) 162 | cfg.exp_name = f"{ind}-{cfg.exp_name}" 163 | 164 | cfg.log_dir = (top_log_dir / cfg.exp_name).resolve() 165 | cfg.workspace_dir = (top_workspace_dir / cfg.exp_name).resolve() 166 | 167 | # validate the config 168 | cfg_schema: Config = OmegaConf.structured(Config) 169 | cfg = OmegaConf.merge(cfg_schema, cfg) 170 | 171 | if cfg.agent.type not in ["parallel", "sequential"]: 172 | raise ValueError("agent.type must be either 'parallel' or 'sequential'") 173 | 174 | return cast(Config, cfg) 175 | 176 | 177 | def print_cfg(cfg: Config) -> None: 178 | rich.print(Syntax(OmegaConf.to_yaml(cfg), "yaml", theme="paraiso-dark")) 179 | 180 | 181 | def load_task_desc(cfg: Config): 182 | """Load task description from markdown file or config str.""" 183 | 184 | # either load the task description from a file 185 | if cfg.desc_file is not None: 186 | if not (cfg.goal is None and cfg.eval is None): 187 | logger.warning( 188 | "Ignoring goal and eval args because task description file is provided." 189 | ) 190 | 191 | with open(cfg.desc_file) as f: 192 | return f.read() 193 | 194 | # or generate it from the goal and eval args 195 | if cfg.goal is None: 196 | raise ValueError( 197 | "`goal` (and optionally `eval`) must be provided if a task description file is not provided." 198 | ) 199 | 200 | task_desc = {"Task goal": cfg.goal} 201 | if cfg.eval is not None: 202 | task_desc["Task evaluation"] = cfg.eval 203 | print(task_desc) 204 | return task_desc 205 | 206 | 207 | def prep_agent_workspace(cfg: Config): 208 | """Setup the agent's workspace and preprocess data if necessary.""" 209 | (cfg.workspace_dir / "input").mkdir(parents=True, exist_ok=True) 210 | (cfg.workspace_dir / "working").mkdir(parents=True, exist_ok=True) 211 | 212 | copytree(cfg.data_dir, cfg.workspace_dir / "input", use_symlinks=not cfg.copy_data) 213 | if cfg.preprocess_data: 214 | preproc_data(cfg.workspace_dir / "input") 215 | 216 | 217 | def save_run(cfg: Config, journal, stage_name: str = None): 218 | if stage_name is None: 219 | stage_name = "NoStageRun" 220 | save_dir = cfg.log_dir / stage_name 221 | save_dir.mkdir(parents=True, exist_ok=True) 222 | 223 | # save journal 224 | try: 225 | serialize.dump_json(journal, save_dir / "journal.json") 226 | except Exception as e: 227 | print(f"Error saving journal: {e}") 228 | raise 229 | # save config 230 | try: 231 | OmegaConf.save(config=cfg, f=save_dir / "config.yaml") 232 | except Exception as e: 233 | print(f"Error saving config: {e}") 234 | raise 235 | # create the tree + code visualization 236 | try: 237 | tree_export.generate(cfg, journal, save_dir / "tree_plot.html") 238 | except Exception as e: 239 | print(f"Error generating tree: {e}") 240 | raise 241 | # save the best found solution 242 | try: 243 | best_node = journal.get_best_node(only_good=False) 244 | if best_node is not None: 245 | for existing_file in save_dir.glob("best_solution_*.py"): 246 | existing_file.unlink() 247 | # Create new best solution file 248 | filename = f"best_solution_{best_node.id}.py" 249 | with open(save_dir / filename, "w") as f: 250 | f.write(best_node.code) 251 | # save best_node.id to a text file 252 | with open(save_dir / "best_node_id.txt", "w") as f: 253 | f.write(str(best_node.id)) 254 | else: 255 | print("No best node found yet") 256 | except Exception as e: 257 | print(f"Error saving best solution: {e}") 258 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/utils/data_preview.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains functions to manually generate a textual preview of some common file types (.csv, .json,..) for the agent. 3 | """ 4 | 5 | import json 6 | from pathlib import Path 7 | 8 | import humanize 9 | import pandas as pd 10 | from genson import SchemaBuilder 11 | from pandas.api.types import is_numeric_dtype 12 | 13 | # these files are treated as code (e.g. markdown wrapped) 14 | code_files = {".py", ".sh", ".yaml", ".yml", ".md", ".html", ".xml", ".log", ".rst"} 15 | # we treat these files as text (rather than binary) files 16 | plaintext_files = {".txt", ".csv", ".json", ".tsv"} | code_files 17 | 18 | 19 | def get_file_len_size(f: Path) -> tuple[int, str]: 20 | """ 21 | Calculate the size of a file (#lines for plaintext files, otherwise #bytes) 22 | Also returns a human-readable string representation of the size. 23 | """ 24 | if f.suffix in plaintext_files: 25 | num_lines = sum(1 for _ in open(f)) 26 | return num_lines, f"{num_lines} lines" 27 | else: 28 | s = f.stat().st_size 29 | return s, humanize.naturalsize(s) 30 | 31 | 32 | def file_tree(path: Path, depth=0) -> str: 33 | """Generate a tree structure of files in a directory""" 34 | result = [] 35 | files = [p for p in Path(path).iterdir() if not p.is_dir()] 36 | dirs = [p for p in Path(path).iterdir() if p.is_dir()] 37 | max_n = 4 if len(files) > 30 else 8 38 | for p in sorted(files)[:max_n]: 39 | result.append(f"{' '*depth*4}{p.name} ({get_file_len_size(p)[1]})") 40 | if len(files) > max_n: 41 | result.append(f"{' '*depth*4}... and {len(files)-max_n} other files") 42 | 43 | for p in sorted(dirs): 44 | result.append(f"{' '*depth*4}{p.name}/") 45 | result.append(file_tree(p, depth + 1)) 46 | 47 | return "\n".join(result) 48 | 49 | 50 | def _walk(path: Path): 51 | """Recursively walk a directory (analogous to os.walk but for pathlib.Path)""" 52 | for p in sorted(Path(path).iterdir()): 53 | if p.is_dir(): 54 | yield from _walk(p) 55 | continue 56 | yield p 57 | 58 | 59 | def preview_csv(p: Path, file_name: str, simple=True) -> str: 60 | """Generate a textual preview of a csv file 61 | 62 | Args: 63 | p (Path): the path to the csv file 64 | file_name (str): the file name to use in the preview 65 | simple (bool, optional): whether to use a simplified version of the preview. Defaults to True. 66 | 67 | Returns: 68 | str: the textual preview 69 | """ 70 | df = pd.read_csv(p) 71 | 72 | out = [] 73 | 74 | out.append(f"-> {file_name} has {df.shape[0]} rows and {df.shape[1]} columns.") 75 | 76 | if simple: 77 | cols = df.columns.tolist() 78 | sel_cols = 15 79 | cols_str = ", ".join(cols[:sel_cols]) 80 | res = f"The columns are: {cols_str}" 81 | if len(cols) > sel_cols: 82 | res += f"... and {len(cols)-sel_cols} more columns" 83 | out.append(res) 84 | else: 85 | out.append("Here is some information about the columns:") 86 | for col in sorted(df.columns): 87 | dtype = df[col].dtype 88 | name = f"{col} ({dtype})" 89 | 90 | nan_count = df[col].isnull().sum() 91 | 92 | if dtype == "bool": 93 | v = df[col][df[col].notnull()].mean() 94 | out.append(f"{name} is {v*100:.2f}% True, {100-v*100:.2f}% False") 95 | elif df[col].nunique() < 10: 96 | out.append( 97 | f"{name} has {df[col].nunique()} unique values: {df[col].unique().tolist()}" 98 | ) 99 | elif is_numeric_dtype(df[col]): 100 | out.append( 101 | f"{name} has range: {df[col].min():.2f} - {df[col].max():.2f}, {nan_count} nan values" 102 | ) 103 | elif dtype == "object": 104 | out.append( 105 | f"{name} has {df[col].nunique()} unique values. Some example values: {df[col].value_counts().head(4).index.tolist()}" 106 | ) 107 | 108 | return "\n".join(out) 109 | 110 | 111 | def preview_json(p: Path, file_name: str): 112 | """Generate a textual preview of a json file using a generated json schema""" 113 | builder = SchemaBuilder() 114 | with open(p) as f: 115 | builder.add_object(json.load(f)) 116 | return f"-> {file_name} has auto-generated json schema:\n" + builder.to_json( 117 | indent=2 118 | ) 119 | 120 | 121 | def generate(base_path, include_file_details=True, simple=False): 122 | """ 123 | Generate a textual preview of a directory, including an overview of the directory 124 | structure and previews of individual files 125 | """ 126 | tree = f"```\n{file_tree(base_path)}```" 127 | out = [tree] 128 | 129 | if include_file_details: 130 | for fn in _walk(base_path): 131 | file_name = str(fn.relative_to(base_path)) 132 | 133 | if fn.suffix == ".csv": 134 | out.append(preview_csv(fn, file_name, simple=simple)) 135 | elif fn.suffix == ".json": 136 | out.append(preview_json(fn, file_name)) 137 | elif fn.suffix in plaintext_files: 138 | if get_file_len_size(fn)[0] < 30: 139 | with open(fn) as f: 140 | content = f.read() 141 | if fn.suffix in code_files: 142 | content = f"```\n{content}\n```" 143 | out.append(f"-> {file_name} has content:\n\n{content}") 144 | 145 | result = "\n\n".join(out) 146 | 147 | # if the result is very long we generate a simpler version 148 | if len(result) > 6_000 and not simple: 149 | return generate( 150 | base_path, include_file_details=include_file_details, simple=True 151 | ) 152 | 153 | return result 154 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/utils/response.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | import black 5 | 6 | 7 | def wrap_code(code: str, lang="python") -> str: 8 | """Wraps code with three backticks.""" 9 | return f"```{lang}\n{code}\n```" 10 | 11 | 12 | def is_valid_python_script(script): 13 | """Check if a script is a valid Python script.""" 14 | try: 15 | compile(script, "", "exec") 16 | return True 17 | except SyntaxError: 18 | return False 19 | 20 | 21 | def extract_jsons(text): 22 | """Extract all JSON objects from the text. Caveat: This function cannot handle nested JSON objects.""" 23 | json_objects = [] 24 | matches = re.findall(r"\{.*?\}", text, re.DOTALL) 25 | for match in matches: 26 | try: 27 | json_obj = json.loads(match) 28 | json_objects.append(json_obj) 29 | except json.JSONDecodeError: 30 | pass 31 | 32 | # Sometimes chatgpt-turbo forget the last curly bracket, so we try to add it back when no json is found 33 | if len(json_objects) == 0 and not text.endswith("}"): 34 | json_objects = extract_jsons(text + "}") 35 | if len(json_objects) > 0: 36 | return json_objects 37 | 38 | return json_objects 39 | 40 | 41 | def trim_long_string(string, threshold=5100, k=2500): 42 | # Check if the length of the string is longer than the threshold 43 | if len(string) > threshold: 44 | # Output the first k and last k characters 45 | first_k_chars = string[:k] 46 | last_k_chars = string[-k:] 47 | 48 | truncated_len = len(string) - 2 * k 49 | 50 | return f"{first_k_chars}\n ... [{truncated_len} characters truncated] ... \n{last_k_chars}" 51 | else: 52 | return string 53 | 54 | 55 | def extract_code(text): 56 | """Extract python code blocks from the text.""" 57 | parsed_codes = [] 58 | 59 | # When code is in a text or python block 60 | matches = re.findall(r"```(python)?\n*(.*?)\n*```", text, re.DOTALL) 61 | for match in matches: 62 | code_block = match[1] 63 | parsed_codes.append(code_block) 64 | 65 | # When the entire text is code or backticks of the code block is missing 66 | if len(parsed_codes) == 0: 67 | matches = re.findall(r"^(```(python)?)?\n?(.*?)\n?(```)?$", text, re.DOTALL) 68 | if matches: 69 | code_block = matches[0][2] 70 | parsed_codes.append(code_block) 71 | 72 | # validate the parsed codes 73 | valid_code_blocks = [ 74 | format_code(c) for c in parsed_codes if is_valid_python_script(c) 75 | ] 76 | return format_code("\n\n".join(valid_code_blocks)) 77 | 78 | 79 | def extract_text_up_to_code(s): 80 | """Extract (presumed) natural language text up to the start of the first code block.""" 81 | if "```" not in s: 82 | return "" 83 | return s[: s.find("```")].strip() 84 | 85 | 86 | def format_code(code) -> str: 87 | """Format Python code using Black.""" 88 | try: 89 | return black.format_str(code, mode=black.FileMode()) 90 | except black.parsing.InvalidInput: # type: ignore 91 | return code 92 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/utils/serialize.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from pathlib import Path 4 | from typing import Type, TypeVar 5 | import re 6 | 7 | import dataclasses_json 8 | from ..journal import Journal, Node 9 | 10 | 11 | def dumps_json(obj: dataclasses_json.DataClassJsonMixin): 12 | """Serialize dataclasses (such as Journals) to JSON.""" 13 | if isinstance(obj, Journal): 14 | obj = copy.deepcopy(obj) 15 | node2parent = {} 16 | for n in obj.nodes: 17 | if n.parent is not None: 18 | # Handle both Node objects and string IDs 19 | parent_id = n.parent.id if isinstance(n.parent, Node) else n.parent 20 | node2parent[n.id] = parent_id 21 | for n in obj.nodes: 22 | n.parent = None 23 | n.children = set() 24 | 25 | obj_dict = obj.to_dict() 26 | 27 | if isinstance(obj, Journal): 28 | obj_dict["node2parent"] = node2parent 29 | obj_dict["__version"] = "2" 30 | 31 | return json.dumps(obj_dict, separators=(",", ":")) 32 | 33 | 34 | def dump_json(obj: dataclasses_json.DataClassJsonMixin, path: Path): 35 | with open(path, "w") as f: 36 | f.write(dumps_json(obj)) 37 | 38 | 39 | G = TypeVar("G", bound=dataclasses_json.DataClassJsonMixin) 40 | 41 | 42 | def loads_json(s: str, cls: Type[G]) -> G: 43 | """Deserialize JSON to AIDE dataclasses.""" 44 | obj_dict = json.loads(s) 45 | obj = cls.from_dict(obj_dict) 46 | 47 | if isinstance(obj, Journal): 48 | id2nodes = {n.id: n for n in obj.nodes} 49 | for child_id, parent_id in obj_dict["node2parent"].items(): 50 | id2nodes[child_id].parent = id2nodes[parent_id] 51 | id2nodes[child_id].__post_init__() 52 | return obj 53 | 54 | 55 | def load_json(path: Path, cls: Type[G]) -> G: 56 | with open(path, "r") as f: 57 | return loads_json(f.read(), cls) 58 | 59 | 60 | def parse_markdown_to_dict(content: str): 61 | """ 62 | Reads a file that contains lines of the form: 63 | 64 | "Key": "Value", 65 | "Another Key": "Another Value", 66 | ... 67 | 68 | including possible multi-line values, and returns a Python dictionary. 69 | """ 70 | 71 | pattern = r'"([^"]+)"\s*:\s*"([^"]*?)"(?:,\s*|\s*$)' 72 | 73 | matches = re.findall(pattern, content, flags=re.DOTALL) 74 | 75 | data_dict = {} 76 | for key, value in matches: 77 | data_dict[key] = value 78 | 79 | return data_dict 80 | -------------------------------------------------------------------------------- /ai_scientist/treesearch/utils/viz_templates/template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 29 | 36 | 37 | 41 | 42 | 43 | 44 | 47 | AI Scientist-v2 Visualization 48 | 262 | 263 | 264 |
265 | 266 | 267 | 268 | 269 |
270 | 271 |
272 | 273 |
274 |         
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 | 294 |
295 | 296 |
297 | 298 | 299 | -------------------------------------------------------------------------------- /ai_scientist/utils/token_tracker.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Dict, Optional, List 3 | import tiktoken 4 | from collections import defaultdict 5 | import asyncio 6 | from datetime import datetime 7 | import logging 8 | 9 | 10 | class TokenTracker: 11 | def __init__(self): 12 | """ 13 | Token counts for prompt, completion, reasoning, and cached. 14 | Reasoning tokens are included in completion tokens. 15 | Cached tokens are included in prompt tokens. 16 | Also tracks prompts, responses, and timestamps. 17 | We assume we get these from the LLM response, and we don't count 18 | the tokens by ourselves. 19 | """ 20 | self.token_counts = defaultdict( 21 | lambda: {"prompt": 0, "completion": 0, "reasoning": 0, "cached": 0} 22 | ) 23 | self.interactions = defaultdict(list) 24 | 25 | self.MODEL_PRICES = { 26 | "gpt-4o-2024-11-20": { 27 | "prompt": 2.5 / 1000000, # $2.50 per 1M tokens 28 | "cached": 1.25 / 1000000, # $1.25 per 1M tokens 29 | "completion": 10 / 1000000, # $10.00 per 1M tokens 30 | }, 31 | "gpt-4o-2024-08-06": { 32 | "prompt": 2.5 / 1000000, # $2.50 per 1M tokens 33 | "cached": 1.25 / 1000000, # $1.25 per 1M tokens 34 | "completion": 10 / 1000000, # $10.00 per 1M tokens 35 | }, 36 | "gpt-4o-2024-05-13": { # this ver does not support cached tokens 37 | "prompt": 5.0 / 1000000, # $5.00 per 1M tokens 38 | "completion": 15 / 1000000, # $15.00 per 1M tokens 39 | }, 40 | "gpt-4o-mini-2024-07-18": { 41 | "prompt": 0.15 / 1000000, # $0.15 per 1M tokens 42 | "cached": 0.075 / 1000000, # $0.075 per 1M tokens 43 | "completion": 0.6 / 1000000, # $0.60 per 1M tokens 44 | }, 45 | "o1-2024-12-17": { 46 | "prompt": 15 / 1000000, # $15.00 per 1M tokens 47 | "cached": 7.5 / 1000000, # $7.50 per 1M tokens 48 | "completion": 60 / 1000000, # $60.00 per 1M tokens 49 | }, 50 | "o1-preview-2024-09-12": { 51 | "prompt": 15 / 1000000, # $15.00 per 1M tokens 52 | "cached": 7.5 / 1000000, # $7.50 per 1M tokens 53 | "completion": 60 / 1000000, # $60.00 per 1M tokens 54 | }, 55 | "o3-mini-2025-01-31": { 56 | "prompt": 1.1 / 1000000, # $1.10 per 1M tokens 57 | "cached": 0.55 / 1000000, # $0.55 per 1M tokens 58 | "completion": 4.4 / 1000000, # $4.40 per 1M tokens 59 | }, 60 | } 61 | 62 | def add_tokens( 63 | self, 64 | model: str, 65 | prompt_tokens: int, 66 | completion_tokens: int, 67 | reasoning_tokens: int, 68 | cached_tokens: int, 69 | ): 70 | self.token_counts[model]["prompt"] += prompt_tokens 71 | self.token_counts[model]["completion"] += completion_tokens 72 | self.token_counts[model]["reasoning"] += reasoning_tokens 73 | self.token_counts[model]["cached"] += cached_tokens 74 | 75 | def add_interaction( 76 | self, 77 | model: str, 78 | system_message: str, 79 | prompt: str, 80 | response: str, 81 | timestamp: datetime, 82 | ): 83 | """Record a single interaction with the model.""" 84 | self.interactions[model].append( 85 | { 86 | "system_message": system_message, 87 | "prompt": prompt, 88 | "response": response, 89 | "timestamp": timestamp, 90 | } 91 | ) 92 | 93 | def get_interactions(self, model: Optional[str] = None) -> Dict[str, List[Dict]]: 94 | """Get all interactions, optionally filtered by model.""" 95 | if model: 96 | return {model: self.interactions[model]} 97 | return dict(self.interactions) 98 | 99 | def reset(self): 100 | """Reset all token counts and interactions.""" 101 | self.token_counts = defaultdict( 102 | lambda: {"prompt": 0, "completion": 0, "reasoning": 0, "cached": 0} 103 | ) 104 | self.interactions = defaultdict(list) 105 | # self._encoders = {} 106 | 107 | def calculate_cost(self, model: str) -> float: 108 | """Calculate the cost for a specific model based on token usage.""" 109 | if model not in self.MODEL_PRICES: 110 | logging.warning(f"Price information not available for model {model}") 111 | return 0.0 112 | 113 | prices = self.MODEL_PRICES[model] 114 | tokens = self.token_counts[model] 115 | 116 | # Calculate cost for prompt and completion tokens 117 | if "cached" in prices: 118 | prompt_cost = (tokens["prompt"] - tokens["cached"]) * prices["prompt"] 119 | cached_cost = tokens["cached"] * prices["cached"] 120 | else: 121 | prompt_cost = tokens["prompt"] * prices["prompt"] 122 | cached_cost = 0 123 | completion_cost = tokens["completion"] * prices["completion"] 124 | 125 | return prompt_cost + cached_cost + completion_cost 126 | 127 | def get_summary(self) -> Dict[str, Dict[str, int]]: 128 | # return dict(self.token_counts) 129 | """Get summary of token usage and costs for all models.""" 130 | summary = {} 131 | for model, tokens in self.token_counts.items(): 132 | summary[model] = { 133 | "tokens": tokens.copy(), 134 | "cost (USD)": self.calculate_cost(model), 135 | } 136 | return summary 137 | 138 | 139 | # Global token tracker instance 140 | token_tracker = TokenTracker() 141 | 142 | 143 | def track_token_usage(func): 144 | @wraps(func) 145 | async def async_wrapper(*args, **kwargs): 146 | prompt = kwargs.get("prompt") 147 | system_message = kwargs.get("system_message") 148 | if not prompt and not system_message: 149 | raise ValueError( 150 | "Either 'prompt' or 'system_message' must be provided for token tracking" 151 | ) 152 | 153 | logging.info("args: ", args) 154 | logging.info("kwargs: ", kwargs) 155 | 156 | result = await func(*args, **kwargs) 157 | model = result.model 158 | timestamp = result.created 159 | 160 | if hasattr(result, "usage"): 161 | token_tracker.add_tokens( 162 | model, 163 | result.usage.prompt_tokens, 164 | result.usage.completion_tokens, 165 | result.usage.completion_tokens_details.reasoning_tokens, 166 | ( 167 | result.usage.prompt_tokens_details.cached_tokens 168 | if hasattr(result.usage, "prompt_tokens_details") 169 | else 0 170 | ), 171 | ) 172 | # Add interaction details 173 | token_tracker.add_interaction( 174 | model, 175 | system_message, 176 | prompt, 177 | result.choices[ 178 | 0 179 | ].message.content, # Assumes response is in content field 180 | timestamp, 181 | ) 182 | return result 183 | 184 | @wraps(func) 185 | def sync_wrapper(*args, **kwargs): 186 | prompt = kwargs.get("prompt") 187 | system_message = kwargs.get("system_message") 188 | if not prompt and not system_message: 189 | raise ValueError( 190 | "Either 'prompt' or 'system_message' must be provided for token tracking" 191 | ) 192 | result = func(*args, **kwargs) 193 | model = result.model 194 | timestamp = result.created 195 | logging.info("args: ", args) 196 | logging.info("kwargs: ", kwargs) 197 | 198 | if hasattr(result, "usage"): 199 | token_tracker.add_tokens( 200 | model, 201 | result.usage.prompt_tokens, 202 | result.usage.completion_tokens, 203 | result.usage.completion_tokens_details.reasoning_tokens, 204 | ( 205 | result.usage.prompt_tokens_details.cached_tokens 206 | if hasattr(result.usage, "prompt_tokens_details") 207 | else 0 208 | ), 209 | ) 210 | # Add interaction details 211 | token_tracker.add_interaction( 212 | model, 213 | system_message, 214 | prompt, 215 | result.choices[ 216 | 0 217 | ].message.content, # Assumes response is in content field 218 | timestamp, 219 | ) 220 | return result 221 | 222 | return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper 223 | -------------------------------------------------------------------------------- /ai_scientist/vlm.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from typing import Any 3 | import re 4 | import json 5 | import backoff 6 | import openai 7 | from PIL import Image 8 | from ai_scientist.utils.token_tracker import track_token_usage 9 | 10 | MAX_NUM_TOKENS = 4096 11 | 12 | AVAILABLE_VLMS = [ 13 | "gpt-4o-2024-05-13", 14 | "gpt-4o-2024-08-06", 15 | "gpt-4o-2024-11-20", 16 | "gpt-4o-mini-2024-07-18", 17 | "o3-mini", 18 | ] 19 | 20 | 21 | def encode_image_to_base64(image_path: str) -> str: 22 | """Convert an image to base64 string.""" 23 | with Image.open(image_path) as img: 24 | # Convert RGBA to RGB if necessary 25 | if img.mode == "RGBA": 26 | img = img.convert("RGB") 27 | 28 | # Save to bytes 29 | import io 30 | 31 | buffer = io.BytesIO() 32 | img.save(buffer, format="JPEG") 33 | image_bytes = buffer.getvalue() 34 | 35 | return base64.b64encode(image_bytes).decode("utf-8") 36 | 37 | 38 | @track_token_usage 39 | def make_llm_call(client, model, temperature, system_message, prompt): 40 | if "gpt" in model: 41 | return client.chat.completions.create( 42 | model=model, 43 | messages=[ 44 | {"role": "system", "content": system_message}, 45 | *prompt, 46 | ], 47 | temperature=temperature, 48 | max_tokens=MAX_NUM_TOKENS, 49 | n=1, 50 | stop=None, 51 | seed=0, 52 | ) 53 | elif "o1" in model or "o3" in model: 54 | return client.chat.completions.create( 55 | model=model, 56 | messages=[ 57 | {"role": "user", "content": system_message}, 58 | *prompt, 59 | ], 60 | temperature=1, 61 | n=1, 62 | seed=0, 63 | ) 64 | else: 65 | raise ValueError(f"Model {model} not supported.") 66 | 67 | 68 | @track_token_usage 69 | def make_vlm_call(client, model, temperature, system_message, prompt): 70 | if "gpt" in model: 71 | return client.chat.completions.create( 72 | model=model, 73 | messages=[ 74 | {"role": "system", "content": system_message}, 75 | *prompt, 76 | ], 77 | temperature=temperature, 78 | max_tokens=MAX_NUM_TOKENS, 79 | ) 80 | else: 81 | raise ValueError(f"Model {model} not supported.") 82 | 83 | 84 | def prepare_vlm_prompt(msg, image_paths, max_images): 85 | pass 86 | 87 | 88 | @backoff.on_exception( 89 | backoff.expo, 90 | ( 91 | openai.RateLimitError, 92 | openai.APITimeoutError, 93 | ), 94 | ) 95 | def get_response_from_vlm( 96 | msg: str, 97 | image_paths: str | list[str], 98 | client: Any, 99 | model: str, 100 | system_message: str, 101 | print_debug: bool = False, 102 | msg_history: list[dict[str, Any]] | None = None, 103 | temperature: float = 0.7, 104 | max_images: int = 25, 105 | ) -> tuple[str, list[dict[str, Any]]]: 106 | """Get response from vision-language model.""" 107 | if msg_history is None: 108 | msg_history = [] 109 | 110 | if model in AVAILABLE_VLMS: 111 | # Convert single image path to list for consistent handling 112 | if isinstance(image_paths, str): 113 | image_paths = [image_paths] 114 | 115 | # Create content list starting with the text message 116 | content = [{"type": "text", "text": msg}] 117 | 118 | # Add each image to the content list 119 | for image_path in image_paths[:max_images]: 120 | base64_image = encode_image_to_base64(image_path) 121 | content.append( 122 | { 123 | "type": "image_url", 124 | "image_url": { 125 | "url": f"data:image/jpeg;base64,{base64_image}", 126 | "detail": "low", 127 | }, 128 | } 129 | ) 130 | # Construct message with all images 131 | new_msg_history = msg_history + [{"role": "user", "content": content}] 132 | 133 | response = make_vlm_call( 134 | client, 135 | model, 136 | temperature, 137 | system_message=system_message, 138 | prompt=new_msg_history, 139 | ) 140 | 141 | content = response.choices[0].message.content 142 | new_msg_history = new_msg_history + [{"role": "assistant", "content": content}] 143 | else: 144 | raise ValueError(f"Model {model} not supported.") 145 | 146 | if print_debug: 147 | print() 148 | print("*" * 20 + " VLM START " + "*" * 20) 149 | for j, msg in enumerate(new_msg_history): 150 | print(f'{j}, {msg["role"]}: {msg["content"]}') 151 | print(content) 152 | print("*" * 21 + " VLM END " + "*" * 21) 153 | print() 154 | 155 | return content, new_msg_history 156 | 157 | 158 | def create_client(model: str) -> tuple[Any, str]: 159 | """Create client for vision-language model.""" 160 | if model in [ 161 | "gpt-4o-2024-05-13", 162 | "gpt-4o-2024-08-06", 163 | "gpt-4o-2024-11-20", 164 | "gpt-4o-mini-2024-07-18", 165 | "o3-mini", 166 | ]: 167 | print(f"Using OpenAI API with model {model}.") 168 | return openai.OpenAI(), model 169 | else: 170 | raise ValueError(f"Model {model} not supported.") 171 | 172 | 173 | def extract_json_between_markers(llm_output: str) -> dict | None: 174 | # Regular expression pattern to find JSON content between ```json and ``` 175 | json_pattern = r"```json(.*?)```" 176 | matches = re.findall(json_pattern, llm_output, re.DOTALL) 177 | 178 | if not matches: 179 | # Fallback: Try to find any JSON-like content in the output 180 | json_pattern = r"\{.*?\}" 181 | matches = re.findall(json_pattern, llm_output, re.DOTALL) 182 | 183 | for json_string in matches: 184 | json_string = json_string.strip() 185 | try: 186 | parsed_json = json.loads(json_string) 187 | return parsed_json 188 | except json.JSONDecodeError: 189 | # Attempt to fix common JSON issues 190 | try: 191 | # Remove invalid control characters 192 | json_string_clean = re.sub(r"[\x00-\x1F\x7F]", "", json_string) 193 | parsed_json = json.loads(json_string_clean) 194 | return parsed_json 195 | except json.JSONDecodeError: 196 | continue # Try next match 197 | 198 | return None # No valid JSON found 199 | 200 | 201 | @backoff.on_exception( 202 | backoff.expo, 203 | ( 204 | openai.RateLimitError, 205 | openai.APITimeoutError, 206 | ), 207 | ) 208 | def get_batch_responses_from_vlm( 209 | msg: str, 210 | image_paths: str | list[str], 211 | client: Any, 212 | model: str, 213 | system_message: str, 214 | print_debug: bool = False, 215 | msg_history: list[dict[str, Any]] | None = None, 216 | temperature: float = 0.7, 217 | n_responses: int = 1, 218 | max_images: int = 200, 219 | ) -> tuple[list[str], list[list[dict[str, Any]]]]: 220 | """Get multiple responses from vision-language model for the same input. 221 | 222 | Args: 223 | msg: Text message to send 224 | image_paths: Path(s) to image file(s) 225 | client: OpenAI client instance 226 | model: Name of model to use 227 | system_message: System prompt 228 | print_debug: Whether to print debug info 229 | msg_history: Previous message history 230 | temperature: Sampling temperature 231 | n_responses: Number of responses to generate 232 | 233 | Returns: 234 | Tuple of (list of response strings, list of message histories) 235 | """ 236 | if msg_history is None: 237 | msg_history = [] 238 | 239 | if model in [ 240 | "gpt-4o-2024-05-13", 241 | "gpt-4o-2024-08-06", 242 | "gpt-4o-2024-11-20", 243 | "gpt-4o-mini-2024-07-18", 244 | "o3-mini", 245 | ]: 246 | # Convert single image path to list 247 | if isinstance(image_paths, str): 248 | image_paths = [image_paths] 249 | 250 | # Create content list with text and images 251 | content = [{"type": "text", "text": msg}] 252 | for image_path in image_paths[:max_images]: 253 | base64_image = encode_image_to_base64(image_path) 254 | content.append( 255 | { 256 | "type": "image_url", 257 | "image_url": { 258 | "url": f"data:image/jpeg;base64,{base64_image}", 259 | "detail": "low", 260 | }, 261 | } 262 | ) 263 | 264 | # Construct message with all images 265 | new_msg_history = msg_history + [{"role": "user", "content": content}] 266 | 267 | # Get multiple responses 268 | response = client.chat.completions.create( 269 | model=model, 270 | messages=[ 271 | {"role": "system", "content": system_message}, 272 | *new_msg_history, 273 | ], 274 | temperature=temperature, 275 | max_tokens=MAX_NUM_TOKENS, 276 | n=n_responses, 277 | seed=0, 278 | ) 279 | 280 | # Extract content from all responses 281 | contents = [r.message.content for r in response.choices] 282 | new_msg_histories = [ 283 | new_msg_history + [{"role": "assistant", "content": c}] for c in contents 284 | ] 285 | else: 286 | raise ValueError(f"Model {model} not supported.") 287 | 288 | if print_debug: 289 | # Just print the first response 290 | print() 291 | print("*" * 20 + " VLM START " + "*" * 20) 292 | for j, msg in enumerate(new_msg_histories[0]): 293 | print(f'{j}, {msg["role"]}: {msg["content"]}') 294 | print(contents[0]) 295 | print("*" * 21 + " VLM END " + "*" * 21) 296 | print() 297 | 298 | return contents, new_msg_histories 299 | -------------------------------------------------------------------------------- /bfts_config.yaml: -------------------------------------------------------------------------------- 1 | # path to the task data directory 2 | data_dir: "data" 3 | preprocess_data: False 4 | 5 | goal: null 6 | eval: null 7 | 8 | log_dir: logs 9 | workspace_dir: workspaces 10 | 11 | # whether to copy the data to the workspace directory (otherwise it will be symlinked) 12 | # copying is recommended to prevent the agent from accidentally modifying the original data 13 | copy_data: True 14 | 15 | exp_name: run # a random experiment name will be generated if not provided 16 | 17 | # settings for code execution 18 | exec: 19 | timeout: 3600 20 | agent_file_name: runfile.py 21 | format_tb_ipython: False 22 | 23 | generate_report: True 24 | # LLM settings for final report from journal 25 | report: 26 | model: gpt-4o-2024-11-20 27 | temp: 1.0 28 | 29 | experiment: 30 | num_syn_datasets: 1 31 | 32 | debug: 33 | stage4: False 34 | 35 | # agent hyperparams 36 | agent: 37 | type: parallel 38 | num_workers: 4 39 | stages: 40 | stage1_max_iters: 20 41 | stage2_max_iters: 12 42 | stage3_max_iters: 12 43 | stage4_max_iters: 18 44 | # how many improvement iterations to run 45 | steps: 5 # if stage-specific max_iters are not provided, the agent will use this value for all stages 46 | # whether to instruct the agent to use CV (set to 1 to disable) 47 | k_fold_validation: 1 48 | multi_seed_eval: 49 | num_seeds: 3 # should be the same as num_workers if num_workers < 3. Otherwise, set it to be 3. 50 | # whether to instruct the agent to generate a prediction function 51 | expose_prediction: False 52 | # whether to provide the agent with a preview of the data 53 | data_preview: False 54 | 55 | # LLM settings for coding 56 | code: 57 | model: anthropic.claude-3-5-sonnet-20241022-v2:0 58 | temp: 1.0 59 | max_tokens: 12000 60 | 61 | # LLM settings for evaluating program output / tracebacks 62 | feedback: 63 | model: gpt-4o-2024-11-20 64 | # gpt-4o 65 | temp: 0.5 66 | max_tokens: 8192 67 | 68 | vlm_feedback: 69 | model: gpt-4o-2024-11-20 70 | temp: 0.5 71 | max_tokens: null 72 | 73 | search: 74 | max_debug_depth: 3 75 | debug_prob: 0.5 76 | num_drafts: 3 77 | -------------------------------------------------------------------------------- /docs/logo_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/AI-Scientist-v2/031126fa19df316e048d01f7e1c1f268e1b3206a/docs/logo_v1.png -------------------------------------------------------------------------------- /launch_scientist_bfts.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import json 3 | import argparse 4 | import shutil 5 | import torch 6 | import os 7 | import re 8 | import sys 9 | from datetime import datetime 10 | from ai_scientist.llm import create_client 11 | 12 | from contextlib import contextmanager 13 | from ai_scientist.treesearch.perform_experiments_bfts_with_agentmanager import ( 14 | perform_experiments_bfts, 15 | ) 16 | from ai_scientist.treesearch.bfts_utils import ( 17 | idea_to_markdown, 18 | edit_bfts_config_file, 19 | ) 20 | from ai_scientist.perform_plotting import aggregate_plots 21 | from ai_scientist.perform_writeup import perform_writeup 22 | from ai_scientist.perform_icbinb_writeup import ( 23 | perform_writeup as perform_icbinb_writeup, 24 | gather_citations, 25 | ) 26 | from ai_scientist.perform_llm_review import perform_review, load_paper 27 | from ai_scientist.perform_vlm_review import perform_imgs_cap_ref_review 28 | from ai_scientist.utils.token_tracker import token_tracker 29 | 30 | 31 | def print_time(): 32 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 33 | 34 | 35 | def save_token_tracker(idea_dir): 36 | with open(osp.join(idea_dir, "token_tracker.json"), "w") as f: 37 | json.dump(token_tracker.get_summary(), f) 38 | with open(osp.join(idea_dir, "token_tracker_interactions.json"), "w") as f: 39 | json.dump(token_tracker.get_interactions(), f) 40 | 41 | 42 | def parse_arguments(): 43 | parser = argparse.ArgumentParser(description="Run AI scientist experiments") 44 | parser.add_argument( 45 | "--writeup-type", 46 | type=str, 47 | default="icbinb", 48 | choices=["normal", "icbinb"], 49 | help="Type of writeup to generate (normal=8 page, icbinb=4 page)", 50 | ) 51 | parser.add_argument( 52 | "--load_ideas", 53 | type=str, 54 | default="ideas/i_cant_believe_its_not_better.json", 55 | help="Path to a JSON file containing pregenerated ideas", 56 | ) 57 | parser.add_argument( 58 | "--load_code", 59 | action="store_true", 60 | help="If set, load a Python file with same name as ideas file but .py extension", 61 | ) 62 | parser.add_argument( 63 | "--idea_idx", 64 | type=int, 65 | default=0, 66 | help="Index of the idea to run", 67 | ) 68 | parser.add_argument( 69 | "--add_dataset_ref", 70 | action="store_true", 71 | help="If set, add a HF dataset reference to the idea", 72 | ) 73 | parser.add_argument( 74 | "--writeup-retries", 75 | type=int, 76 | default=3, 77 | help="Number of writeup attempts to try", 78 | ) 79 | parser.add_argument( 80 | "--attempt_id", 81 | type=int, 82 | default=0, 83 | help="Attempt ID, used to distinguish same idea in different attempts in parallel runs", 84 | ) 85 | parser.add_argument( 86 | "--model_agg_plots", 87 | type=str, 88 | default="o3-mini-2025-01-31", 89 | help="Model to use for plot aggregation", 90 | ) 91 | parser.add_argument( 92 | "--model_writeup", 93 | type=str, 94 | default="o1-preview-2024-09-12", 95 | help="Model to use for writeup", 96 | ) 97 | parser.add_argument( 98 | "--model_citation", 99 | type=str, 100 | default="gpt-4o-2024-11-20", 101 | help="Model to use for citation gathering", 102 | ) 103 | parser.add_argument( 104 | "--num_cite_rounds", 105 | type=int, 106 | default=20, 107 | help="Number of citation rounds to perform", 108 | ) 109 | parser.add_argument( 110 | "--model_review", 111 | type=str, 112 | default="gpt-4o-2024-11-20", 113 | help="Model to use for review main text and captions", 114 | ) 115 | parser.add_argument( 116 | "--skip_writeup", 117 | action="store_true", 118 | help="If set, skip the writeup process", 119 | ) 120 | parser.add_argument( 121 | "--skip_review", 122 | action="store_true", 123 | help="If set, skip the review process", 124 | ) 125 | return parser.parse_args() 126 | 127 | 128 | def get_available_gpus(gpu_ids=None): 129 | if gpu_ids is not None: 130 | return [int(gpu_id) for gpu_id in gpu_ids.split(",")] 131 | return list(range(torch.cuda.device_count())) 132 | 133 | 134 | def find_pdf_path_for_review(idea_dir): 135 | pdf_files = [f for f in os.listdir(idea_dir) if f.endswith(".pdf")] 136 | reflection_pdfs = [f for f in pdf_files if "reflection" in f] 137 | if reflection_pdfs: 138 | # First check if there's a final version 139 | final_pdfs = [f for f in reflection_pdfs if "final" in f.lower()] 140 | if final_pdfs: 141 | # Use the final version if available 142 | pdf_path = osp.join(idea_dir, final_pdfs[0]) 143 | else: 144 | # Try to find numbered reflections 145 | reflection_nums = [] 146 | for f in reflection_pdfs: 147 | match = re.search(r"reflection[_.]?(\d+)", f) 148 | if match: 149 | reflection_nums.append((int(match.group(1)), f)) 150 | 151 | if reflection_nums: 152 | # Get the file with the highest reflection number 153 | highest_reflection = max(reflection_nums, key=lambda x: x[0]) 154 | pdf_path = osp.join(idea_dir, highest_reflection[1]) 155 | else: 156 | # Fall back to the first reflection PDF if no numbers found 157 | pdf_path = osp.join(idea_dir, reflection_pdfs[0]) 158 | return pdf_path 159 | 160 | 161 | @contextmanager 162 | def redirect_stdout_stderr_to_file(log_file_path): 163 | original_stdout = sys.stdout 164 | original_stderr = sys.stderr 165 | log = open(log_file_path, "a") 166 | sys.stdout = log 167 | sys.stderr = log 168 | try: 169 | yield 170 | finally: 171 | sys.stdout = original_stdout 172 | sys.stderr = original_stderr 173 | log.close() 174 | 175 | 176 | if __name__ == "__main__": 177 | args = parse_arguments() 178 | os.environ["AI_SCIENTIST_ROOT"] = os.path.dirname(os.path.abspath(__file__)) 179 | print(f"Set AI_SCIENTIST_ROOT to {os.environ['AI_SCIENTIST_ROOT']}") 180 | 181 | # Check available GPUs and adjust parallel processes if necessary 182 | available_gpus = get_available_gpus() 183 | print(f"Using GPUs: {available_gpus}") 184 | 185 | with open(args.load_ideas, "r") as f: 186 | ideas = json.load(f) 187 | print(f"Loaded {len(ideas)} pregenerated ideas from {args.load_ideas}") 188 | 189 | idea = ideas[args.idea_idx] 190 | 191 | date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 192 | idea_dir = f"experiments/{date}_{idea['Name']}_attempt_{args.attempt_id}" 193 | print(f"Results will be saved in {idea_dir}") 194 | os.makedirs(idea_dir, exist_ok=True) 195 | 196 | # Convert idea json to markdown file 197 | idea_path_md = osp.join(idea_dir, "idea.md") 198 | 199 | # If load_code is True, get the Python file with same name as JSON 200 | code = None 201 | if args.load_code: 202 | code_path = args.load_ideas.rsplit(".", 1)[0] + ".py" 203 | if os.path.exists(code_path): 204 | with open(code_path, "r") as f: 205 | code = f.read() 206 | else: 207 | print(f"Warning: Code file {code_path} not found") 208 | else: 209 | code_path = None 210 | 211 | idea_to_markdown(ideas[args.idea_idx], idea_path_md, code_path) 212 | 213 | dataset_ref_code = None 214 | if args.add_dataset_ref: 215 | dataset_ref_path = "hf_dataset_reference.py" 216 | if os.path.exists(dataset_ref_path): 217 | with open(dataset_ref_path, "r") as f: 218 | dataset_ref_code = f.read() 219 | else: 220 | print(f"Warning: Dataset reference file {dataset_ref_path} not found") 221 | dataset_ref_code = None 222 | 223 | if dataset_ref_code is not None and code is not None: 224 | added_code = dataset_ref_code + "\n" + code 225 | elif dataset_ref_code is not None and code is None: 226 | added_code = dataset_ref_code 227 | elif dataset_ref_code is None and code is not None: 228 | added_code = code 229 | else: 230 | added_code = None 231 | 232 | print(added_code) 233 | 234 | # Add code to idea json if it was loaded 235 | if added_code is not None: 236 | ideas[args.idea_idx]["Code"] = added_code 237 | 238 | # Store raw idea json 239 | idea_path_json = osp.join(idea_dir, "idea.json") 240 | with open(idea_path_json, "w") as f: 241 | json.dump(ideas[args.idea_idx], f, indent=4) 242 | 243 | config_path = "bfts_config.yaml" 244 | idea_config_path = edit_bfts_config_file( 245 | config_path, 246 | idea_dir, 247 | idea_path_json, 248 | ) 249 | 250 | perform_experiments_bfts(idea_config_path) 251 | experiment_results_dir = osp.join(idea_dir, "logs/0-run/experiment_results") 252 | if os.path.exists(experiment_results_dir): 253 | shutil.copytree( 254 | experiment_results_dir, 255 | osp.join(idea_dir, "experiment_results"), 256 | dirs_exist_ok=True, 257 | ) 258 | 259 | aggregate_plots(base_folder=idea_dir, model=args.model_agg_plots) 260 | 261 | shutil.rmtree(osp.join(idea_dir, "experiment_results")) 262 | 263 | save_token_tracker(idea_dir) 264 | 265 | if not args.skip_writeup: 266 | writeup_success = False 267 | citations_text = gather_citations( 268 | idea_dir, 269 | num_cite_rounds=args.num_cite_rounds, 270 | small_model=args.model_citation, 271 | ) 272 | for attempt in range(args.writeup_retries): 273 | print(f"Writeup attempt {attempt+1} of {args.writeup_retries}") 274 | if args.writeup_type == "normal": 275 | writeup_success = perform_writeup( 276 | base_folder=idea_dir, 277 | big_model=args.model_writeup, 278 | page_limit=8, 279 | citations_text=citations_text, 280 | ) 281 | else: 282 | writeup_success = perform_icbinb_writeup( 283 | base_folder=idea_dir, 284 | big_model=args.model_writeup, 285 | page_limit=4, 286 | citations_text=citations_text, 287 | ) 288 | if writeup_success: 289 | break 290 | 291 | if not writeup_success: 292 | print("Writeup process did not complete successfully after all retries.") 293 | 294 | save_token_tracker(idea_dir) 295 | 296 | if not args.skip_review and not args.skip_writeup: 297 | # Perform paper review if the paper exists 298 | pdf_path = find_pdf_path_for_review(idea_dir) 299 | if os.path.exists(pdf_path): 300 | print("Paper found at: ", pdf_path) 301 | paper_content = load_paper(pdf_path) 302 | client, client_model = create_client(args.model_review) 303 | review_text = perform_review(paper_content, client_model, client) 304 | review_img_cap_ref = perform_imgs_cap_ref_review( 305 | client, client_model, pdf_path 306 | ) 307 | with open(osp.join(idea_dir, "review_text.txt"), "w") as f: 308 | f.write(json.dumps(review_text, indent=4)) 309 | with open(osp.join(idea_dir, "review_img_cap_ref.json"), "w") as f: 310 | json.dump(review_img_cap_ref, f, indent=4) 311 | print("Paper review completed.") 312 | 313 | print("Start cleaning up processes") 314 | # Kill all mp and torch processes associated with this experiment 315 | import psutil 316 | import signal 317 | 318 | # Get the current process and all its children 319 | current_process = psutil.Process() 320 | children = current_process.children(recursive=True) 321 | 322 | # First try graceful termination 323 | for child in children: 324 | try: 325 | child.send_signal(signal.SIGTERM) 326 | except (psutil.NoSuchProcess, psutil.AccessDenied): 327 | continue 328 | 329 | # Wait briefly for processes to terminate 330 | gone, alive = psutil.wait_procs(children, timeout=3) 331 | 332 | # If any processes remain, force kill them 333 | for process in alive: 334 | try: 335 | process.kill() 336 | except (psutil.NoSuchProcess, psutil.AccessDenied): 337 | continue 338 | 339 | # Additional cleanup: find any orphaned processes containing specific keywords 340 | keywords = ["python", "torch", "mp", "bfts", "experiment"] 341 | for proc in psutil.process_iter(["name", "cmdline"]): 342 | try: 343 | # Check both process name and command line arguments 344 | cmdline = " ".join(proc.cmdline()).lower() 345 | if any(keyword in cmdline for keyword in keywords): 346 | proc.send_signal(signal.SIGTERM) 347 | proc.wait(timeout=3) 348 | if proc.is_running(): 349 | proc.kill() 350 | except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.TimeoutExpired): 351 | continue 352 | 353 | # Finally, terminate the current process 354 | # current_process.send_signal(signal.SIGTERM) 355 | # try: 356 | # current_process.wait(timeout=3) 357 | # except psutil.TimeoutExpired: 358 | # current_process.kill() 359 | 360 | # exit the program 361 | sys.exit(0) 362 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # LLM APIs 2 | anthropic 3 | backoff 4 | openai 5 | # Viz 6 | matplotlib 7 | pypdf 8 | pymupdf4llm 9 | seaborn 10 | # Common Requirements 11 | numpy 12 | transformers 13 | datasets 14 | tiktoken 15 | wandb 16 | tqdm 17 | rich 18 | humanize 19 | dataclasses-json 20 | funcy 21 | black 22 | genson 23 | shutup 24 | python-igraph 25 | coolname 26 | jsonschema 27 | omegaconf 28 | botocore 29 | boto3 30 | --------------------------------------------------------------------------------