├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── codebase_reorg.yml │ ├── exp_record.yml │ ├── feature_request.yml │ └── writing_task.yml ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── codespell.yml │ ├── isort.yml │ ├── mypy.yml │ ├── openhands-resolver.yml │ ├── pytest.yml │ └── ruff.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── .gitkeep ├── architecture.png └── tiny_scientist.png ├── backend └── app.py ├── config.template.toml ├── data └── .gitkeep ├── docs └── .gitkeep ├── experiments ├── baseline_results.txt ├── experiment.py ├── experiment_results.txt └── idea.json ├── frontend ├── package-lock.json ├── package.json ├── public │ ├── favicon.ico │ ├── index.html │ ├── logo192.png │ ├── logo512.png │ ├── manifest.json │ └── robots.txt └── src │ ├── App.css │ ├── App.js │ ├── App.test.js │ ├── components │ ├── FactorBlock.jsx │ ├── HypothesisCard.jsx │ ├── HypothesisFactorsAndScoresCard.jsx │ ├── TopNav.jsx │ └── TreePlotVisualization.jsx │ ├── images │ ├── evaluation.svg │ ├── exploration.svg │ ├── green.svg │ ├── grey.svg │ ├── logo.svg │ └── red.svg │ ├── index.css │ ├── index.js │ ├── reportWebVitals.js │ └── setupTests.js ├── poetry.lock ├── pyproject.toml ├── scripts ├── code.py ├── code.sh ├── demo.py ├── demo.sh ├── demo_deepseek.sh ├── drawer.py ├── drawer.sh ├── review.py ├── review.sh ├── search_code.py ├── search_code.sh ├── search_paper.py ├── search_paper.sh ├── think.py ├── think.sh ├── write.py └── write.sh ├── tests └── test_scientist.py └── tiny_scientist ├── __init__.py ├── coder.py ├── configs.py ├── data.py ├── prompts ├── coder_prompt.yaml ├── diagram_prompt.yaml ├── drawer_prompt.yaml ├── reviewer_prompt.yaml ├── thinker_prompt.yaml └── writer_prompt.yaml ├── reviewer.py ├── scientist.py ├── thinker.py ├── tool.py ├── utils ├── __init__.py ├── bib_manager.py ├── error_handler.py ├── input_formatter.py ├── llm.py ├── output_formatter.py ├── pricing.py └── water_marker.py └── writer.py /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛Bug Report 2 | description: File a bug report here 3 | title: "[BUG]: " 4 | labels: ["bug"] 5 | assignees: [""] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for taking the time to fill out this bug report 🤗 11 | Make sure there aren't any open/closed issues for this topic 😃 12 | 13 | - type: textarea 14 | id: bug-description 15 | attributes: 16 | label: Description of the bug 17 | description: Give us a brief description of what happened and what should have happened 18 | validations: 19 | required: true 20 | 21 | - type: textarea 22 | id: steps-to-reproduce 23 | attributes: 24 | label: Steps To Reproduce 25 | description: Steps to reproduce the behavior. 26 | placeholder: | 27 | 1. Go to '...' 28 | 2. Click on '...' 29 | 3. Scroll down to '...' 30 | 4. See error 31 | validations: 32 | required: true 33 | 34 | - type: textarea 35 | id: additional-information 36 | attributes: 37 | label: Additional Information 38 | description: | 39 | Provide any additional information such as logs, screenshots, likes, scenarios in which the bug occurs so that it facilitates resolving the issue. 40 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/codebase_reorg.yml: -------------------------------------------------------------------------------- 1 | name: 🧹Codebase Refactor 2 | description: Refactor, clean, format the codebase 3 | title: "[ORG]: " 4 | labels: ["refactor"] 5 | assignees: [""] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Please make sure this codebase refactor request hasn't been already submitted by someone by looking through other open/closed issues 11 | 12 | - type: textarea 13 | id: description 14 | attributes: 15 | label: Description 16 | description: Give us a brief description of the codebase refactor task you would like 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: additional-information 22 | attributes: 23 | label: Additional Information 24 | description: Give us some additional reason on why codebase refactor is necessary to do 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/exp_record.yml: -------------------------------------------------------------------------------- 1 | name: 🧪Experiment Record 2 | description: Describe experiment setting and results here 3 | title: "[EXP]: " 4 | labels: ["experiment"] 5 | assignees: [""] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Please make sure this experiment request hasn't been already submitted by someone by looking through other open/closed issues 11 | 12 | - type: textarea 13 | id: description 14 | attributes: 15 | label: Description 16 | description: Give us a brief description of the experimental setting and results you would like 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: additional-information 22 | attributes: 23 | label: Additional Information 24 | description: Give us some additional information on the experimental setting and results like learning rate, data selection , etc. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: ✨Feature Request 2 | description: Request a new feature or enhancement 3 | labels: ["enhancement"] 4 | title: "[FEAT]: " 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Please make sure this feature request hasn't been already submitted by someone by looking through other open/closed issues 10 | 11 | - type: textarea 12 | id: description 13 | attributes: 14 | label: Description 15 | description: Give us a brief description of the feature or enhancement you would like 16 | validations: 17 | required: true 18 | 19 | - type: textarea 20 | id: additional-information 21 | attributes: 22 | label: Additional Information 23 | description: Give us some additional information on the feature request like proposed solutions, links, screenshots, etc. 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/writing_task.yml: -------------------------------------------------------------------------------- 1 | name: 🖊️Writing Task 2 | description: Describe writing task here 3 | title: "[WRT]: " 4 | labels: ["writing"] 5 | assignees: [""] 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Please make sure this writing task request hasn't been already submitted by someone by looking through other open/closed issues 11 | 12 | - type: textarea 13 | id: description 14 | attributes: 15 | label: Description 16 | description: Give us a brief description of the writing task you would like 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: additional-information 22 | attributes: 23 | label: Additional Information 24 | description: Give us some additional information on the writing task like exptected length, main content, etc. 25 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | 7 | Closes # 8 | 9 | ## 📑 Description 10 | 11 | 12 | 16 | 17 | ## ✅ Checks 18 | 19 | - [ ] My pull request adheres to the code style of this project 20 | - [ ] My code requires changes to the documentation 21 | - [ ] I have updated the documentation as required 22 | - [ ] All the tests have passed 23 | - [ ] Branch name follows `type/descript` (e.g. `feature/add-llm-agents`) 24 | - [ ] Ready for code review 25 | 26 | ## ℹ Additional Information 27 | 28 | -------------------------------------------------------------------------------- /.github/workflows/codespell.yml: -------------------------------------------------------------------------------- 1 | name: codespell 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 6 | 7 | on: 8 | push: 9 | branches: [main] 10 | pull_request: 11 | branches: [main] 12 | 13 | jobs: 14 | codespell: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.10"] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install codespell==2.2.6 tomli==2.0.1 29 | - name: Spelling check with codespell 30 | run: | 31 | codespell -c pyproject.toml 32 | -------------------------------------------------------------------------------- /.github/workflows/isort.yml: -------------------------------------------------------------------------------- 1 | name: isort 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 6 | 7 | on: 8 | push: 9 | branches: [main] 10 | pull_request: 11 | branches: [main] 12 | 13 | jobs: 14 | isort: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.10"] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install isort==5.13.2 29 | - name: Run isort 30 | run: | 31 | isort . --check-only 32 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: Mypy 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 6 | 7 | on: 8 | push: 9 | branches: [main] 10 | pull_request: 11 | branches: [main] 12 | 13 | jobs: 14 | Static-Type-Checking: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | max-parallel: 5 18 | matrix: 19 | python-version: ["3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | curl -sSL https://install.python-poetry.org | python3 30 | poetry install --all-extras 31 | - name: Type-checking package with mypy 32 | run: | 33 | # Run this mypy instance against our main package. 34 | poetry run pip install types-protobuf==4.24.0.4 35 | poetry run mypy --config-file pyproject.toml . 36 | -------------------------------------------------------------------------------- /.github/workflows/openhands-resolver.yml: -------------------------------------------------------------------------------- 1 | name: Resolve Issue with OpenHands 2 | 3 | on: 4 | issues: 5 | types: [labeled] 6 | pull_request: 7 | types: [labeled] 8 | issue_comment: 9 | types: [created] 10 | pull_request_review_comment: 11 | types: [created] 12 | pull_request_review: 13 | types: [submitted] 14 | 15 | permissions: 16 | contents: write 17 | pull-requests: write 18 | issues: write 19 | 20 | jobs: 21 | call-openhands-resolver: 22 | uses: All-Hands-AI/OpenHands/.github/workflows/openhands-resolver.yml@main 23 | with: 24 | macro: ${{ vars.OPENHANDS_MACRO || '@openhands-agent' }} 25 | max_iterations: ${{ fromJson(vars.OPENHANDS_MAX_ITER || 50) }} 26 | base_container_image: ${{ vars.OPENHANDS_BASE_CONTAINER_IMAGE || '' }} 27 | LLM_MODEL: ${{ vars.LLM_MODEL || 'openai/gpt-4o' }} 28 | target_branch: ${{ vars.TARGET_BRANCH || 'main' }} 29 | secrets: 30 | PAT_TOKEN: ${{ secrets.PAT_TOKEN }} 31 | PAT_USERNAME: ${{ secrets.PAT_USERNAME }} 32 | LLM_API_KEY: ${{ secrets.LLM_API_KEY }} 33 | LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }} 34 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Pytest 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 6 | 7 | on: 8 | push: 9 | branches: [main] 10 | pull_request: 11 | branches: [main] 12 | 13 | jobs: 14 | Pytest: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | max-parallel: 5 18 | matrix: 19 | python-version: ["3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | curl -sSL https://install.python-poetry.org | python3 30 | poetry install --all-extras 31 | poetry run python -m spacy download en_core_web_sm 32 | - name: Test with pytest 33 | env: 34 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 35 | TOGETHERAI_API_KEY: ${{ secrets.TOGETHERAI_API_KEY }} 36 | run: | 37 | poetry run pytest 38 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: ruff 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 6 | 7 | on: 8 | push: 9 | branches: [main] 10 | pull_request: 11 | branches: [main] 12 | 13 | jobs: 14 | ruff: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.10"] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install ruff==0.5.1 29 | - name: Analysing the code with ruff 30 | run: | 31 | ruff check . 32 | - name: Format the code with ruff 33 | run: | 34 | ruff format . 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | **/experiments/* 155 | **/config.toml 156 | !**/.gitkeep 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | 10 | - repo: https://github.com/pre-commit/mirrors-prettier 11 | rev: v3.0.1 # Use the sha / tag you want to point at 12 | hooks: 13 | - id: prettier 14 | types_or: [html] 15 | 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: v0.3.5 # Ruff version 18 | hooks: 19 | - id: ruff 20 | types_or: [python, pyi, jupyter] 21 | args: [--fix, --config, pyproject.toml] 22 | 23 | - repo: https://github.com/pre-commit/mirrors-isort 24 | rev: v5.10.1 # Use the latest isort version 25 | hooks: 26 | - id: isort # This will sort imports automatically 27 | 28 | - repo: https://github.com/kynan/nbstripout 29 | rev: 0.6.0 30 | hooks: 31 | - id: nbstripout 32 | 33 | - repo: https://github.com/psf/black 34 | rev: 23.3.0 # or the latest version you prefer 35 | hooks: 36 | - id: black 37 | language_version: python3 38 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | . 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | . Translations are available at 128 | . 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Haofei Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 |

TinyScientist: A Lightweight Framework for Building Research Agents

6 | 7 |
8 | 9 | [![PyPI version](https://img.shields.io/pypi/v/tiny-scientist)](https://pypi.org/project/tiny-scientist/) 10 | [![Python 3.10](https://img.shields.io/badge/python-%E2%89%A53.10-blue)](https://www.python.org/downloads/release/python-3109/) 11 | [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-red)](https://github.com/hiyouga/LLaMA-Factory/pulls) 12 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/) 13 | [![bear-ified](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg)](https://beartype.readthedocs.io) 14 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 15 | 16 |
17 | 18 | # Introduction 19 | 20 | **Tiny-Scientist** is a lightweight, user-friendly framework for automating the entire lifecycle of scientific research—**from ideation to implementation, writing, and review**. Designed for flexibility, it integrates smoothly with your favorite LLMs and search tools. 21 | 22 | #### Core Features 23 | 24 | - 🧠 **Think**: Generate structured research ideas from an intent string. 25 | - 💻 **Code**: Automatically generate and run experiments based on the idea. 26 | - ✍️ **Write**: Convert your results and ideas into a conference-style paper. 27 | - 📝 **Review**: Review any form of paper and output structured feedback in JSON. 28 | 29 | #### Software Architecture 30 | 31 | Our codebase is structured around three core components to support an extensible framework: **core**, **tools**, and **formatters**. The **core** module provides essential functionalities, **tools** enhance and extend these core capabilities, and **formatters** handle input/output tasks such as LaTeX template rendering. 32 | 33 |

34 | architecture 35 |

36 | 37 | 38 | # Installation 39 | 40 | #### Option 1: Install via pip (recommended) 41 | 42 | ```bash 43 | pip install tiny-scientist 44 | ``` 45 | 46 | #### Option 2: Install from source 47 | 48 | ```bash 49 | # create conda environment 50 | conda create -n tiny-scientist python=3.10 51 | conda activate tiny-scientist 52 | 53 | # Install Poetry 54 | curl -sSL https://install.python-poetry.org | python3 55 | export PATH="$HOME/.local/bin:$PATH" 56 | 57 | # Install dependencies 58 | poetry install 59 | ``` 60 | 61 | # Get started 62 | 63 | Before running any code, set your API key: 64 | 65 | ```bash 66 | export OPENAI_API_KEY=your-key-here 67 | # or use DEEPSEEK_API_KEY, ANTHROPIC_API_KEY, or OPENROUTER_API_KEY 68 | ``` 69 | 70 | If you want to use local ollama models, set the API base: 71 | 72 | ```bash 73 | export OLLAMA_API_BASE=http://192.168.23.11:11434 74 | ``` 75 | 76 | You can then specify ollama models like so: `ollama/llama3.2:latest` for example. 77 | 78 | For LM Studio it is similar: 79 | 80 | ```bash 81 | export LM_STUDIO_API_BASE=http://localhost:1234/v1 82 | ``` 83 | 84 | but you do need to specify an API key, even if it's a dummy value: 85 | 86 | ```bash 87 | export LM_STUDIO_API_KEY=dummy-api-key 88 | ``` 89 | 90 | And the models are specified like so: `lm_studio/qwen2.5-coder-32b-instruct-mlx` 91 | 92 | For other openAI compatible backend providers, set the following variables: 93 | 94 | ```bash 95 | export OPENAI_API_BASE=http://192.168.9.14/v1 96 | export OPENAI_API_KEY=your-key-here 97 | ``` 98 | 99 | and specify your model like so: `openai/qwen3-30b-a3b` 100 | 101 | Now you can use Tiny-Scientist in Python with only a few lines of code: 102 | 103 | ```python 104 | from tiny_scientist import TinyScientist 105 | 106 | scientist = TinyScientist(model="gpt-4o") 107 | 108 | # Step 1: Generate a json-format research idea 109 | idea = scientist.think(intent="Benchmarking adaptive step size strategies using a convex quadratic optimization function") 110 | 111 | # Step 2: Run experiments (you can provide baseline_results if available) 112 | status, experiment_dir = scientist.code(idea=idea) 113 | 114 | # if the experiments run successfully 115 | if status is True: 116 | # Step 3: Write a paper 117 | pdf_path = scientist.write(idea=idea, experiment_dir=experiment_dir) 118 | 119 | # Step 4: Review the paper 120 | review = scientist.review(pdf_path=pdf_path) 121 | ``` 122 | 123 | # Managing API Keys (Optional) 124 | 125 | You can configure keys using a `.toml` file for convenience beyond exporting. 126 | 127 | #### Step 1: Copy the template 128 | 129 | ```bash 130 | cp config.template.toml config.toml 131 | ``` 132 | 133 | #### Step 2: Fill in your API credentials 134 | 135 | Edit `config.toml` to include your keys, such as: 136 | 137 | ```toml 138 | [core] 139 | llm_api_key = "xxxx" 140 | ``` 141 | 142 | No need to export environment variables manually—just set this once. 143 | 144 | # Developing 145 | 146 | #### Develop Demo 147 | To develop a demo (Both frontend and backend): 148 | ```bash 149 | python backend/app.py 150 | ``` 151 | ```bash 152 | cd frontend 153 | npm install 154 | npm start 155 | ``` 156 | # Q&A 157 | 158 | If you face "cairo"-related errors, cario is a system-level dependency, please run `conda install -c conda-forge cairo` or `brew install cairo`. 159 | 160 | If you face errors related to pdflatex, this is also a system-level dependency for latex rendering, please run `brew install --cask mactex`. 161 | 162 | # Contribution 163 | 164 | We’re working on extending support for more tools, models, and paper formats. Contributions welcome! 165 | 166 | # Citation 167 | 168 | ``` 169 | @misc{tinyscientist, 170 | author = {Haofei Yu and Keyang Xuan and Fenghai Li and Zijie Lei and Jiaxuan You}, 171 | title = {TinyScientist: A Lightweight Framework for Building Research Agents}, 172 | howpublished = {https://github.com/ulab-uiuc/tiny-scientist}, 173 | note = {Accessed: 2025-04-14}, 174 | year = {2025} 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /assets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/assets/.gitkeep -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/assets/architecture.png -------------------------------------------------------------------------------- /assets/tiny_scientist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/assets/tiny_scientist.png -------------------------------------------------------------------------------- /backend/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Optional, Union 3 | 4 | from flask import Flask, Response, jsonify, request, session 5 | from flask_cors import CORS 6 | 7 | from tiny_scientist.thinker import Thinker 8 | 9 | app = Flask(__name__) 10 | app.secret_key = "your-secret-key-here" 11 | CORS(app, supports_credentials=True) 12 | 13 | 14 | thinker: Optional[Thinker] = None 15 | 16 | 17 | # Initialize the Thinker 18 | @app.route("/api/configure", methods=["POST"]) 19 | def configure() -> Union[Response, tuple[Response, int]]: 20 | """Configure model and API key""" 21 | data = request.json 22 | if data is None: 23 | return jsonify({"error": "No JSON data provided"}), 400 24 | 25 | model = data.get("model") 26 | api_key = data.get("api_key") 27 | 28 | if not model or not api_key: 29 | return jsonify({"error": "Model and API key are required"}), 400 30 | 31 | # Map models to their environment variables 32 | env_var_map = { 33 | "deepseek-chat": "DEEPSEEK_API_KEY", 34 | "deepseek-reasoner": "DEEPSEEK_API_KEY", 35 | "gpt-4o": "OPENAI_API_KEY", 36 | "gpt-o1": "OPENAI_API_KEY", 37 | "claude-3.5-sonnet": "ANTHROPIC_API_KEY", 38 | } 39 | 40 | # Set the appropriate environment variable 41 | env_var = env_var_map.get(model) 42 | if env_var: 43 | os.environ[env_var] = api_key 44 | 45 | # Store in session 46 | session["model"] = model 47 | session["api_key"] = api_key 48 | session["configured"] = True 49 | 50 | # Initialize thinker with new model 51 | global thinker 52 | thinker = Thinker( 53 | model=model, 54 | tools=[], 55 | iter_num=0, 56 | output_dir="./", 57 | search_papers=False, 58 | generate_exp_plan=False, 59 | ) 60 | 61 | return jsonify({"status": "configured", "model": model}) 62 | 63 | 64 | @app.route("/api/generate-initial", methods=["POST"]) 65 | def generate_initial() -> Union[Response, tuple[Response, int]]: 66 | """Generate initial ideas from an intent (handleAnalysisIntentSubmit)""" 67 | if thinker is None: 68 | return jsonify({"error": "Thinker not configured"}), 400 69 | 70 | data = request.json 71 | if data is None: 72 | return jsonify({"error": "No JSON data provided"}), 400 73 | 74 | intent = data.get("intent") 75 | num_ideas = data.get("num_ideas", 3) 76 | 77 | # Generate ideas 78 | ideas = thinker.run(intent=intent, num_ideas=num_ideas) 79 | 80 | # Return in the format expected by TreePlot 81 | response = { 82 | "ideas": [ 83 | { 84 | "title": ( 85 | idea.get("Title", idea.get("Name", "Untitled")) 86 | if isinstance(idea, dict) 87 | else "Untitled" 88 | ), 89 | "content": format_idea_content(idea), 90 | } 91 | for idea in ideas 92 | ] 93 | } 94 | 95 | return jsonify(response) 96 | 97 | 98 | @app.route("/api/generate-children", methods=["POST"]) 99 | def generate_children() -> Union[Response, tuple[Response, int]]: 100 | """Generate child ideas (generateChildNodes)""" 101 | if thinker is None: 102 | return jsonify({"error": "Thinker not configured"}), 400 103 | 104 | data = request.json 105 | if data is None: 106 | return jsonify({"error": "No JSON data provided"}), 400 107 | 108 | parent_content = data.get("parent_content") 109 | context = data.get("context", "") 110 | 111 | # Combine parent content and context as the intent 112 | combined_intent = f"{parent_content}\nAdditional Context: {context}" 113 | ideas = thinker.run(intent=combined_intent, num_ideas=3) 114 | 115 | # Return in the format expected by TreePlot 116 | response = { 117 | "ideas": [ 118 | { 119 | "title": ( 120 | idea.get("Title", idea.get("Name", "Untitled")) 121 | if isinstance(idea, dict) 122 | else "Untitled" 123 | ), 124 | "content": format_idea_content(idea), 125 | } 126 | for idea in ideas 127 | ] 128 | } 129 | 130 | return jsonify(response) 131 | 132 | 133 | @app.route("/api/modify", methods=["POST"]) 134 | def modify_idea() -> Union[Response, tuple[Response, int]]: 135 | """Modify an idea (modifyHypothesisBasedOnModifications)""" 136 | if thinker is None: 137 | return jsonify({"error": "Thinker not configured"}), 400 138 | 139 | data = request.json 140 | if data is None: 141 | return jsonify({"error": "No JSON data provided"}), 400 142 | 143 | original_idea = data.get("original_idea") 144 | modifications = data.get("modifications") 145 | behind_idea = data.get("behind_idea") 146 | 147 | # Convert TreePlot format to Thinker format 148 | thinker_original = convert_to_thinker_format(original_idea) 149 | thinker_behind = convert_to_thinker_format(behind_idea) if behind_idea else None 150 | # Convert modifications to Thinker format 151 | thinker_mods = [] 152 | for mod in modifications: 153 | thinker_mods.append( 154 | {"metric": mod.get("metric"), "direction": mod.get("direction")} 155 | ) 156 | 157 | # Modify the idea 158 | modified_idea = thinker.modify_idea( 159 | original_idea=thinker_original, 160 | modifications=thinker_mods, 161 | behind_idea=thinker_behind, 162 | ) 163 | # Return in the format expected by TreePlot 164 | response = { 165 | "title": ( 166 | modified_idea.get("Title", modified_idea.get("Name", "Untitled")) 167 | if modified_idea 168 | else "Untitled" 169 | ), 170 | "content": format_idea_content(modified_idea), 171 | } 172 | return jsonify(response) 173 | 174 | 175 | @app.route("/api/merge", methods=["POST"]) 176 | def merge_ideas() -> Union[Response, tuple[Response, int]]: 177 | """Merge two ideas (mergeHypotheses)""" 178 | if thinker is None: 179 | return jsonify({"error": "Thinker not configured"}), 400 180 | 181 | data = request.json 182 | if data is None: 183 | return jsonify({"error": "No JSON data provided"}), 400 184 | 185 | idea_a = data.get("idea_a") 186 | idea_b = data.get("idea_b") 187 | # Convert TreePlot format to Thinker format 188 | thinker_idea_a = convert_to_thinker_format(idea_a) 189 | thinker_idea_b = convert_to_thinker_format(idea_b) 190 | # Merge ideas 191 | merged_idea = thinker.merge_ideas(idea_a=thinker_idea_a, idea_b=thinker_idea_b) 192 | 193 | # Return in the format expected by TreePlot 194 | response = { 195 | "title": ( 196 | merged_idea.get("Title", merged_idea.get("Name", "Untitled")) 197 | if merged_idea 198 | else "Untitled" 199 | ), 200 | "content": format_idea_content(merged_idea), 201 | } 202 | 203 | return jsonify(response) 204 | 205 | 206 | @app.route("/api/evaluate", methods=["POST"]) 207 | def evaluate_ideas() -> Union[Response, tuple[Response, int]]: 208 | """Evaluate ideas (evaluateHypotheses)""" 209 | if thinker is None: 210 | return jsonify({"error": "Thinker not configured"}), 400 211 | 212 | data = request.json 213 | if data is None: 214 | return jsonify({"error": "No JSON data provided"}), 400 215 | 216 | ideas = data.get("ideas") 217 | intent = data.get("intent") 218 | 219 | # Convert TreePlot format to Thinker format 220 | thinker_ideas = [convert_to_thinker_format(idea) for idea in ideas] 221 | 222 | # Rank ideas 223 | ranked_ideas = thinker.rank(ideas=thinker_ideas, intent=intent) 224 | 225 | # Return in the format expected by TreePlot 226 | # Include rankings that TreePlot will convert to scores 227 | response = [] 228 | for idea in ranked_ideas: 229 | response.append( 230 | { 231 | "id": idea.get("id"), 232 | "novelty_rank": idea.get("NoveltyRanking"), 233 | "novelty_rank_reason": idea.get("NoveltyReason", ""), 234 | "feasibility_rank": idea.get("FeasibilityRanking"), 235 | "feasibility_rank_reason": idea.get("FeasibilityReason", ""), 236 | "impact_rank": idea.get("ImpactRanking"), 237 | "impact_rank_reason": idea.get("ImpactReason", ""), 238 | } 239 | ) 240 | 241 | return jsonify(response) 242 | 243 | 244 | def format_idea_content(idea: Any) -> str: 245 | """Format Thinker idea into content for TreePlot - with standardized section headers""" 246 | if not isinstance(idea, dict): 247 | return "No content available" 248 | 249 | # Get content and ensure no trailing ** in any of the content sections 250 | problem = idea.get("Problem", "").strip().rstrip("*") 251 | importance = idea.get("Importance", "").strip().rstrip("*") 252 | feasibility = idea.get("Difficulty", "").strip().rstrip("*") 253 | novelty = idea.get("NoveltyComparison", "").strip().rstrip("*") 254 | 255 | return "\n\n".join( 256 | [ 257 | f"Problem: {problem}", 258 | f"Impact: {importance}", 259 | f"Feasibility: {feasibility}", 260 | f"Novelty: {novelty}", 261 | ] 262 | ) 263 | 264 | 265 | def convert_to_thinker_format(treeplot_idea: Any) -> Dict[str, Any]: 266 | """Convert TreePlot idea format to Thinker format""" 267 | if not isinstance(treeplot_idea, dict): 268 | return {} 269 | 270 | # Extract sections from content if possible 271 | content = treeplot_idea.get("content", "") 272 | 273 | problem = "" 274 | importance = "" 275 | difficulty = "" # Maps to Feasibility in frontend 276 | novelty_comparison = "" # Maps to Novelty in frontend 277 | approach = "" 278 | 279 | # Try to extract sections with more flexible pattern matching 280 | if content: 281 | sections = content.split("\n\n") 282 | for section in sections: 283 | # Remove all formatting variations and normalize 284 | section_lower = section.lower() 285 | 286 | if "problem" in section_lower: 287 | # Extract content after any form of "Problem:" heading 288 | problem = extract_section_content(section) 289 | if "impact" in section_lower: 290 | importance = extract_section_content(section) 291 | if "feasibility" in section_lower: 292 | # In frontend it's called Feasibility, in backend it's Difficulty 293 | difficulty = extract_section_content(section) 294 | if "novelty" in section_lower: 295 | # In frontend it can be Novelty or Novelty Comparison 296 | novelty_comparison = extract_section_content(section) 297 | if "approach" in section_lower: 298 | approach = extract_section_content(section) 299 | 300 | # Create Thinker format 301 | thinker_idea = { 302 | "id": treeplot_idea.get("id"), 303 | "Name": treeplot_idea.get("title"), 304 | "Title": treeplot_idea.get("title"), 305 | "Problem": problem, 306 | "Importance": importance, 307 | "Difficulty": difficulty, # Maps to Feasibility in frontend 308 | "NoveltyComparison": novelty_comparison, # Maps to Novelty in frontend 309 | "Approach": approach, 310 | } 311 | 312 | return thinker_idea 313 | 314 | 315 | def extract_section_content(section: str) -> str: 316 | """Helper function to extract content after section heading regardless of format""" 317 | # Check if section contains a colon (indicating a header) 318 | if ":" in section: 319 | # Split at the first colon to separate header from content 320 | parts = section.split(":", 1) 321 | if len(parts) > 1: 322 | # Return just the content part, removing any asterisks 323 | return parts[1].replace("**", "").strip() 324 | 325 | # If there's no colon or we couldn't extract properly, 326 | # just clean any formatting and return the whole section 327 | return section.replace("**", "").strip() 328 | 329 | 330 | if __name__ == "__main__": 331 | app.run(debug=True, port=8080, host="0.0.0.0") 332 | -------------------------------------------------------------------------------- /config.template.toml: -------------------------------------------------------------------------------- 1 | ###################### TinyScientist Configuration ###################### 2 | 3 | #################################### Core #################################### 4 | [core] 5 | # Base directory for storing experiments and results 6 | #workspace_base = "./experiments" 7 | 8 | # Name of the AI Scientist 9 | #name = "AI Scientist" 10 | 11 | # LLM Model 12 | model = "gpt-4o" 13 | 14 | # Temperature for controlling randomness in responses 15 | temperature = 0.75 16 | 17 | # API Key for general LLM API access 18 | llm_api_key = "" 19 | 20 | # S2 API Key for accessing scientific research data 21 | #s2_api_key = "" 22 | 23 | # GITHUB TOKEN for accessing GitHub repositories 24 | #github_token = "" 25 | 26 | # paper searching engine 27 | #engine = "semanticscholar" 28 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/data/.gitkeep -------------------------------------------------------------------------------- /docs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/docs/.gitkeep -------------------------------------------------------------------------------- /experiments/baseline_results.txt: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "baseline_transformer", 3 | "model": "Transformer", 4 | "dataset": "Wikitext-103", 5 | "optimizer": "Adam", 6 | "learning_rate": 0.001, 7 | "batch_size": 64, 8 | "epochs": 10, 9 | "metrics": { 10 | "validation_loss": 3.12, 11 | "train_loss": 2.85, 12 | "perplexity": 22.5, 13 | "accuracy": 74.6, 14 | "f1_score": 0.71 15 | }, 16 | "notes": "This is the baseline performance of a standard Transformer model using Adam optimizer." 17 | } 18 | -------------------------------------------------------------------------------- /experiments/experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Tuple 4 | 5 | import torch 6 | import torch.optim as optim 7 | from datasets import Dataset, load_dataset 8 | from torch.nn import CrossEntropyLoss, Module 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | from torch.utils.data import DataLoader 11 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 12 | from transformers.modeling_outputs import BaseModelOutput 13 | 14 | # Define model and dataset 15 | MODEL_NAME = "bert-base-uncased" 16 | DATASET_NAME = "glue" 17 | TASK_NAME = "sst2" 18 | 19 | 20 | def load_data() -> Tuple[Dataset, Dataset]: 21 | """Loads the dataset and prepares train/test splits.""" 22 | dataset = load_dataset(DATASET_NAME, TASK_NAME) 23 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 24 | 25 | def tokenize_function(examples: Dataset) -> Dataset: 26 | return tokenizer(examples["sentence"], truncation=True, padding="max_length") 27 | 28 | dataset = dataset.map(tokenize_function, batched=True) 29 | dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) 30 | return dataset["train"], dataset["validation"] 31 | 32 | 33 | class AdaptiveLRModel(Module): # type: ignore[misc] 34 | """Custom model wrapper for adaptive learning rate experiments.""" 35 | 36 | def __init__(self) -> None: 37 | super().__init__() 38 | self.model = AutoModelForSequenceClassification.from_pretrained( 39 | MODEL_NAME, num_labels=2 40 | ) 41 | 42 | def forward( 43 | self, input_ids: torch.Tensor, attention_mask: torch.Tensor 44 | ) -> BaseModelOutput: 45 | return self.model(input_ids=input_ids, attention_mask=attention_mask) 46 | 47 | 48 | def train_and_evaluate( 49 | output_dir: str, initial_lr: float = 5e-5, adapt_lr: bool = True 50 | ) -> None: 51 | """Trains the model with adaptive learning rates and evaluates performance.""" 52 | train_data, val_data = load_data() 53 | model = AdaptiveLRModel() 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | model.to(device) 56 | 57 | optimizer = optim.AdamW(model.parameters(), lr=initial_lr) 58 | scheduler = ( 59 | ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1) 60 | if adapt_lr 61 | else None 62 | ) 63 | 64 | loss_fn = CrossEntropyLoss() 65 | train_loader: DataLoader[Dataset] = DataLoader( 66 | train_data, batch_size=16, shuffle=True 67 | ) 68 | val_loader: DataLoader[Dataset] = DataLoader(val_data, batch_size=16) 69 | 70 | best_val_loss = float("inf") 71 | 72 | for epoch in range(3): 73 | model.train() 74 | running_loss = 0.0 75 | for batch in train_loader: 76 | optimizer.zero_grad() 77 | outputs = model( 78 | batch["input_ids"].to(device), batch["attention_mask"].to(device) 79 | ) 80 | loss = loss_fn(outputs.logits, batch["label"].to(device)) 81 | loss.backward() 82 | optimizer.step() 83 | running_loss += loss.item() 84 | 85 | model.eval() 86 | val_loss = 0.0 87 | with torch.no_grad(): 88 | for batch in val_loader: 89 | outputs = model( 90 | batch["input_ids"].to(device), batch["attention_mask"].to(device) 91 | ) 92 | loss = loss_fn(outputs.logits, batch["label"].to(device)) 93 | val_loss += loss.item() 94 | 95 | if scheduler: 96 | scheduler.step(val_loss) 97 | 98 | print( 99 | f"Epoch {epoch + 1} - Training Loss: {running_loss / len(train_loader):.4f}, " 100 | f"Validation Loss: {val_loss / len(val_loader):.4f}" 101 | ) 102 | 103 | if val_loss < best_val_loss: 104 | best_val_loss = val_loss 105 | torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth")) 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser( 110 | description="Experiment with Adaptive Learning Rate" 111 | ) 112 | parser.add_argument( 113 | "--out_dir", 114 | type=str, 115 | required=True, 116 | help="Output directory for model checkpoints", 117 | ) 118 | parser.add_argument("--lr", type=float, default=5e-5, help="Initial learning rate") 119 | parser.add_argument( 120 | "--adaptive", action="store_true", help="Use adaptive learning rates" 121 | ) 122 | args = parser.parse_args() 123 | 124 | os.makedirs(args.out_dir, exist_ok=True) 125 | train_and_evaluate(args.out_dir, initial_lr=args.lr, adapt_lr=args.adaptive) 126 | -------------------------------------------------------------------------------- /experiments/experiment_results.txt: -------------------------------------------------------------------------------- 1 | { 2 | "experiment_name": "adaptive_learning_rates", 3 | "model": "Transformer", 4 | "dataset": "Wikitext-103", 5 | "optimizer": "Adam with adaptive learning rates", 6 | "learning_rate_strategy": "Dynamic adjustment based on gradient variance", 7 | "batch_size": 64, 8 | "epochs": 10, 9 | "metrics": { 10 | "validation_loss": 2.85, 11 | "train_loss": 2.62, 12 | "perplexity": 18.9, 13 | "accuracy": 77.3, 14 | "f1_score": 0.75 15 | }, 16 | "comparison_with_baseline": { 17 | "validation_loss_reduction": 8.65, 18 | "accuracy_improvement": 2.7 19 | }, 20 | "notes": "The adaptive learning rate strategy improved convergence speed and final accuracy." 21 | } 22 | -------------------------------------------------------------------------------- /experiments/idea.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "adaptive_learning_rates", 3 | "Title": "Exploring Adaptive Learning Rate Strategies for Transformer Models", 4 | "Problem": "Transformer-based models are highly sensitive to learning rate schedules. Fixed or manually tuned learning rates often lead to suboptimal convergence and require extensive trial-and-error tuning to generalize across tasks and datasets. While optimizers like Adam and its variants are widely used, there remains a limited understanding of how real-time gradient-based adaptations could improve stability and efficiency during training.", 5 | "Importance": "Improving the adaptability of learning rates could substantially reduce instability during training and lessen the need for extensive hyperparameter tuning. This would democratize fine-tuning Transformer models across compute-constrained environments and improve reproducibility across diverse NLP tasks.", 6 | "Difficulty": "The main challenge lies in designing learning rate control mechanisms that dynamically respond to optimization signals like gradient variance without inducing instability or large computational overhead. It also requires careful integration into existing optimization frameworks and benchmarking across multiple tasks.", 7 | "NoveltyComparison": "Most previous work on learning rate schedules relies on fixed heuristics like cosine decay, warm-up steps, or cyclical patterns. In contrast, our approach dynamically adjusts learning rates based on statistical analysis of the gradient magnitude variance. This design draws from adaptive control theory and second-order optimization insights to provide more fine-grained and context-aware adaptation strategies.", 8 | "Approach": "This work proposes a method that integrates gradient variance monitoring into the optimizer. If gradient variance exceeds a predefined threshold, the learning rate is reduced to enhance stability. If variance is low, a more aggressive learning rate is applied to accelerate convergence. The method is implemented in a PyTorch-based optimizer and benchmarked on standard Transformer architectures.", 9 | "Experiment": { 10 | "Model": "We use a BERT-base Transformer model and implement a custom optimizer subclass based on AdamW. The optimizer tracks moving averages of gradient variance to adjust the learning rate dynamically. Baselines include standard AdamW with cosine decay, linear warm-up, and constant schedules. The implementation uses HuggingFace Transformers and PyTorch Lightning for training pipelines.", 11 | "Dataset": "We evaluate the approach using the GLUE benchmark suite. Key tasks include SST-2 for sentiment classification and MRPC for paraphrase detection. These tasks are selected due to their sensitivity to learning rate choices during fine-tuning. Datasets are loaded via HuggingFace Datasets with standard preprocessing and evaluation metrics.", 12 | "Metric": "We report task-specific accuracy and F1 scores. Additional metrics include convergence speed (epochs to 90% of peak performance), training stability (standard deviation across runs with different seeds), and robustness across different batch sizes. Each experiment is repeated 3 times to compute variance." 13 | }, 14 | "Interestingness": 8, 15 | "Feasibility": 7, 16 | "Novelty": 9, 17 | "IntentAlignment": 9, 18 | "Score": 8 19 | } 20 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hypo-eval", 3 | "version": "0.1.0", 4 | "private": true, 5 | "proxy": "http://localhost:8080", 6 | "dependencies": { 7 | "@testing-library/jest-dom": "^5.17.0", 8 | "@testing-library/react": "^13.4.0", 9 | "@testing-library/user-event": "^13.5.0", 10 | "d3": "^7.9.0", 11 | "lucide-react": "^0.460.0", 12 | "openai": "^4.73.1", 13 | "react": "^18.3.1", 14 | "react-dom": "^18.3.1", 15 | "react-scripts": "5.0.1", 16 | "recharts": "^2.13.3", 17 | "web-vitals": "^2.1.4" 18 | }, 19 | "scripts": { 20 | "start": "react-scripts start", 21 | "build": "react-scripts build", 22 | "test": "react-scripts test", 23 | "eject": "react-scripts eject", 24 | "shadcn-ui": "shadcn-ui" 25 | }, 26 | "eslintConfig": { 27 | "extends": [ 28 | "react-app", 29 | "react-app/jest" 30 | ] 31 | }, 32 | "browserslist": { 33 | "production": [ 34 | ">0.2%", 35 | "not dead", 36 | "not op_mini all" 37 | ], 38 | "development": [ 39 | "last 1 chrome version", 40 | "last 1 firefox version", 41 | "last 1 safari version" 42 | ] 43 | }, 44 | "devDependencies": { 45 | "@shadcn/ui": "^0.0.4" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /frontend/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/frontend/public/favicon.ico -------------------------------------------------------------------------------- /frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 17 | 18 | 27 | React App 28 | 29 | 30 | 31 |
32 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /frontend/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/frontend/public/logo192.png -------------------------------------------------------------------------------- /frontend/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/frontend/public/logo512.png -------------------------------------------------------------------------------- /frontend/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /frontend/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /frontend/src/App.css: -------------------------------------------------------------------------------- 1 | .App { 2 | text-align: center; 3 | } 4 | 5 | .App-logo { 6 | height: 40vmin; 7 | pointer-events: none; 8 | } 9 | 10 | @media (prefers-reduced-motion: no-preference) { 11 | .App-logo { 12 | animation: App-logo-spin infinite 20s linear; 13 | } 14 | } 15 | 16 | .App-header { 17 | background-color: #282c34; 18 | min-height: 100vh; 19 | display: flex; 20 | flex-direction: column; 21 | align-items: center; 22 | justify-content: center; 23 | font-size: calc(10px + 2vmin); 24 | color: white; 25 | } 26 | 27 | .App-link { 28 | color: #61dafb; 29 | } 30 | 31 | @keyframes App-logo-spin { 32 | from { 33 | transform: rotate(0deg); 34 | } 35 | to { 36 | transform: rotate(360deg); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /frontend/src/App.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import TreePlotVisualization from './components/TreePlotVisualization'; 3 | 4 | function App() { 5 | return ( 6 |
7 | 8 |
9 | ); 10 | } 11 | 12 | export default App; 13 | -------------------------------------------------------------------------------- /frontend/src/App.test.js: -------------------------------------------------------------------------------- 1 | import { render, screen } from '@testing-library/react'; 2 | import App from './App'; 3 | 4 | test('renders learn react link', () => { 5 | render(); 6 | const linkElement = screen.getByText(/learn react/i); 7 | expect(linkElement).toBeInTheDocument(); 8 | }); 9 | -------------------------------------------------------------------------------- /frontend/src/components/FactorBlock.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | 3 | /** 4 | * 单个因子(Novelty / Feasibility / Impact)分数与理由 5 | */ 6 | const FactorBlock = ({ 7 | label, 8 | color, 9 | score, 10 | reason, 11 | beforeVal, 12 | afterVal, 13 | showAfter, 14 | }) => { 15 | const hasBefore = beforeVal !== null && beforeVal !== undefined; 16 | const round = (v) => Math.round(parseFloat(v) || 0); 17 | 18 | /* 箭头 */ 19 | const Arrow = () => { 20 | if (!hasBefore) return null; 21 | const b = round(beforeVal); 22 | const a = round(afterVal); 23 | if (a > b) return ; 24 | if (a < b) return ; 25 | return ; 26 | }; 27 | 28 | return ( 29 |
38 |
39 | 48 | 51 | {label} 52 | 53 | {hasBefore ? ( 54 | <> 55 | {round(beforeVal)} 56 | 57 | {round(afterVal)} 58 | 59 | ) : ( 60 | Score: {round(score)} 61 | )} 62 |
63 | 64 |
72 | {reason || '(No reason provided)'} 73 |
74 |
75 | ); 76 | }; 77 | 78 | export default FactorBlock; 79 | -------------------------------------------------------------------------------- /frontend/src/components/HypothesisCard.jsx: -------------------------------------------------------------------------------- 1 | import React, { useState } from 'react'; 2 | 3 | /** 4 | * 显示单条假设,可在 Before / After 之间切换 5 | * 问题部分直接显示在标题下方,其余三个部分(Importance/Feasibility/Novelty)可通过标签切换 6 | */ 7 | const HypothesisCard = ({ node, showAfter, setShowAfter }) => { 8 | // 默认选中 Impact 标签 9 | const [activeSection, setActiveSection] = useState('Impact'); 10 | // 对活动内容使用过渡效果 11 | const [fadeState, setFadeState] = useState('visible'); 12 | 13 | const hasPrevious = !!node.previousState; 14 | const content = 15 | hasPrevious && !showAfter ? node.previousState.content : node.content; 16 | 17 | // 解析内容中的各个章节 18 | const parseContent = (content) => { 19 | // 定义需要识别的章节标题 (支持多种格式) 20 | const sections = [ 21 | { title: "Problem", regex: /\*\*Problem:\*\*|Problem:|Problem\s*\*\*/ }, 22 | { title: "Impact", regex: /\*\*Impact:\*\*|Impact:|Impact\s*\*\*/ }, 23 | { title: "Feasibility", regex: /\*\*Feasibility:\*\*|Feasibility:|Feasibility\s*\*\*/ }, 24 | { title: "Novelty", regex: /\*\*Novelty Comparison:\*\*|Novelty Comparison:|Novelty:|Novelty\s*\*\*/ } 25 | ]; 26 | 27 | // 如果内容为空,返回空数组 28 | if (!content) return []; 29 | 30 | // 提取各个章节 31 | const parsedSections = []; 32 | 33 | // 找出所有匹配的章节位置 34 | const allMatches = []; 35 | sections.forEach(section => { 36 | const displayTitle = section.displayTitle || section.title; 37 | let match; 38 | let tempContent = content; 39 | let offset = 0; 40 | 41 | // Find all occurrences of the section 42 | while ((match = tempContent.match(section.regex)) !== null) { 43 | const startPos = match.index + offset; 44 | const matchLength = match[0].length; 45 | 46 | allMatches.push({ 47 | title: displayTitle, 48 | position: startPos, 49 | length: matchLength 50 | }); 51 | 52 | // Move past this match for the next iteration 53 | offset += match.index + matchLength; 54 | tempContent = content.substring(offset); 55 | } 56 | }); 57 | 58 | // Sort matches by their position in the content 59 | allMatches.sort((a, b) => a.position - b.position); 60 | 61 | // Extract content between section headings 62 | for (let i = 0; i < allMatches.length; i++) { 63 | const currentMatch = allMatches[i]; 64 | const startPos = currentMatch.position + currentMatch.length; 65 | 66 | // Find the end position (either the next section or the end of content) 67 | let endPos = content.length; 68 | if (i < allMatches.length - 1) { 69 | endPos = allMatches[i + 1].position; 70 | } 71 | 72 | // Extract the content 73 | const sectionContent = content.substring(startPos, endPos).trim(); 74 | 75 | // Add to parsed sections if there's actual content 76 | if (sectionContent) { 77 | parsedSections.push({ 78 | title: currentMatch.title, 79 | content: sectionContent 80 | }); 81 | } 82 | } 83 | 84 | return parsedSections; 85 | }; 86 | 87 | // 从内容中获取所有部分 88 | const sections = parseContent(content); 89 | 90 | // 找到问题部分 91 | const problemSection = sections.find(section => section.title === 'Problem'); 92 | 93 | // 过滤出三个主要评分部分 94 | const scoreSections = sections.filter(section => 95 | ['Impact', 'Feasibility', 'Novelty'].includes(section.title) 96 | ); 97 | 98 | // 获取当前激活部分的内容 99 | const activeContent = scoreSections.find(section => section.title === activeSection)?.content || ''; 100 | 101 | // 获取评分 (从节点或前一状态) 102 | const getScore = (sectionName) => { 103 | const scoreMap = { 104 | 'Impact': 'impactScore', 105 | 'Feasibility': 'feasibilityScore', 106 | 'Novelty': 'noveltyScore' 107 | }; 108 | 109 | const scoreField = scoreMap[sectionName]; 110 | if (!scoreField) return null; 111 | 112 | return hasPrevious && !showAfter 113 | ? node.previousState[scoreField] 114 | : node[scoreField]; 115 | }; 116 | 117 | // 各部分的颜色 118 | const sectionColors = { 119 | "Impact": "#4040a1", // Blue 120 | "Feasibility": "#50394c", // Purple 121 | "Novelty": "#618685", // Green 122 | }; 123 | 124 | // 处理标签切换 125 | const handleSectionChange = (newSection) => { 126 | if (newSection === activeSection) return; 127 | 128 | // 立即更新标签,使颜色变化立即可见 129 | setActiveSection(newSection); 130 | 131 | // 设置过渡状态 132 | setFadeState('visible'); 133 | }; 134 | 135 | return ( 136 |
145 | {/* Toggle 按钮 */} 146 | {hasPrevious && ( 147 | 164 | )} 165 | 166 | {/* 标题 */} 167 |

168 | {node.title || 'Untitled'} 169 |

170 | 171 | {/* 问题部分 - 直接显示在标题下方 */} 172 | {problemSection && ( 173 |
181 | {problemSection.content} 182 |
183 | )} 184 | 185 | {/* 部分选择标签 */} 186 |
191 | {scoreSections.map(section => { 192 | const isActive = activeSection === section.title; 193 | const score = getScore(section.title); 194 | const color = sectionColors[section.title]; 195 | 196 | return ( 197 |
handleSectionChange(section.title)} 200 | style={{ 201 | padding: '8px 12px', 202 | marginRight: '10px', 203 | cursor: 'pointer', 204 | position: 'relative', 205 | fontWeight: isActive ? 600 : 400, 206 | color: isActive ? color : '#6b7280', 207 | borderBottom: isActive ? `2px solid ${color}` : 'none', 208 | display: 'flex', 209 | alignItems: 'center', 210 | transition: 'color 0.1s ease, border-bottom 0.1s ease', 211 | opacity: activeSection === section.title ? 1 : 0.8, 212 | }} 213 | > 214 | 224 | {section.title} 225 | {score !== null && ( 226 | 234 | {Math.round(score)} 235 | 236 | )} 237 |
238 | ); 239 | })} 240 |
241 | 242 | {/* 所选部分的内容 */} 243 |
254 | {activeContent} 255 |
256 |
257 | ); 258 | }; 259 | 260 | export default HypothesisCard; 261 | -------------------------------------------------------------------------------- /frontend/src/components/HypothesisFactorsAndScoresCard.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import FactorBlock from './FactorBlock'; 3 | 4 | /** 5 | * 组合 3 个因子块的卡片 6 | */ 7 | const HypothesisFactorsAndScoresCard = ({ node, isEvaluating, showAfter }) => { 8 | const hasPrev = !!node.previousState; 9 | 10 | if (isEvaluating) { 11 | return ( 12 |

13 | Scores will be displayed once evaluation is complete. 14 |

15 | ); 16 | } 17 | 18 | const pick = (field) => 19 | showAfter || !hasPrev ? node[field] : node.previousState[field]; 20 | 21 | return ( 22 | <> 23 | 32 | 33 | 42 | 43 | 52 | 53 | ); 54 | }; 55 | 56 | export default HypothesisFactorsAndScoresCard; 57 | -------------------------------------------------------------------------------- /frontend/src/components/TopNav.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | 3 | /** 4 | * 顶部导航栏:切换 Exploration / Evaluation 视图 5 | * @param {boolean} showTree 6 | * @param {Function} setShowTree 7 | */ 8 | const TopNav = ({ currentView, setCurrentView }) => { 9 | /* ---------- SVG 图标 ---------- */ 10 | const overviewIcon = ( 11 | 19 | 26 | 27 | ); 28 | 29 | const explorationIcon = ( 30 | 38 | 39 | 45 | 52 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | ); 67 | 68 | const evaluationIcon = ( 69 | 77 | 84 | 85 | 86 | 87 | 88 | 89 | ); 90 | 91 | /* ---------- 样式 ---------- */ 92 | const tabStyle = { 93 | display: 'flex', 94 | alignItems: 'center', 95 | padding: '10px 16px', 96 | cursor: 'pointer', 97 | borderRadius: '9999px', 98 | transition: 'all 0.2s ease', 99 | fontSize: '0.875rem', 100 | fontWeight: 500, 101 | }; 102 | 103 | const getActiveStyle = (viewName) => 104 | currentView === viewName 105 | ? { backgroundColor: '#fff', color: '#141414' } 106 | : { backgroundColor: 'transparent', color: '#A0AEC0' }; 107 | 108 | return ( 109 |
117 | {/* Overview */} 118 |
setCurrentView('overview')} 121 | > 122 | {overviewIcon} 123 | Home View 124 |
125 | {/* Exploration View */} 126 |
setCurrentView('exploration')} 129 | > 130 | {explorationIcon} 131 | Exploration View 132 |
133 | 134 | {/* Evaluation View */} 135 |
setCurrentView('evaluation')} 138 | > 139 | {evaluationIcon} 140 | Evaluation View 141 |
142 |
143 | ); 144 | }; 145 | 146 | export default TopNav; 147 | -------------------------------------------------------------------------------- /frontend/src/images/evaluation.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /frontend/src/images/exploration.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /frontend/src/images/green.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /frontend/src/images/grey.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /frontend/src/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /frontend/src/images/red.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /frontend/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 12 | monospace; 13 | } 14 | -------------------------------------------------------------------------------- /frontend/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | 10 | 11 | 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | reportWebVitals(); 18 | -------------------------------------------------------------------------------- /frontend/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /frontend/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom'; 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tiny-scientist" 3 | version = "0.0.4" 4 | description = "A lightweight framework for building research agents" 5 | authors = ["Haofei Yu "] 6 | license = "Apache 2.0 License" 7 | readme = "README.md" 8 | 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.10, <3.12" 12 | mypy = "^1.8.0" 13 | beartype = "*" 14 | pydantic = "^2.8.2" 15 | requests = "^2.28.0" 16 | pyyaml = "^6.0" 17 | backoff = "*" 18 | pyalex = "*" 19 | pymupdf = "^1.22.3" 20 | pymupdf4llm = "*" 21 | pypdf = "^5.3.1" 22 | anthropic = "*" 23 | google-generativeai = "*" 24 | openai = "*" 25 | aider-chat = "0.83.1" 26 | toml = "*" 27 | spacy = "^3.0.0" 28 | reportlab = "*" 29 | litellm = "*" 30 | rich = "*" 31 | cairosvg = "^2.7.1" 32 | together = "*" 33 | flask = "^3.0.0" 34 | flask-cors = "^4.0.0" 35 | 36 | [tool.poetry.group.dev.dependencies] 37 | pre-commit = "*" 38 | nbmake = "*" 39 | types-setuptools = "*" 40 | types-pyyaml = "^6.0.12.20250402" 41 | types-requests = "^2.31" 42 | types-toml = "^0.10" 43 | 44 | [tool.poetry.group.test.dependencies] 45 | pytest = "*" 46 | pytest-asyncio = "*" 47 | 48 | 49 | [build-system] 50 | requires = ["poetry-core"] 51 | build-backend = "poetry.core.masonry.api" 52 | 53 | [tool.mypy] 54 | ignore_missing_imports = true 55 | check_untyped_defs = true 56 | follow_imports = "normal" 57 | strict = true 58 | plugins = ["pydantic.mypy"] 59 | 60 | [tool.pytest.ini_options] 61 | testpaths = ["tests"] 62 | python_files = "test_*.py" 63 | 64 | [tool.codespell] 65 | ignore-words-list = "dout, te, indicies, astroid" 66 | skip = ["data"] 67 | 68 | [tool.isort] 69 | profile = "black" 70 | use_parentheses = true 71 | skip_gitignore = true 72 | multi_line_output = 3 73 | include_trailing_comma = true 74 | force_grid_wrap = 0 75 | line_length = 88 76 | 77 | [tool.black] 78 | line-length = 88 79 | target-version = ['py37', 'py38', 'py39', 'py310'] 80 | 81 | [tool.ruff] 82 | line-length = 88 83 | fix = true 84 | target-version = "py310" 85 | 86 | [tool.ruff.format] 87 | indent-style = "space" 88 | -------------------------------------------------------------------------------- /scripts/code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | from typing import Any, Dict 5 | 6 | # Import the Coder class - assuming it's in a module called "coder" 7 | # You may need to adjust this import based on your actual project structure 8 | from tiny_scientist.coder import Coder 9 | 10 | 11 | def create_sample_idea() -> Dict[str, Any]: 12 | """Create a sample experiment idea for testing.""" 13 | return { 14 | "Title": "Learning Rate Impact on Model Convergence", 15 | "Experiment": "Investigate how different learning rates affect the convergence speed and final performance of a simple neural network on MNIST dataset.", 16 | } 17 | 18 | 19 | def create_baseline_results() -> Dict[str, Any]: 20 | """Create sample baseline results for comparison.""" 21 | return { 22 | "accuracy": {"means": 0.92, "std": 0.015}, 23 | "training_time": {"means": 125.3, "std": 12.7}, 24 | "convergence_epoch": {"means": 8.5, "std": 1.2}, 25 | } 26 | 27 | 28 | def setup_experiment_directory(output_dir: str) -> None: 29 | """Set up the experiment directory with necessary files.""" 30 | os.makedirs(output_dir, exist_ok=True) 31 | 32 | # Create an empty experiment.py file 33 | with open(os.path.join(output_dir, "experiment.py"), "w") as f: 34 | f.write( 35 | """ 36 | # This is a placeholder experiment file that will be modified by the Coder 37 | import argparse 38 | import json 39 | import os 40 | 41 | def main() -> int: 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--out_dir", type=str, required=True) 44 | args = parser.parse_args() 45 | 46 | # Create output directory 47 | os.makedirs(os.path.join(os.path.dirname(__file__), args.out_dir), exist_ok=True) 48 | 49 | # Just return dummy results for testing 50 | results = { 51 | "accuracy": {"means": 0.94, "std": 0.01}, 52 | "training_time": {"means": 120.5, "std": 10.2}, 53 | "convergence_epoch": {"means": 7.8, "std": 0.9} 54 | } 55 | 56 | # Save results 57 | with open(os.path.join(os.path.dirname(__file__), args.out_dir, "final_info.json"), "w") as f: 58 | json.dump(results, f, indent=2) 59 | 60 | return 0 61 | 62 | if __name__ == "__main__": 63 | exit(main()) 64 | """ 65 | ) 66 | 67 | # Create an empty notes.txt file 68 | with open(os.path.join(output_dir, "notes.txt"), "w") as f: 69 | f.write( 70 | "# Experiment Notes\n\nThis file will contain notes about the experiment.\n" 71 | ) 72 | 73 | 74 | def main() -> None: 75 | parser = argparse.ArgumentParser(description="Run a trial of the Coder class") 76 | parser.add_argument( 77 | "--output_dir", 78 | type=str, 79 | default="./experiment_trial", 80 | help="Base directory for the experiment", 81 | ) 82 | parser.add_argument( 83 | "--model", 84 | type=str, 85 | default="gpt-4o", 86 | help="Model to use (e.g., llama3.1-405b, deepseek-coder-v2-0724)", 87 | ) 88 | parser.add_argument( 89 | "--max_iters", type=int, default=2, help="Maximum iterations per experiment" 90 | ) 91 | parser.add_argument( 92 | "--max_runs", type=int, default=2, help="Maximum experiment runs" 93 | ) 94 | parser.add_argument( 95 | "--prompt_template_dir", type=str, default="./configs", help="Config directory" 96 | ) 97 | args = parser.parse_args() 98 | 99 | # Set up the experiment directory 100 | setup_experiment_directory(args.output_dir) 101 | 102 | print(f"Setting up Coder with model: {args.model}") 103 | print(f"Base directory: {args.output_dir}") 104 | print(f"Max iterations: {args.max_iters}") 105 | print(f"Max runs: {args.max_runs}") 106 | 107 | # Create the Coder instance 108 | coder = Coder( 109 | output_dir=args.output_dir, 110 | model=args.model, 111 | max_iters=args.max_iters, 112 | max_runs=args.max_runs, 113 | prompt_template_dir=args.prompt_template_dir, 114 | ) 115 | 116 | # Create a sample idea and baseline results 117 | idea = create_sample_idea() 118 | baseline_results = create_baseline_results() 119 | 120 | print("\nStarting experiment...") 121 | print(f"Idea: {idea['Title']}") 122 | 123 | # Run the experiment 124 | success = coder.run(idea, baseline_results) 125 | 126 | if success: 127 | print("\nExperiment completed successfully!") 128 | print(f"Results and plots can be found in: {args.output_dir}") 129 | else: 130 | print("\nExperiment did not complete successfully.") 131 | print("Check the logs for more information.") 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /scripts/code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | python code.py --model gpt-4o --output_dir ../experiments 4 | -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from tiny_scientist.scientist import TinyScientist 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser(description="Run TinyScientist pipeline.") 10 | parser.add_argument( 11 | "--output_dir", 12 | type=str, 13 | default="experiments/demo", 14 | help="Base output directory", 15 | ) 16 | parser.add_argument( 17 | "--prompt_template_dir", 18 | type=str, 19 | default=None, 20 | help="Configuration directory with prompt YAML files", 21 | ) 22 | parser.add_argument("--model", type=str, default="gpt-4o", help="LLM model to use") 23 | parser.add_argument( 24 | "--template", 25 | type=str, 26 | default="acl", 27 | help="Paper format template (e.g. acl, iclr)", 28 | ) 29 | args = parser.parse_args() 30 | 31 | if os.path.exists(args.output_dir): 32 | import shutil 33 | 34 | shutil.rmtree(args.output_dir) 35 | print(f"🧹 Cleared existing directory: {args.output_dir}") 36 | os.makedirs(args.output_dir, exist_ok=True) 37 | 38 | # Construct experiment intent and baseline result 39 | baseline_results = { 40 | "experiment_name": "baseline_quadratic_optimization", 41 | "function": "f(x, y) = x^2 + y^2", 42 | "optimizer": "Gradient Descent", 43 | "step_size": 0.1, 44 | "iterations": 100, 45 | "metrics": {"final_function_value": 0.001, "steps_to_convergence": 85}, 46 | "notes": "This baseline uses fixed step-size gradient descent on a quadratic bowl. Adaptive step-size methods aim to converge faster.", 47 | } 48 | 49 | with open(os.path.join(args.output_dir, "baseline_results.txt"), "w") as f: 50 | json.dump(baseline_results, f, indent=2) 51 | 52 | # Instantiate TinyScientist and run pipeline 53 | scientist = TinyScientist( 54 | model=args.model, 55 | output_dir=args.output_dir, 56 | prompt_template_dir=args.prompt_template_dir, 57 | template=args.template, 58 | ) 59 | 60 | idea = scientist.think( 61 | intent="Evaluating Adaptive Step Sizes in Numerical Optimization" 62 | ) 63 | 64 | if isinstance(idea, list): 65 | idea = idea[0] 66 | 67 | status, experiment_dir = scientist.code( 68 | idea=idea, baseline_results=baseline_results 69 | ) 70 | if status is False: 71 | return 72 | pdf_path = scientist.write(idea=idea, experiment_dir=experiment_dir) 73 | scientist.review(pdf_path=pdf_path) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /scripts/demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | python demo.py --output_dir ../demo --model gpt-4o-2024-08-06 --template acl 4 | -------------------------------------------------------------------------------- /scripts/demo_deepseek.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | python demo.py --output_dir ../demo --prompt_template_dir ../configs --model deepseek-chat --template acl 4 | -------------------------------------------------------------------------------- /scripts/drawer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | 6 | from tiny_scientist.tool import DrawerTool 7 | 8 | 9 | def parse_args() -> argparse.Namespace: 10 | parser = argparse.ArgumentParser( 11 | description="Generate diagrams from text using an LLM." 12 | ) 13 | 14 | parser.add_argument( 15 | "--text", 16 | type=str, 17 | help="Text content to generate a diagram from", 18 | ) 19 | parser.add_argument( 20 | "--input-file", 21 | type=str, 22 | help="Path to a text file to generate a diagram from", 23 | ) 24 | parser.add_argument( 25 | "--model", 26 | type=str, 27 | default="claude-3-5-sonnet-20241022", 28 | help="LLM model to use (default: claude-3-5-sonnet-20241022)", 29 | ) 30 | parser.add_argument( 31 | "--output", 32 | type=str, 33 | default="diagram_output.json", 34 | help="Path to save the generated diagram as JSON", 35 | ) 36 | parser.add_argument( 37 | "--temperature", 38 | type=float, 39 | default=0.75, 40 | help="Temperature for LLM generation (default: 0.75)", 41 | ) 42 | parser.add_argument( 43 | "--example", 44 | type=str, 45 | help="Path to a file containing an example diagram for few-shot learning", 46 | ) 47 | return parser.parse_args() 48 | 49 | 50 | def main() -> int: 51 | args: argparse.Namespace = parse_args() 52 | 53 | if not args.text and not args.input_file: 54 | print("Error: Either --text or --input-file must be provided") 55 | return 1 56 | 57 | try: 58 | # Get text content from file or command line argument 59 | if args.input_file: 60 | with open(args.input_file, "r") as f: 61 | text = f.read() 62 | else: 63 | text = args.text 64 | 65 | # Get example if provided 66 | example = None 67 | if args.example and os.path.exists(args.example): 68 | with open(args.example, "r") as f: 69 | example = f.read() 70 | 71 | # Get prompt templates directory 72 | current_dir = os.path.dirname(os.path.realpath(__file__)) 73 | prompt_template_dir = os.path.join( 74 | os.path.dirname(current_dir), "tiny_scientist", "prompts" 75 | ) 76 | 77 | # Initialize DrawerTool 78 | drawer = DrawerTool( 79 | model=args.model, 80 | prompt_template_dir=prompt_template_dir, 81 | temperature=args.temperature, 82 | ) 83 | 84 | print(f"Generating diagram using {args.model}...") 85 | diagram = drawer.draw_diagram(text=text, example=example) 86 | 87 | if not diagram or not diagram.get("svg"): 88 | print("Failed to generate diagram.") 89 | return 1 90 | 91 | # Display summary 92 | if diagram.get("summary"): 93 | print("\nDiagram Summary:") 94 | print(diagram["summary"]) 95 | 96 | # Save results 97 | output_data = { 98 | "summary": diagram.get("summary", ""), 99 | "svg": diagram.get("svg", ""), 100 | } 101 | 102 | with open(args.output, "w") as f: 103 | json.dump(output_data, f, indent=4) 104 | print(f"\nDiagram saved to {args.output}") 105 | 106 | # Also save SVG directly if available 107 | if diagram.get("svg"): 108 | svg_path = args.output.replace(".json", ".svg") 109 | with open(svg_path, "w") as f: 110 | f.write(diagram["svg"]) 111 | print(f"SVG file saved to {svg_path}") 112 | 113 | except Exception as e: 114 | print(f"Error: {e}") 115 | return 1 116 | 117 | return 0 118 | 119 | 120 | if __name__ == "__main__": 121 | exit(main()) 122 | -------------------------------------------------------------------------------- /scripts/drawer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define default parameters 4 | MODEL="gpt-4o-2024-08-06" 5 | OUTPUT="diagram_output.json" 6 | TEMPERATURE=0.75 7 | 8 | # Check if input text or file is provided 9 | if [ "$1" == "" ]; then 10 | echo "Usage: ./drawer.sh " 11 | echo "Example: ./drawer.sh \"Design a system for text to image generation using LLMs\"" 12 | echo "Example: ./drawer.sh --input-file path/to/text_file.txt" 13 | exit 1 14 | fi 15 | 16 | # Determine if the input is a file or text 17 | if [[ "$1" == "--input-file" ]]; then 18 | # Run the DrawerTool script with an input file 19 | python3 drawer.py --input-file "$2" --model "$MODEL" --output "$OUTPUT" --temperature "$TEMPERATURE" 20 | else 21 | # Run the DrawerTool script with text input 22 | python3 drawer.py --text "$*" --model "$MODEL" --output "$OUTPUT" --temperature "$TEMPERATURE" 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/review.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os.path as osp 5 | from typing import List 6 | 7 | from tiny_scientist.reviewer import Reviewer 8 | from tiny_scientist.tool import BaseTool 9 | from tiny_scientist.utils.llm import AVAILABLE_LLMS 10 | 11 | 12 | def parse_args() -> argparse.Namespace: 13 | parser = argparse.ArgumentParser( 14 | description="Perform a paper review using the specified model." 15 | ) 16 | parser.add_argument( 17 | "--paper", 18 | type=str, 19 | default="../example/attention.pdf", 20 | help="Path to the paper text/PDF to be reviewed, or raw text directly.", 21 | ) 22 | parser.add_argument( 23 | "--model", 24 | type=str, 25 | default="gpt-4o", 26 | choices=AVAILABLE_LLMS, 27 | help="Model to use for reviewing.", 28 | ) 29 | parser.add_argument( 30 | "--reviews-num", 31 | type=int, 32 | default=3, 33 | help="Number of independent reviews to generate (default: 3).", 34 | ) 35 | parser.add_argument( 36 | "--reflection-num", 37 | type=int, 38 | default=2, 39 | help="Number of re_review (reflection) iterations per review (default: 2).", 40 | ) 41 | parser.add_argument( 42 | "--temperature", type=float, default=0.75, help="Temperature for the LLM." 43 | ) 44 | parser.add_argument( 45 | "--output", 46 | type=str, 47 | default="review.json", 48 | help="Path to save the final review JSON.", 49 | ) 50 | parser.add_argument( 51 | "--config-dir", 52 | type=str, 53 | default="../configs", 54 | help="Path to directory containing model configurations.", 55 | ) 56 | return parser.parse_args() 57 | 58 | 59 | def main() -> int: 60 | args = parse_args() 61 | 62 | dummy_tools: List[BaseTool] = [] 63 | reviewer = Reviewer( 64 | tools=dummy_tools, 65 | num_reviews=args.reviews_num, 66 | num_reflections=args.reflection_num, 67 | model=args.model, 68 | temperature=args.temperature, 69 | prompt_template_dir=args.prompt_template_dir, 70 | ) 71 | 72 | final_review = reviewer.run(args.paper) 73 | 74 | # Print and save the final meta-review. 75 | print("\nFinal Review JSON:") 76 | print(json.dumps(final_review, indent=4)) 77 | 78 | output_path = osp.abspath(args.output) 79 | with open(output_path, "w", encoding="utf-8") as f: 80 | json.dump(final_review, f, indent=4) 81 | print(f"\nReview saved to {output_path}") 82 | 83 | return 0 84 | 85 | 86 | if __name__ == "__main__": 87 | exit(main()) 88 | -------------------------------------------------------------------------------- /scripts/review.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | python review.py --model gpt-4o-2024-08-06 4 | -------------------------------------------------------------------------------- /scripts/search_code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | 5 | from tiny_scientist.tool import CodeSearchTool 6 | 7 | 8 | def parse_args() -> argparse.Namespace: 9 | parser = argparse.ArgumentParser( 10 | description="Search for GitHub repositories or code snippets." 11 | ) 12 | 13 | parser.add_argument( 14 | "--query", 15 | type=str, 16 | required=True, 17 | help="Search query for GitHub repositories or code", 18 | ) 19 | parser.add_argument( 20 | "--result-limit", 21 | type=int, 22 | default=10, 23 | help="Number of results to retrieve (default: 10)", 24 | ) 25 | parser.add_argument( 26 | "--search-type", 27 | type=str, 28 | choices=["repositories", "code"], 29 | default="repositories", 30 | help="Type of GitHub search: repositories or code", 31 | ) 32 | parser.add_argument( 33 | "--output", type=str, help="Path to save retrieved search results as JSON" 34 | ) 35 | return parser.parse_args() 36 | 37 | 38 | def main() -> int: 39 | args: argparse.Namespace = parse_args() 40 | 41 | try: 42 | # Initialize CodeSearchTool instance 43 | searcher = CodeSearchTool() 44 | print(f"Searching for {args.search_type} on GitHub...") 45 | 46 | results = searcher.run(query=args.query, search_type=args.search_type) 47 | 48 | if not results: 49 | print("No results found.") 50 | return 1 51 | 52 | # Display results 53 | print(json.dumps(results, indent=4)) 54 | 55 | # Save results if output path is provided 56 | if args.output: 57 | with open(args.output, "w") as f: 58 | json.dump(results, f, indent=4) 59 | print(f"Results saved to {args.output}") 60 | 61 | except Exception as e: 62 | print(f"Error: {e}") 63 | return 1 64 | 65 | return 0 66 | 67 | 68 | if __name__ == "__main__": 69 | exit(main()) 70 | -------------------------------------------------------------------------------- /scripts/search_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define default parameters 4 | QUERY="machine learning" 5 | RESULT_LIMIT=10 6 | SEARCH_TYPE="repositories" # Change to "code" for searching code snippets 7 | OUTPUT="github_results.json" 8 | 9 | # Run the CodeSearchTool script 10 | python3 search_code.py --query "$QUERY" --result-limit $RESULT_LIMIT --search-type "$SEARCH_TYPE" --output "$OUTPUT" 11 | -------------------------------------------------------------------------------- /scripts/search_paper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | 5 | from tiny_scientist.tool import PaperSearchTool 6 | 7 | 8 | def parse_args() -> argparse.Namespace: 9 | parser = argparse.ArgumentParser(description="Search for academic papers.") 10 | 11 | parser.add_argument( 12 | "--query", 13 | type=str, 14 | required=True, 15 | help="Search query for retrieving academic papers", 16 | ) 17 | parser.add_argument( 18 | "--result-limit", 19 | type=int, 20 | default=10, 21 | help="Number of results to retrieve (default: 10)", 22 | ) 23 | parser.add_argument( 24 | "--engine", 25 | type=str, 26 | choices=["semanticscholar", "openalex"], 27 | default="semanticscholar", 28 | help="Search engine for retrieving papers", 29 | ) 30 | parser.add_argument( 31 | "--output", type=str, help="Path to save retrieved papers as JSON" 32 | ) 33 | return parser.parse_args() 34 | 35 | 36 | def main() -> int: 37 | args: argparse.Namespace = parse_args() 38 | 39 | try: 40 | # Initialize PaperSearchTool instance 41 | searcher = PaperSearchTool() 42 | print(f"Searching for papers using {args.engine} engine...") 43 | 44 | results = searcher.run(query=args.query) 45 | 46 | if not results: 47 | print("No papers found.") 48 | return 1 49 | 50 | # Display results 51 | print(json.dumps(results, indent=4)) 52 | 53 | # Save results if output path is provided 54 | if args.output: 55 | with open(args.output, "w") as f: 56 | json.dump(results, f, indent=4) 57 | print(f"Results saved to {args.output}") 58 | 59 | except Exception as e: 60 | print(f"Error: {e}") 61 | return 1 62 | 63 | return 0 64 | 65 | 66 | if __name__ == "__main__": 67 | print("rahs") 68 | exit(main()) 69 | -------------------------------------------------------------------------------- /scripts/search_paper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define default parameters 4 | QUERY="deep learning in healthcare" 5 | RESULT_LIMIT=10 6 | ENGINE="semanticscholar" 7 | OUTPUT="papers.json" 8 | 9 | # Run the PaperSearchTool script 10 | python3 search_paper.py --query "$QUERY" --result-limit $RESULT_LIMIT --engine "$ENGINE" --output "$OUTPUT" 11 | -------------------------------------------------------------------------------- /scripts/think.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | from typing import Any, Dict, cast 6 | 7 | from tiny_scientist.thinker import Thinker 8 | from tiny_scientist.utils.input_formatter import InputFormatter 9 | from tiny_scientist.utils.llm import AVAILABLE_LLMS 10 | 11 | 12 | def parse_args() -> argparse.Namespace: 13 | parser = argparse.ArgumentParser(description="Generate and evaluate research ideas") 14 | parser.add_argument( 15 | "--base-dir", type=str, default="../ideas", help="Path to base directory" 16 | ) 17 | parser.add_argument( 18 | "--model", 19 | type=str, 20 | default="gpt-4o", 21 | choices=AVAILABLE_LLMS, 22 | help="Model to use for generating ideas", 23 | ) 24 | parser.add_argument( 25 | "--load-existing", action="store_true", help="Load existing ideas from file" 26 | ) 27 | parser.add_argument( 28 | "--num-ideas", type=int, default=1, help="Number of new ideas to generate" 29 | ) 30 | parser.add_argument( 31 | "--num-reflections", 32 | type=int, 33 | default=5, 34 | help="Number of reflection iterations per idea", 35 | ) 36 | parser.add_argument( 37 | "--check-novelty", action="store_true", help="Check novelty of generated ideas" 38 | ) 39 | parser.add_argument( 40 | "--engine", 41 | type=str, 42 | choices=["semanticscholar", "openalex"], 43 | default="semanticscholar", 44 | help="Search engine for checking novelty", 45 | ) 46 | parser.add_argument( 47 | "--temperature", 48 | type=float, 49 | default=0.75, 50 | help="Temperature for idea generation", 51 | ) 52 | parser.add_argument( 53 | "--output", 54 | type=str, 55 | help="Path to save ideas JSON (defaults to ideas.json in experiment directory)", 56 | ) 57 | parser.add_argument( 58 | "--config-dir", 59 | type=str, 60 | default="../configs", 61 | help="Path to directory containing model configurations", 62 | ) 63 | parser.add_argument( 64 | "--initial-idea", type=str, help="Path to JSON file containing initial idea(s)" 65 | ) 66 | parser.add_argument( 67 | "--pdf", type=str, help="Path to the PDF paper for idea generation" 68 | ) 69 | return parser.parse_args() 70 | 71 | 72 | def load_initial_idea(filepath: str) -> Dict[str, Any]: 73 | """Load initial idea from a JSON file.""" 74 | try: 75 | with open(filepath, "r") as f: 76 | idea = json.load(f) 77 | print(f"Loaded initial idea from {filepath}") 78 | return cast(Dict[str, Any], idea) 79 | except (FileNotFoundError, json.JSONDecodeError) as e: 80 | print(f"Error loading initial idea: {e}") 81 | raise ValueError("Valid initial idea must be provided") 82 | 83 | 84 | def create_default_idea() -> Dict[str, Any]: 85 | """Create a default initial idea.""" 86 | default_idea = { 87 | "Name": "baseline", 88 | "Title": "Baseline Implementation", 89 | "Experiment": "Implement baseline model with standard parameters", 90 | "Interestingness": 5, 91 | "Feasibility": 9, 92 | "Novelty": 3, 93 | "Score": 6, 94 | } 95 | return default_idea 96 | 97 | 98 | def main() -> int: 99 | args = parse_args() 100 | formatter = InputFormatter() 101 | 102 | pdf_content: str = "" 103 | if args.pdf: 104 | try: 105 | pdf_dict = formatter.parse_paper_pdf_to_json(args.pdf) 106 | pdf_content = json.dumps(pdf_dict) # Convert to string immediately 107 | print("Loaded PDF content for idea generation.") 108 | except Exception as e: 109 | print(f"Error loading PDF: {e}") 110 | 111 | try: 112 | thinker = Thinker( 113 | model=args.model, 114 | output_dir=args.output_dir, 115 | prompt_template_dir=args.prompt_template_dir, 116 | temperature=args.temperature, 117 | iter_num=args.num_reflections, 118 | tools=[], 119 | ) 120 | 121 | # Get initial idea 122 | if args.load_existing: 123 | try: 124 | ideas_path = os.path.join(args.output_dir, "ideas.json") 125 | with open(ideas_path, "r") as f: 126 | loaded_ideas = json.load(f) 127 | if loaded_ideas: 128 | initial_idea = loaded_ideas[0] # Take the first idea 129 | print(f"Loaded existing idea from {ideas_path}") 130 | else: 131 | print("No valid existing ideas found. Using default idea.") 132 | initial_idea = create_default_idea() 133 | except (FileNotFoundError, json.JSONDecodeError): 134 | print("No valid existing ideas found. Using default idea.") 135 | initial_idea = create_default_idea() 136 | elif args.initial_idea: 137 | initial_idea = load_initial_idea(args.initial_idea) 138 | else: 139 | print("No initial idea provided. Using default idea.") 140 | initial_idea = create_default_idea() 141 | 142 | # Generate ideas and refine them by calling run() 143 | final_result = thinker.run( 144 | initial_idea, 145 | num_ideas=args.num_ideas, 146 | check_novelty=args.check_novelty, 147 | pdf_content=pdf_content, # Already a string 148 | ) 149 | 150 | output_path = args.output or os.path.join(args.output_dir, "refined_ideas.json") 151 | with open(output_path, "w") as f: 152 | json.dump(final_result, f, indent=4) 153 | print(f"\nRefined ideas saved to {output_path}") 154 | 155 | except Exception as e: 156 | print(f"Error: {e}") 157 | return 1 158 | 159 | return 0 160 | 161 | 162 | if __name__ == "__main__": 163 | exit(main()) 164 | -------------------------------------------------------------------------------- /scripts/think.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | python think.py --model gpt-4o-2024-08-06 --num-ideas 2 4 | -------------------------------------------------------------------------------- /scripts/write.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | 6 | from tiny_scientist.writer import Writer 7 | 8 | 9 | def parse_args() -> argparse.Namespace: 10 | parser = argparse.ArgumentParser(description="Write paper.") 11 | 12 | parser.add_argument( 13 | "--experiment", 14 | type=str, 15 | required=True, 16 | help="Path to the experiment directory containing experiment details", 17 | ) 18 | parser.add_argument( 19 | "--model", 20 | type=str, 21 | default="gpt-4o", 22 | help="Model to use for writing and refinement", 23 | ) 24 | parser.add_argument( 25 | "--num-cite-rounds", 26 | type=int, 27 | default=2, 28 | help="Number of citation addition rounds", 29 | ) 30 | parser.add_argument("--template", type=str, help="Template of the output paper") 31 | parser.add_argument( 32 | "--engine", 33 | type=str, 34 | choices=["semanticscholar", "openalex"], 35 | default="semanticscholar", 36 | help="Search engine for citation retrieval", 37 | ) 38 | parser.add_argument( 39 | "--output", 40 | type=str, 41 | help="Path to save final paper PDF (defaults to experiment directory)", 42 | ) 43 | 44 | return parser.parse_args() 45 | 46 | 47 | def main() -> int: 48 | args: argparse.Namespace = parse_args() 49 | 50 | try: 51 | # Initialize Writer 52 | writer = Writer( 53 | model=args.model, 54 | output_dir=args.experiment, 55 | template=args.template, 56 | ) 57 | 58 | # idea should be import from args.experiemnt and idea.json 59 | with open(os.path.join(args.experiment, "idea.json"), "r") as f: 60 | idea = json.load(f) 61 | 62 | # Perform paper writing 63 | print("\nStarting paper write-up...") 64 | writer.run( 65 | idea=idea, 66 | experiment_dir=args.experiment, 67 | ) 68 | 69 | except Exception as e: 70 | print(f"Error: {e}") 71 | return 1 72 | 73 | return 0 74 | 75 | 76 | if __name__ == "__main__": 77 | exit(main()) 78 | -------------------------------------------------------------------------------- /scripts/write.sh: -------------------------------------------------------------------------------- 1 | python write.py --template acl --experiment ../experiments 2 | -------------------------------------------------------------------------------- /tests/test_scientist.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from unittest.mock import Mock 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def mock_client() -> Mock: 10 | return Mock() 11 | 12 | 13 | @pytest.fixture 14 | def mock_model() -> str: 15 | return "gpt-4-test" 16 | 17 | 18 | @pytest.fixture 19 | def test_output_dir(tmp_path: Path) -> Path: 20 | # Create a subdirectory under tmp_path 21 | output_dir = tmp_path / "test_scientist" 22 | output_dir.mkdir() 23 | 24 | # Create required files using Path methods 25 | experiment_py = output_dir / "experiment.py" 26 | experiment_py.write_text("print('Test experiment')") 27 | 28 | prompt_json = output_dir / "prompt.json" 29 | prompt_json.write_text( 30 | json.dumps({"task_description": "Test task", "system": "Test system prompt"}) 31 | ) 32 | 33 | seed_ideas_json = output_dir / "seed_ideas.json" 34 | seed_ideas_json.write_text( 35 | json.dumps( 36 | [ 37 | { 38 | "Name": "test_idea", 39 | "Title": "Test Idea", 40 | "Experiment": "Test experiment description", 41 | } 42 | ] 43 | ) 44 | ) 45 | 46 | # Return a Path object, not a string 47 | return output_dir 48 | 49 | 50 | def test_mock() -> bool: 51 | return True 52 | -------------------------------------------------------------------------------- /tiny_scientist/__init__.py: -------------------------------------------------------------------------------- 1 | from .coder import Coder 2 | from .reviewer import Reviewer 3 | from .scientist import TinyScientist 4 | from .thinker import Thinker 5 | from .writer import Writer 6 | 7 | __all__ = ["Coder", "Reviewer", "Thinker", "Writer", "TinyScientist"] 8 | -------------------------------------------------------------------------------- /tiny_scientist/coder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | import shutil 4 | import subprocess 5 | import sys 6 | from subprocess import TimeoutExpired 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | from aider.coders import Coder as AiderCoder 10 | from aider.io import InputOutput 11 | from aider.models import Model 12 | from rich import print 13 | 14 | from .configs import Config 15 | from .utils.llm import create_client, get_response_from_llm 16 | 17 | 18 | class Coder: 19 | def __init__( 20 | self, 21 | model: str, 22 | output_dir: str, 23 | max_iters: int = 4, 24 | max_runs: int = 5, 25 | max_stderr_output: int = 1500, 26 | prompt_template_dir: Optional[str] = None, 27 | chat_history: Optional[str] = None, 28 | ): 29 | """Initialize the ExperimentCoder with configuration and Aider setup.""" 30 | self.client, self.model = create_client(model) 31 | self.output_dir = osp.abspath(output_dir) 32 | self.max_iters = max_iters 33 | self.max_runs = max_runs 34 | self.max_stderr_output = max_stderr_output 35 | self.config = Config() 36 | 37 | # Load prompts 38 | self.prompts = self.config.prompt_template.coder_prompt 39 | 40 | def setup_aider( 41 | self, model: str, fnames: List[str], chat_history: Optional[str] = None 42 | ) -> None: 43 | """Setup Aider coder with the specified model.""" 44 | io = InputOutput( 45 | yes=True, chat_history_file=chat_history or f"{self.output_dir}/aider.txt" 46 | ) 47 | 48 | if model == "deepseek-coder-v2-0724": 49 | main_model = Model("deepseek/deepseek-coder") 50 | elif model == "deepseek-chat": 51 | main_model = Model("deepseek/deepseek-chat") 52 | elif model == "llama3.1-405b": 53 | main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct") 54 | else: 55 | main_model = Model(model) 56 | 57 | self.coder = AiderCoder.create( 58 | main_model=main_model, 59 | fnames=fnames, # Will be set per operation 60 | io=io, 61 | stream=False, 62 | use_git=False, 63 | edit_format="diff", 64 | ) 65 | 66 | def run( 67 | self, idea: Dict[str, Any], baseline_results: Optional[Dict[str, Any]] = {} 68 | ) -> Tuple[bool, str]: 69 | fnames = [ 70 | osp.join(self.output_dir, "experiment.py"), 71 | osp.join(self.output_dir, "notes.txt"), 72 | ] 73 | 74 | self.setup_aider(self.model, fnames) 75 | 76 | # Run experiments 77 | success = self._run_experiment_loop(idea, baseline_results) 78 | 79 | if not success: 80 | # Even if failed, save an empty result file to avoid breaking writer 81 | save_path = osp.join(self.output_dir, "experiment_results.txt") 82 | with open(save_path, "w") as f: 83 | json.dump({}, f, indent=2) 84 | print( 85 | f"[System] No experiments succeeded, but wrote empty result to {save_path}" 86 | ) 87 | return False, self.output_dir 88 | 89 | self._update_notes() 90 | 91 | result_summary = {} 92 | for run_num in range(1, self.max_runs + 1): 93 | run_dir = osp.join(self.output_dir, f"run_{run_num}") 94 | result_path = osp.join(run_dir, "final_info.json") 95 | if osp.exists(result_path): 96 | with open(result_path, "r") as f: 97 | result_summary[f"run_{run_num}"] = json.load(f) 98 | 99 | # Save combined results 100 | save_path = osp.join(self.output_dir, "experiment_results.txt") 101 | with open(save_path, "w") as f: 102 | json.dump(result_summary, f, indent=2) 103 | 104 | print(f"[System] All experiment results saved to {save_path}") 105 | 106 | return True, self.output_dir 107 | 108 | def _format_experiment_for_prompt( 109 | self, exp: Dict[str, str] 110 | ) -> Tuple[str, str, str]: 111 | llm_prompt = self.prompts.experiment_keyword_prompt.format( 112 | model=exp["Model"], dataset=exp["Dataset"], metric=exp["Metric"] 113 | ) 114 | 115 | llm_output, _ = get_response_from_llm( 116 | msg=llm_prompt, 117 | client=self.client, 118 | model=self.model, 119 | system_message="You are helping an AI agent extract implementation-relevant key information from an experiment description.", 120 | ) 121 | 122 | try: 123 | # Clean and parse JSON block 124 | llm_output_clean = llm_output.strip().strip("`").strip("json").strip() 125 | keyword_info = json.loads(llm_output_clean) 126 | except json.JSONDecodeError: 127 | print("[System] Failed to parse LLM keyword JSON.") 128 | keyword_info = { 129 | "model": [], 130 | "dataset": [], 131 | "metric": [], 132 | } 133 | 134 | model_kw = ", ".join(keyword_info.get("model", [])) 135 | dataset_kw = ", ".join(keyword_info.get("dataset", [])) 136 | metric_kw = ", ".join(keyword_info.get("metric", [])) 137 | 138 | return model_kw, dataset_kw, metric_kw 139 | 140 | def _summarize_to_bullets(self, paragraph: str) -> str: 141 | # Simple sentence-splitting bullet conversion 142 | lines = paragraph.strip().split(". ") 143 | return "\n".join(f"- {line.strip().rstrip('.')}" for line in lines if line) 144 | 145 | def _run_experiment_loop( 146 | self, idea: Dict[str, Any], baseline_results: Optional[Dict[str, Any]] = {} 147 | ) -> bool: 148 | """Run the experiment loop with multiple iterations if needed.""" 149 | current_iter = 0 150 | run_time = 1 151 | 152 | # Initial prompt 153 | model, dataset, metric = self._format_experiment_for_prompt(idea["Experiment"]) 154 | 155 | next_prompt = self.prompts.experiment_prompt.format( 156 | title=idea["Title"], 157 | problem=idea["Problem"], 158 | novelty=idea["NoveltyComparison"], 159 | approach=idea["Approach"], 160 | model=model, 161 | dataset=dataset, 162 | metric=metric, 163 | max_runs=self.max_runs, 164 | baseline_results=baseline_results, 165 | ) 166 | 167 | while run_time < self.max_runs + 1: 168 | if current_iter >= self.max_iters: 169 | print("Max iterations reached") 170 | return False 171 | 172 | coder_out = self.coder.run(next_prompt) 173 | exp_path = osp.join(self.output_dir, "experiment.py") 174 | 175 | if "ALL_COMPLETED" in coder_out: 176 | return True 177 | 178 | if osp.exists(exp_path): 179 | with open(exp_path) as f: 180 | content = f.read() 181 | if "..." in content: 182 | print("[System] Placeholder '...' detected. Attempting fix.") 183 | self.coder.run( 184 | "Please replace all placeholders (`...`) in experiment.py with complete runnable code." 185 | ) 186 | 187 | return_code, message = self._run_single_experiment(run_time) 188 | 189 | if return_code == 0: 190 | run_time += 1 191 | current_iter = 0 192 | next_prompt = message 193 | else: 194 | print("[System] Experiment run failed. Attempting fix with Aider...") 195 | next_prompt = self.prompts.experiment_error_prompt.format( 196 | message=message, 197 | Title=idea["Title"], 198 | Experiment=idea["Experiment"], 199 | run_time=run_time, 200 | max_runs=self.max_runs, 201 | ) 202 | 203 | current_iter += 1 204 | 205 | return current_iter < self.max_iters 206 | 207 | def _run_single_experiment( 208 | self, run_num: int, timeout: int = 7200 209 | ) -> Tuple[int, str]: 210 | """Run a single experiment iteration.""" 211 | 212 | shutil.copy( 213 | osp.join(self.output_dir, "experiment.py"), 214 | osp.join(self.output_dir, f"run_{run_num}.py"), 215 | ) 216 | 217 | # Run experiment 218 | command = ["python", "experiment.py", f"--out_dir=run_{run_num}"] 219 | 220 | try: 221 | result = subprocess.run( 222 | command, 223 | cwd=self.output_dir, 224 | stderr=subprocess.PIPE, 225 | text=True, 226 | timeout=timeout, 227 | ) 228 | 229 | if result.stderr: 230 | print(result.stderr, file=sys.stderr) 231 | 232 | if result.returncode != 0: 233 | print(f"Run {run_num} failed with return code {result.returncode}") 234 | if "ModuleNotFoundError" in result.stderr and getattr( 235 | self, "auto_install", True 236 | ): 237 | missing_pkg = self._extract_missing_package(result.stderr) 238 | print( 239 | f"[System] Missing package detected: {missing_pkg}. Attempting to install..." 240 | ) 241 | subprocess.run( 242 | [sys.executable, "-m", "pip", "install", missing_pkg] 243 | ) 244 | print("[System] Re-running after installing dependency...") 245 | return self._run_single_experiment(run_num, timeout=timeout) 246 | 247 | self._cleanup_failed_run(run_num) 248 | 249 | stderr_output = result.stderr 250 | if len(stderr_output) > self.max_stderr_output: 251 | stderr_output = "..." + stderr_output[-self.max_stderr_output :] 252 | 253 | return 1, stderr_output 254 | 255 | # Load and format results 256 | with open( 257 | osp.join(self.output_dir, f"run_{run_num}", "final_info.json"), "r" 258 | ) as f: 259 | results = json.load(f) 260 | 261 | if isinstance(results, dict): 262 | results = { 263 | k: v["means"] if isinstance(v, dict) and "means" in v else v 264 | for k, v in results.items() 265 | } 266 | elif isinstance(results, list): 267 | results = {f"entry_{i+1}": entry for i, entry in enumerate(results)} 268 | 269 | results = { 270 | k: v["means"] if isinstance(v, dict) and "means" in v else v 271 | for k, v in results.items() 272 | } 273 | 274 | return 0, self.prompts.experiment_success_prompt.format( 275 | run_num=run_num, results=results, next_run=run_num + 1 276 | ) 277 | 278 | except TimeoutExpired: 279 | print(f"Run {run_num} timed out after {timeout} seconds") 280 | self._cleanup_failed_run(run_num) 281 | return 1, self.prompts.experiment_timeout_prompt.format(timeout=timeout) 282 | 283 | def _update_notes(self) -> None: 284 | """Update notes.txt with plot descriptions.""" 285 | # Set files for this operation 286 | self.coder.fnames = [osp.join(self.output_dir, "notes.txt")] 287 | self.coder.run(self.prompts.notes_prompt) 288 | 289 | def _cleanup_failed_run(self, run_num: int) -> None: 290 | """Clean up files from a failed run.""" 291 | run_dir = osp.join(self.output_dir, f"run_{run_num}") 292 | if osp.exists(run_dir): 293 | shutil.rmtree(run_dir) 294 | 295 | def _extract_missing_package(self, stderr: str) -> str: 296 | for line in stderr.splitlines(): 297 | if "ModuleNotFoundError" in line: 298 | parts = line.split("'") 299 | if len(parts) >= 2: 300 | return parts[1] 301 | return "unknown-package" 302 | -------------------------------------------------------------------------------- /tiny_scientist/configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Optional, Type, TypeVar 3 | 4 | import yaml 5 | from pydantic import BaseModel 6 | 7 | from .data import CoderPrompt, DrawerPrompt, ReviewerPrompt, ThinkerPrompt, WriterPrompt 8 | 9 | T = TypeVar("T", bound=BaseModel) 10 | 11 | 12 | class PromptTemplate(BaseModel): 13 | """Configuration for prompts.""" 14 | 15 | coder_prompt: CoderPrompt 16 | thinker_prompt: ThinkerPrompt 17 | reviewer_prompt: ReviewerPrompt 18 | writer_prompt: WriterPrompt 19 | drawer_prompt: DrawerPrompt 20 | 21 | 22 | class Config(BaseModel): 23 | prompt_template: PromptTemplate 24 | 25 | def __init__(self, prompt_path: Optional[str] = None, **kwargs: Any) -> None: 26 | if not prompt_path: 27 | prompt_path = self._default_config_path() 28 | 29 | yaml_data = {"prompt_template": self._load_from_yaml(prompt_path)} 30 | kwargs.update(yaml_data) 31 | super().__init__(**kwargs) 32 | 33 | def _default_config_path(self) -> str: 34 | this_dir = os.path.dirname(__file__) 35 | return os.path.abspath(os.path.join(this_dir, "./", "prompts")) 36 | 37 | def _load_from_yaml(self, prompt_path: str) -> PromptTemplate: 38 | return PromptTemplate( 39 | thinker_prompt=self._load_yaml_file( 40 | os.path.join(prompt_path, "thinker_prompt.yaml"), ThinkerPrompt 41 | ), 42 | coder_prompt=self._load_yaml_file( 43 | os.path.join(prompt_path, "coder_prompt.yaml"), CoderPrompt 44 | ), 45 | writer_prompt=self._load_yaml_file( 46 | os.path.join(prompt_path, "writer_prompt.yaml"), WriterPrompt 47 | ), 48 | reviewer_prompt=self._load_yaml_file( 49 | os.path.join(prompt_path, "reviewer_prompt.yaml"), 50 | ReviewerPrompt, 51 | ), 52 | drawer_prompt=self._load_yaml_file( 53 | os.path.join(prompt_path, "drawer_prompt.yaml"), 54 | DrawerPrompt, 55 | ), 56 | ) 57 | 58 | def _load_yaml_file(self, file_path: str, model_class: Type[T]) -> T: 59 | if not os.path.exists(file_path): 60 | raise FileNotFoundError(f"YAML file '{file_path}' does not exist.") 61 | with open(file_path, "r") as f: 62 | return model_class(**yaml.safe_load(f)) 63 | -------------------------------------------------------------------------------- /tiny_scientist/data.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class ReviewerPrompt(BaseModel): 7 | reviewer_system_prompt_base: str 8 | reviewer_system_prompt_neg: str 9 | reviewer_system_prompt_pos: str 10 | query_prompt: str 11 | template_instructions: str 12 | neurips_form: str 13 | meta_reviewer_system_prompt: str 14 | reviewer_reflection_prompt: str 15 | 16 | 17 | class WriterPrompt(BaseModel): 18 | write_system_prompt: str 19 | write_system_prompt_related_work: str 20 | section_tips: Dict[str, str] 21 | error_list: str 22 | refinement_prompt: str 23 | second_refinement_prompt: str 24 | citation_system_prompt: str 25 | abstract_prompt: str 26 | section_prompt: Dict[str, str] 27 | citation_related_work_prompt: str 28 | add_citation_prompt: str 29 | embed_citation_prompt: str 30 | related_work_prompt: str 31 | title_refinement_prompt: str 32 | citation_aider_format: str 33 | 34 | 35 | class CoderPrompt(BaseModel): 36 | experiment_keyword_prompt: str 37 | experiment_prompt: str 38 | experiment_success_prompt: str 39 | experiment_error_prompt: str 40 | experiment_timeout_prompt: str 41 | plot_initial_prompt: str 42 | plot_error_prompt: str 43 | plot_timeout_prompt: str 44 | notes_prompt: str 45 | 46 | 47 | class ThinkerPrompt(BaseModel): 48 | idea_system_prompt: str 49 | evaluation_system_prompt: str 50 | idea_evaluation_prompt: str 51 | modify_idea_prompt: str 52 | merge_ideas_prompt: str 53 | query_prompt: str 54 | rethink_query_prompt: str 55 | novelty_query_prompt: str 56 | novelty_system_prompt: str 57 | idea_first_prompt: str 58 | idea_reflection_prompt: str 59 | novelty_prompt: str 60 | experiment_plan_prompt: str 61 | 62 | 63 | class DrawerPrompt(BaseModel): 64 | diagram_system_prompt_base: str 65 | template_instructions: str 66 | few_shot_instructions: str 67 | error_list: str 68 | refinement_prompt: str 69 | -------------------------------------------------------------------------------- /tiny_scientist/prompts/coder_prompt.yaml: -------------------------------------------------------------------------------- 1 | experiment_keyword_prompt: | 2 | The experiment is organized into three sections: 3 | ## Model Section: 4 | {model} 5 | 6 | ## Dataset Section: 7 | {dataset} 8 | 9 | ## Metric Section: 10 | {metric} 11 | 12 | Your job is to extract the essential names of models, datasets, and evaluation metrics that are directly useful for coding and experimentation. 13 | 14 | ### Output Format: 15 | Return a JSON object with the following structure: 16 | ```json 17 | {{ 18 | "model": ["Model1", "Model2", ...], 19 | "dataset": ["Dataset1", "Dataset2", ...], 20 | "metric": ["Metric1", "Metric2", ...] 21 | }} 22 | 23 | experiment_prompt: | 24 | You are writing a Python script named `experiment.py` that must be runnable. 25 | 26 | ## Research Context 27 | Title: {title} 28 | Problem: {problem} 29 | Novelty: {novelty} 30 | Proposed Approach: {approach} 31 | 32 | ## Experimental Setup 33 | The following describes the experiment setup. You must base your implementation strictly on this structure: 34 | 35 | Models/Algorithms to use: {model} 36 | Datasets involved: {dataset} 37 | Evaluation metrics: {metric} 38 | 39 | ## Execution Command (DO NOT MODIFY): 40 | You have {max_runs} runs to complete this experiment. For each run, the script will be executed using: 41 | `python experiment.py --out_dir=run_i` 42 | where `i` is the run number (`run_1`, `run_2`, etc.). 43 | 44 | ## YOU MUST ENSURE experiment.py: 45 | 1. Parses the `--out_dir` argument. 46 | 2. Creates the output directory using `os.makedirs(out_dir, exist_ok=True)`. 47 | 3. Performs actual model training and evaluation — DO NOT simulate results using random numbers or hardcode experiment result, all result should get from execution. 48 | 4. Implements evaluation metircs with real logic. 49 | 5. **Saves results as a dictionary in a file named `final_info.json` placed directly inside `out_dir`** — do **not** save into nested folders like `out_dir/variant_name/final_info.json`. 50 | 51 | ## Computational Constraints 52 | - Ensure the code is computationally affordable to run on a single GPU or CPU machine. 53 | - Avoid using large models like GPT, T5, BERT-large, or full ImageNet training. 54 | - Prefer small-scale tasks, toy models, or low-cost benchmarks (e.g., MNIST, UCI datasets, small MLPs or ResNet18). 55 | - Do not use complex distributed training or multi-GPU setups. 56 | 57 | Do not add extra command-line arguments. 58 | If your current experiment.py has placeholder code like `...`, replace them with runnable implementations. 59 | If any external functions like `compute_loss`, `evaluate_model`, or `log_results` are used, implement them too. 60 | 61 | ## Baseline Results 62 | You do not need to re-run the baseline. If available, the results are provided below: 63 | {baseline_results} 64 | 65 | --- 66 | Please begin writing the `experiment.py` file now. 67 | 68 | experiment_success_prompt: | 69 | Run {run_num} completed. Here are the results: 70 | {results} 71 | 72 | Decide if you need to re-plan your experiments given the result (you often will not need to). 73 | 74 | Someone else will be using `notes.txt` to perform a writeup on this in the future. 75 | Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary. 76 | 77 | Then, implement the next thing on your list. 78 | We will then run the command `python experiment.py --out_dir=run_{next_run}'. 79 | YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS. 80 | If you are finished with experiments, respond with 'ALL_COMPLETED'. 81 | 82 | experiment_error_prompt: | 83 | There was an error running the experiment script: 84 | {message} 85 | Your goal is still to implement this experiment: {Title}. 86 | The purpose is: {Experiment}. 87 | You have {max_runs} runs total. We're currently on run {run_time}. 88 | Please fix `experiment.py` so that it runs successfully with: 89 | `python experiment.py --out_dir=run_{run_time}`. 90 | Make sure to implement any missing parts like model definition, loss function, data loading, and final_info.json saving. 91 | 92 | 93 | experiment_timeout_prompt: | 94 | Run timed out after {timeout} seconds 95 | 96 | plot_initial_prompt: | 97 | Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup. 98 | 99 | In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot. 100 | 101 | Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs. 102 | 103 | We will be running the command `python plot.py` to generate the plots. 104 | 105 | plot_error_prompt: | 106 | Plotting failed with the following error {error} 107 | 108 | plot_timeout_prompt: | 109 | Plotting timed out after {timeout} seconds 110 | 111 | notes_prompt: | 112 | Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth. 113 | 114 | Somebody else will be using `notes.txt` to write a report on this in the future. 115 | -------------------------------------------------------------------------------- /tiny_scientist/prompts/diagram_prompt.yaml: -------------------------------------------------------------------------------- 1 | diagram_system_prompt_base: > 2 | You are a master of creating demonstration diagrams for the academic paper. 3 | You first extract components in the pipeline of 1) a specific architecture in the paper or 2) the whole paper and then 4 | draw a diagram in the SVG format accordingly. 5 | Your generated diagram should follow the components you extract. 6 | You should be careful of the overlapping problem in the layout, that different items should not overlap with each other in the layout. 7 | You need to ensure that the layout is reasonable. If the diagram is too tight, separate the main flow diagram into several sub-diagrams with sub-titles. 8 | template_instructions: | 9 | Respond in the following format: 10 | 11 | SUMMARY: 12 | 13 | 14 | DIAGRAM SVG: 15 | ```svg 16 | 17 | ``` 18 | 19 | In , first demonstrate the pipeline of {topic} by listing all related components mentioned in the paper. 20 | You should especially focus on the inputs and outputs of each component and make sure they can be organized as flows. 21 | 22 | In , design the diagram in SVG format to visualize the above pipeline. 23 | few_shot_instructions: | 24 | Here is a layout example for reference. You need to first generate a description of the diagram layout in this example. 25 | Then, you need to refer to the description when generating the diagram for the given paper. 26 | """ 27 | {example} 28 | """ 29 | -------------------------------------------------------------------------------- /tiny_scientist/prompts/drawer_prompt.yaml: -------------------------------------------------------------------------------- 1 | diagram_system_prompt_base: | 2 | You are an expert in creating clear, professional diagrams for academic papers. 3 | Your task is to generate an SVG diagram that visually represents the key concepts, methods, or workflows described in the text. 4 | 5 | Guidelines: 6 | 1. Create diagrams that are clean, professional, and suitable for academic publication 7 | 2. Use standard diagramming conventions and symbols 8 | 3. Ensure the diagram is self-contained and can be understood without additional context 9 | 4. Include appropriate labels, arrows, and annotations 10 | 5. Make the diagram scalable and maintainable 11 | 6. Focus on clarity and readability over artistic complexity 12 | 7. Ensure the SVG code is compatible with LaTeX figure environment 13 | 8. Use LaTeX-compatible fonts and symbols in labels 14 | 9. Keep the diagram size reasonable for a single-column academic paper 15 | 16 | Your response must include: 17 | 1. A SUMMARY section describing the key elements of the diagram (this will be used as the figure caption) 18 | 2. A DIAGRAM SVG section containing the actual SVG code (this will be embedded in a LaTeX figure) 19 | 3. The SVG must be valid, self-contained, and LaTeX-compatible 20 | 21 | template_instructions: | 22 | Please create a professional academic diagram based on the provided text. 23 | The diagram should: 24 | - Clearly illustrate the key concepts and relationships 25 | - Use appropriate visual elements (boxes, arrows, etc.) 26 | - Include clear labels and annotations 27 | - Be scalable and maintainable 28 | - Follow academic diagramming conventions 29 | - Be compatible with LaTeX figure environment 30 | - Use LaTeX-compatible fonts and symbols 31 | 32 | Your response must include: 33 | 1. A SUMMARY section describing the key elements of the diagram (will be used as figure caption) 34 | 2. A DIAGRAM SVG section containing the actual SVG code (will be embedded in LaTeX) 35 | 36 | few_shot_instructions: | 37 | Here is an example of how to create a diagram: 38 | {example} 39 | 40 | Based on this example, please create a diagram for the following text. 41 | 42 | error_list: | 43 | - Missing or invalid SVG code 44 | - Unclear or missing labels 45 | - Inconsistent styling 46 | - Overly complex or cluttered diagrams 47 | - Missing key elements from the text 48 | - Poor scaling or layout 49 | - Unprofessional appearance 50 | - Missing summary section 51 | - Missing SVG section 52 | - Invalid SVG syntax 53 | - Non-LaTeX compatible SVG elements 54 | - Incompatible fonts or symbols 55 | - Oversized diagrams 56 | - Poor figure caption content 57 | 58 | refinement_prompt: | 59 | Please review and refine the diagram you just created. 60 | Focus on: 61 | - Adding modular structure (e.g., labeled stages for "Input", "Processing", "Output") 62 | - Including visual indicators for dynamic processes (e.g., dashed arrows for adaptive flows) 63 | - Better layout alignment (use horizontal or vertical symmetry) 64 | - Distinguishing components visually (via grouping boxes or layout space) 65 | - Fixing SVG issues: alignment, font size, spacing, and arrow direction 66 | 67 | Your output must include both a SUMMARY and a valid SVG diagram. 68 | Pay particular attention to fixing any errors such as: 69 | {error_list} 70 | 71 | Here is the current diagram: 72 | """ 73 | {diagram_content} 74 | """ 75 | 76 | Please provide an improved version that addresses any issues while maintaining the core message and structure. 77 | Ensure the diagram is properly formatted for LaTeX integration. 78 | -------------------------------------------------------------------------------- /tiny_scientist/prompts/reviewer_prompt.yaml: -------------------------------------------------------------------------------- 1 | reviewer_system_prompt_base: > 2 | You are an AI researcher who is reviewing a paper that was submitted to a prestigious ML venue. Be critical and cautious in your decision. 3 | reviewer_system_prompt_neg: > 4 | You are an AI researcher who is reviewing a paper that was submitted to a prestigious ML venue. Be critical and cautious in your decision. If a paper is bad or you are unsure, give it bad scores and reject it. 5 | reviewer_system_prompt_pos: > 6 | You are an AI researcher who is reviewing a paper that was submitted to a prestigious ML venue. Be critical and cautious in your decision. If a paper is good or you are unsure, give it good scores and accept it. 7 | query_prompt: | 8 | Here is the paper text: 9 | 10 | ``` 11 | {paper_text} 12 | ``` 13 | 14 | Generate a concise search query (e.g., "attention is all you need") that captures the main topics and any identified weaknesses of the paper. 15 | This query will be used to retrieve additional literature to inform your review. 16 | 17 | 18 | Response in the following format: 19 | 20 | RESPONSE: 21 | ```json 22 | 23 | ``` 24 | 25 | In , respond in JSON format with ONLY the following field: 26 | - "Query": The query you just generated 27 | template_instructions: | 28 | Respond in the following format: 29 | 30 | THOUGHT: 31 | 32 | 33 | REVIEW JSON: 34 | ```json 35 | 36 | ``` 37 | 38 | In , first briefly discuss your intuitions and reasoning for the evaluation. 39 | Detail your high-level arguments, necessary choices and desired outcomes of the review. 40 | 41 | Before writing your review, please consider the following related works: {related_works_string} 42 | 43 | Do not make generic comments here, but be specific to your current paper. 44 | Treat this as the note-taking phase of your review. 45 | 46 | In , provide the review in JSON format with the following fields in the order: 47 | - "Summary": A summary of the paper content and its contributions. 48 | - "Strengths": A list of strengths of the paper. 49 | - "Weaknesses": A list of weaknesses of the paper. 50 | - "Originality": A rating from 1 to 4 (low, medium, high, very high). 51 | - "Quality": A rating from 1 to 4 (low, medium, high, very high). 52 | - "Clarity": A rating from 1 to 4 (low, medium, high, very high). 53 | - "Significance": A rating from 1 to 4 (low, medium, high, very high). 54 | - "Questions": A set of clarifying questions to be answered by the paper authors. 55 | - "Limitations": A set of limitations and potential negative societal impacts of the work. 56 | - "Ethical Concerns": A boolean value indicating whether there are ethical concerns. 57 | - "Soundness": A rating from 1 to 4 (poor, fair, good, excellent). 58 | - "Presentation": A rating from 1 to 4 (poor, fair, good, excellent). 59 | - "Contribution": A rating from 1 to 4 (poor, fair, good, excellent). 60 | - "Overall": A rating from 1 to 10 (very strong reject to award quality). 61 | - "Confidence": A rating from 1 to 5 (low, medium, high, very high, absolute). 62 | - "Decision": A decision that has to be one of the following: Accept, Reject. 63 | neurips_form: | 64 | ## Review Form 65 | Below is a description of the questions you will be asked on the review form for each paper and some guidelines on what to consider when answering these questions. 66 | When writing your review, please keep in mind that after decisions have been made, reviews and meta-reviews of accepted papers and opted-in rejected papers will be made public. 67 | 68 | 1. Summary: Briefly summarize the paper and its contributions. 69 | 2. Strengths and Weaknesses: Provide a thorough assessment of the paper's strengths and weaknesses. 70 | 3. Originality: Rate from 1 to 4. 71 | 4. Quality: Rate from 1 to 4. 72 | 5. Clarity: Rate from 1 to 4. 73 | 6. Significance: Rate from 1 to 4. 74 | 7. Questions: List any clarifying questions. 75 | 8. Limitations: List any limitations or potential negative societal impacts. 76 | 9. Ethical Concerns: Indicate whether there are ethical concerns. 77 | 10. Soundness: Rate from 1 to 4. 78 | 11. Presentation: Rate from 1 to 4. 79 | 12. Contribution: Rate from 1 to 4. 80 | 13. Overall: Rate from 1 to 10. 81 | 14. Confidence: Rate from 1 to 5. 82 | 15. Decision: Accept or Reject. 83 | 84 | {template_instructions} 85 | 86 | meta_reviewer_system_prompt: | 87 | You are an Area Chair at a machine learning conference. 88 | You are in charge of meta-reviewing a paper that was reviewed by {reviewer_count} reviewers. 89 | Your job is to aggregate the reviews into a single meta-review in the same format. 90 | Be critical and cautious in your decision, find consensus, and respect the opinion of all the reviewers. 91 | 92 | reviewer_reflection_prompt: | 93 | In your thoughts, first carefully consider the accuracy and soundness of the review you just created. 94 | Include any other factors that you think are important in evaluating the paper. 95 | Ensure the review is clear and concise, and the JSON is in the correct format. 96 | Do not make things overly complicated. 97 | In the next attempt, try and refine and improve your review. 98 | Stick to the spirit of the original review unless there are glaring issues. 99 | 100 | Additionally, please consider the following related works obtained via a literature search: 101 | 102 | ``` 103 | {related_works_string} 104 | ``` 105 | 106 | Use these search results to assess the paper’s novelty, relevance, and significance. 107 | Provide specific comments on how the paper aligns with or differs from these related works. 108 | 109 | Respond in the same format as before: 110 | THOUGHT: 111 | 112 | 113 | REVIEW JSON: 114 | ```json 115 | 116 | ``` 117 | 118 | If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON. 119 | ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES. 120 | -------------------------------------------------------------------------------- /tiny_scientist/reviewer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Optional, Tuple 3 | 4 | from rich import print 5 | 6 | from .configs import Config 7 | from .tool import BaseTool, PaperSearchTool 8 | from .utils.error_handler import api_calling_error_exponential_backoff 9 | from .utils.input_formatter import InputFormatter 10 | from .utils.llm import ( 11 | create_client, 12 | extract_json_between_markers, 13 | get_response_from_llm, 14 | ) 15 | 16 | 17 | class Reviewer: 18 | def __init__( 19 | self, 20 | model: str, 21 | tools: List[BaseTool], 22 | num_reviews: int = 3, 23 | num_reflections: int = 2, 24 | temperature: float = 0.75, 25 | prompt_template_dir: Optional[str] = None, 26 | ): 27 | self.tools = tools 28 | self.num_reviews = num_reviews 29 | self.num_reflections = num_reflections 30 | self.client, self.model = create_client(model) 31 | self.temperature = temperature 32 | self.config = Config(prompt_template_dir) 33 | self.searcher = PaperSearchTool() 34 | self._query_cache: Dict[str, List[Dict[str, Any]]] = {} 35 | self.last_related_works_string = "" 36 | 37 | self.prompts = self.config.prompt_template.reviewer_prompt 38 | self.prompts.neurips_form = self.prompts.neurips_form.format( 39 | template_instructions=self.prompts.template_instructions 40 | ) 41 | 42 | def review(self, pdf_path: str) -> str: 43 | formatter = InputFormatter() 44 | text = formatter.parse_paper_pdf_to_json(pdf_path=pdf_path) 45 | paper_text = str(text) 46 | print(f"Using content from PDF file: {pdf_path}") 47 | 48 | if not paper_text: 49 | raise ValueError("No paper text provided for review.") 50 | 51 | query = self._generate_query(paper_text) 52 | 53 | related_works_string = self._get_related_works(query) 54 | self.last_related_works_string = related_works_string 55 | 56 | base_prompt = self._build_review_prompt(paper_text, related_works_string) 57 | system_prompt = self.prompts.reviewer_system_prompt_neg 58 | 59 | review, _ = self._generate_review(base_prompt, system_prompt, msg_history=[]) 60 | return json.dumps(review, indent=2) 61 | 62 | def re_review(self, review_json: str) -> str: 63 | current_review = json.loads(review_json) 64 | if not current_review: 65 | raise ValueError("No review provided for re-review.") 66 | 67 | system_prompt = self.prompts.reviewer_system_prompt_neg 68 | related_works_string = self.last_related_works_string 69 | 70 | new_review, _, _ = self._reflect_review( 71 | review=current_review, 72 | reviewer_system_prompt=system_prompt, 73 | related_works_string=related_works_string, 74 | msg_history=[], 75 | ) 76 | return json.dumps(new_review, indent=2) 77 | 78 | def run(self, pdf_path: str) -> Dict[str, Any]: 79 | all_reviews = [] 80 | 81 | for i in range(self.num_reviews): 82 | print(f"Generating {i + 1}/{self.num_reviews} review") 83 | current_review = self.review(pdf_path) 84 | 85 | # Apply tools to review 86 | for tool in self.tools: 87 | tool_input = json.dumps({"review": current_review}) 88 | tool_output = tool.run(tool_input) 89 | if "review" in tool_output: 90 | current_review = tool_output["review"]["review"] 91 | 92 | # Apply reflections 93 | for j in range(self.num_reflections): 94 | current_review = self.re_review(current_review) 95 | 96 | all_reviews.append(json.loads(current_review)) 97 | 98 | return self._write_meta_review(all_reviews) 99 | 100 | def _get_related_works(self, query: str) -> str: 101 | if query in self._query_cache: 102 | related_papers = self._query_cache[query] 103 | else: 104 | results_dict = self.searcher.run(query) 105 | related_papers = list(results_dict.values()) 106 | self._query_cache[query] = related_papers if related_papers else [] 107 | 108 | if related_papers: 109 | related_works_string = self._format_paper_results(related_papers) 110 | print("✅Related Works String Found") 111 | else: 112 | related_works_string = "No related works found." 113 | print("❎No Related Works Found") 114 | 115 | return related_works_string 116 | 117 | def _build_review_prompt(self, text: str, related_works_string: str) -> str: 118 | base_prompt = self.prompts.neurips_form.format( 119 | related_works_string=related_works_string 120 | ) 121 | return f"{base_prompt}\nHere is the paper you are asked to review:\n```\n{text}\n```" 122 | 123 | def _generate_query(self, text: str) -> str: 124 | query_prompt = self.prompts.query_prompt.format(paper_text=text) 125 | response, _ = get_response_from_llm( 126 | query_prompt, 127 | client=self.client, 128 | model=self.model, 129 | system_message=self.prompts.reviewer_system_prompt_neg, 130 | temperature=self.temperature, 131 | msg_history=[], 132 | ) 133 | query_data = extract_json_between_markers(response) 134 | return str(query_data.get("Query", "")) if query_data else "" 135 | 136 | @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) 137 | def _generate_review( 138 | self, 139 | base_prompt: str, 140 | reviewer_system_prompt: str, 141 | msg_history: Optional[List[Dict[str, Any]]] = None, 142 | ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: 143 | if msg_history is None: 144 | msg_history = [] 145 | 146 | llm_review, msg_history = get_response_from_llm( 147 | base_prompt, 148 | model=self.model, 149 | client=self.client, 150 | system_message=reviewer_system_prompt, 151 | print_debug=False, 152 | msg_history=msg_history, 153 | temperature=self.temperature, 154 | ) 155 | review = extract_json_between_markers(llm_review) 156 | return review if review is not None else {}, msg_history 157 | 158 | @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) 159 | def _reflect_review( 160 | self, 161 | review: Dict[str, Any], 162 | reviewer_system_prompt: str, 163 | related_works_string: str, 164 | msg_history: List[Dict[str, Any]], 165 | ) -> Tuple[Dict[str, Any], List[Dict[str, Any]], bool]: 166 | updated_prompt = ( 167 | f"Previous review: {json.dumps(review)}\n" 168 | + self.prompts.reviewer_reflection_prompt.format( 169 | related_works_string=related_works_string 170 | ) 171 | ) 172 | 173 | text, msg_history = get_response_from_llm( 174 | updated_prompt, 175 | client=self.client, 176 | model=self.model, 177 | system_message=reviewer_system_prompt, 178 | msg_history=msg_history, 179 | temperature=self.temperature, 180 | ) 181 | 182 | new_review = extract_json_between_markers(text) 183 | is_done = "I am done" in text 184 | 185 | return new_review or {}, msg_history, is_done 186 | 187 | def _write_meta_review(self, reviews: List[Dict[str, Any]]) -> Dict[str, Any]: 188 | if not reviews: 189 | raise ValueError("At least one review must be provided for meta-review.") 190 | 191 | formatted_reviews = "".join( 192 | f"\nReview {i + 1}:\n```\n{json.dumps(r)}\n```\n" 193 | for i, r in enumerate(reviews) 194 | ) 195 | 196 | meta_prompt = self.prompts.neurips_form + formatted_reviews 197 | meta_system_prompt = self.prompts.meta_reviewer_system_prompt.format( 198 | reviewer_count=len(reviews) 199 | ) 200 | 201 | llm_meta_review, _ = get_response_from_llm( 202 | meta_prompt, 203 | model=self.model, 204 | client=self.client, 205 | system_message=meta_system_prompt, 206 | msg_history=[], 207 | temperature=self.temperature, 208 | ) 209 | 210 | meta_review = extract_json_between_markers(llm_meta_review) 211 | if meta_review is None: 212 | return {} 213 | 214 | return self._aggregate_scores(meta_review, reviews) 215 | 216 | def _aggregate_scores( 217 | self, meta_review: Dict[str, Any], reviews: List[Dict[str, Any]] 218 | ) -> Dict[str, Any]: 219 | score_fields = { 220 | "Originality": (1, 4), 221 | "Quality": (1, 4), 222 | "Clarity": (1, 4), 223 | "Significance": (1, 4), 224 | "Soundness": (1, 4), 225 | "Presentation": (1, 4), 226 | "Contribution": (1, 4), 227 | "Overall": (1, 10), 228 | "Confidence": (1, 5), 229 | } 230 | 231 | for score, (min_val, max_val) in score_fields.items(): 232 | valid_scores = [ 233 | r[score] 234 | for r in reviews 235 | if score in r 236 | and isinstance(r[score], (int, float)) 237 | and min_val <= r[score] <= max_val 238 | ] 239 | 240 | if valid_scores: 241 | meta_review[score] = int(round(sum(valid_scores) / len(valid_scores))) 242 | 243 | return meta_review 244 | 245 | @staticmethod 246 | def _format_paper_results(papers: List[Dict[str, Any]]) -> str: 247 | if not papers: 248 | return "No papers found." 249 | 250 | paper_strings = [] 251 | for i, paper in enumerate(papers): 252 | paper_strings.append( 253 | f"{i}: {paper.get('title', 'No title')}. {paper.get('source', 'No authors')}. " 254 | f"{paper.get('info', 'No venue')}" 255 | ) 256 | 257 | return "\n\n".join(paper_strings) 258 | -------------------------------------------------------------------------------- /tiny_scientist/scientist.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | from rich import print 4 | 5 | from .coder import Coder 6 | from .reviewer import Reviewer 7 | from .thinker import Thinker 8 | from .utils.input_formatter import InputFormatter 9 | from .writer import Writer 10 | 11 | 12 | class TinyScientist: 13 | def __init__( 14 | self, 15 | model: str = "gpt-4o", 16 | output_dir: str = "./", 17 | template: str = "acl", 18 | prompt_template_dir: Optional[str] = None, 19 | ): 20 | self.model = model 21 | self.output_dir = output_dir 22 | self.template = template 23 | self.prompt_template_dir = prompt_template_dir 24 | self.input_formatter = InputFormatter() 25 | 26 | self.thinker = Thinker( 27 | model=model, 28 | output_dir=output_dir, 29 | prompt_template_dir=prompt_template_dir, 30 | tools=[], 31 | iter_num=3, 32 | search_papers=True, 33 | generate_exp_plan=True, 34 | ) 35 | 36 | self.coder = Coder( 37 | model=model, 38 | output_dir=output_dir, 39 | prompt_template_dir=prompt_template_dir, 40 | max_iters=4, 41 | max_runs=3, 42 | ) 43 | 44 | self.writer = Writer( 45 | model=model, 46 | output_dir=output_dir, 47 | prompt_template_dir=prompt_template_dir, 48 | template=template, 49 | ) 50 | 51 | self.reviewer = Reviewer( 52 | model=model, 53 | prompt_template_dir=prompt_template_dir, 54 | tools=[], 55 | ) 56 | 57 | def think( 58 | self, intent: str, num_ideas: int = 1, pdf_content: Optional[str] = None 59 | ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: 60 | print("🧠 Generating idea...") 61 | ideas = self.thinker.run( 62 | intent=intent, num_ideas=num_ideas, pdf_content=pdf_content 63 | ) 64 | print(ideas) 65 | print("✅ Idea generated.") 66 | return ideas 67 | 68 | def code( 69 | self, 70 | idea: Dict[str, Any], 71 | baseline_results: Optional[Dict[str, Any]] = {}, 72 | ) -> Tuple[bool, str]: 73 | print("💻 Running experiments...") 74 | status, exp_path = self.coder.run(idea=idea, baseline_results=baseline_results) 75 | if status: 76 | print(f"✅ Experiment completed successfully. Results saved at {exp_path}") 77 | else: 78 | print(f"❌ Experiment failed. Please check {exp_path} for details.") 79 | return status, exp_path 80 | 81 | def write(self, idea: Dict[str, Any], experiment_dir: str) -> str: 82 | print("📝 Writing paper...") 83 | pdf_path, paper_name = self.writer.run(idea=idea, experiment_dir=experiment_dir) 84 | print( 85 | f"Check the generated paper named as {paper_name} and saved at {pdf_path}" 86 | ) 87 | print("✅ Paper written.") 88 | return pdf_path 89 | 90 | def review(self, pdf_path: str) -> Dict[str, Any]: 91 | print("🔍 Reviewing paper...") 92 | review = self.reviewer.run(pdf_path=pdf_path) 93 | print(review) 94 | print("✅ Review complete.") 95 | return review 96 | -------------------------------------------------------------------------------- /tiny_scientist/tool.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import json 3 | import os 4 | import re 5 | import time 6 | from typing import Any, Dict, List, Optional, cast 7 | 8 | import requests 9 | import toml 10 | from rich import print 11 | 12 | from .configs import Config 13 | from .utils.error_handler import api_calling_error_exponential_backoff 14 | from .utils.llm import create_client, get_response_from_llm 15 | 16 | # Load config 17 | config_path = os.path.join(os.path.dirname(__file__), "config.toml") 18 | config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} 19 | 20 | 21 | class BaseTool(abc.ABC): 22 | @abc.abstractmethod 23 | def run(self, query: str) -> Dict[str, Dict[str, str]]: 24 | pass 25 | 26 | 27 | class CodeSearchTool(BaseTool): 28 | def __init__(self) -> None: 29 | self.github_token = config["core"].get("github_token", None) 30 | 31 | def run( 32 | self, query: str, search_type: str = "repositories" 33 | ) -> Dict[str, Dict[str, str]]: 34 | print(f"[github API calling] Searching for code with query: {query}") 35 | results = {} 36 | 37 | try: 38 | idea = json.loads(query) 39 | if isinstance(idea, dict) and any( 40 | k in idea for k in ["Title", "Experiment"] 41 | ): 42 | query = self.format_github_repo_query(idea) 43 | print(f"[github API calling] Formatted query from idea: {query}") 44 | except (json.JSONDecodeError, TypeError): 45 | pass 46 | 47 | repos = self._search_github(query=query, search_type=search_type) 48 | 49 | if repos: 50 | for i, repo in enumerate(repos): 51 | results[str(i)] = { 52 | "title": repo["name"], 53 | "source": repo["url"], 54 | "info": f"Stars: {repo['stars']}", 55 | } 56 | 57 | return results 58 | 59 | def format_github_repo_query( 60 | self, idea: Dict[str, Any], max_terms: int = 6, max_query_length: int = 250 61 | ) -> str: 62 | import re 63 | 64 | import spacy 65 | 66 | title = idea.get("Title", "") 67 | experiment = idea.get("Experiment", "") 68 | combined_text = f"{title}. {experiment}" 69 | 70 | nlp = spacy.load("en_core_web_sm") 71 | doc = nlp(combined_text) 72 | candidates = set() 73 | 74 | # Extract short noun phrases 75 | for chunk in doc.noun_chunks: 76 | phrase = chunk.text.strip().lower() 77 | if 1 <= len(phrase.split()) <= 4: 78 | candidates.add(phrase) 79 | 80 | # Add important standalone nouns and proper nouns 81 | for token in doc: 82 | if token.pos_ in {"NOUN", "PROPN"} and len(token.text) > 2: 83 | candidates.add(token.text.lower()) 84 | 85 | # Clean and deduplicate 86 | seen = set() 87 | keywords = [] 88 | for kw in candidates: 89 | cleaned = re.sub(r"[^\w\s]", "", kw) 90 | if cleaned not in seen: 91 | seen.add(cleaned) 92 | keywords.append(cleaned) 93 | if len(keywords) >= max_terms: 94 | break 95 | 96 | # Build query string 97 | quoted_keywords = [f'"{kw}"' if " " in kw else kw for kw in keywords] 98 | base_query = " ".join(quoted_keywords) 99 | suffix = " in:file language:python" 100 | full_query = f"{base_query} {suffix}" 101 | 102 | # Truncate if needed 103 | if len(full_query) > max_query_length: 104 | full_query = f"{' '.join(quoted_keywords[:max_terms//2])} {suffix}" 105 | 106 | return full_query 107 | 108 | def _search_github( 109 | self, query: str, search_type: str, result_limit: int = 10 110 | ) -> Optional[List[Dict[str, Any]]]: 111 | if search_type not in ["repositories", "code"]: 112 | raise ValueError("search_type must be either 'repositories' or 'code'.") 113 | 114 | url = f"https://api.github.com/search/{search_type}" 115 | headers = ( 116 | {"Authorization": f"token {self.github_token}"} if self.github_token else {} 117 | ) 118 | 119 | params = { 120 | "q": query, 121 | "sort": "stars" if search_type == "repositories" else "indexed", 122 | "order": "desc", 123 | "per_page": result_limit, 124 | } 125 | 126 | response = requests.get(url, headers=headers, params=params) 127 | print( 128 | f"GitHub {search_type.capitalize()} Response Status Code: {response.status_code}" 129 | ) 130 | response.raise_for_status() 131 | 132 | results = response.json() 133 | if "items" not in results: 134 | return None 135 | 136 | return ( 137 | self._extract_github_repo_info(results["items"]) 138 | if search_type == "repositories" 139 | else self._extract_github_code_info(results["items"]) 140 | ) 141 | 142 | @staticmethod 143 | def _extract_github_repo_info(repos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 144 | return [ 145 | { 146 | "name": repo["name"], 147 | "owner": repo["owner"]["login"], 148 | "stars": repo["stargazers_count"], 149 | "forks": repo["forks_count"], 150 | "url": repo["html_url"], 151 | "description": repo["description"] or "No description provided.", 152 | } 153 | for repo in repos 154 | ] 155 | 156 | @staticmethod 157 | def _extract_github_code_info( 158 | code_results: List[Dict[str, Any]] 159 | ) -> List[Dict[str, Any]]: 160 | return [ 161 | { 162 | "file_name": item["name"], 163 | "repository": item["repository"]["full_name"], 164 | "url": item["html_url"], 165 | } 166 | for item in code_results 167 | ] 168 | 169 | 170 | class PaperSearchTool(BaseTool): 171 | def __init__(self) -> None: 172 | self.s2_api_key = config["core"].get("s2_api_key", None) 173 | 174 | def run(self, query: str) -> Dict[str, Dict[str, str]]: 175 | results = {} 176 | papers = self.search_for_papers(query) 177 | 178 | if papers: 179 | for i, paper in enumerate(papers): 180 | paper_id = paper.get("paperId", None) 181 | bibtex = self.fetch_bibtex(paper_id) if paper_id else "N/A" 182 | 183 | if not bibtex or bibtex == "N/A": 184 | continue 185 | 186 | results[paper["title"]] = {"title": paper["title"], "bibtex": bibtex} 187 | 188 | return results 189 | 190 | def search_for_papers( 191 | self, query: str, result_limit: int = 3 192 | ) -> Optional[List[Dict[str, Any]]]: 193 | if not query: 194 | return None 195 | 196 | engine = config["core"].get("engine", "semanticscholar") 197 | if engine == "semanticscholar": 198 | print( 199 | f"[semantic scholar API calling] Searching for papers with query: {query}" 200 | ) 201 | return self._search_semanticscholar(query, result_limit) 202 | elif engine == "openalex": 203 | print(f"[openalex API calling] Searching for papers with query: {query}") 204 | return self._search_openalex(query, result_limit) 205 | else: 206 | raise NotImplementedError(f"{engine=} not supported!") 207 | 208 | @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) 209 | def _search_semanticscholar( 210 | self, query: str, result_limit: int 211 | ) -> Optional[List[Dict[str, Any]]]: 212 | params: Dict[str, str | int] = { 213 | "query": query, 214 | "limit": result_limit, 215 | "fields": "title,authors,venue,year,abstract,citationStyles,citationCount", 216 | } 217 | 218 | headers = {"X-API-KEY": self.s2_api_key} if self.s2_api_key else {} 219 | rsp = requests.get( 220 | "https://api.semanticscholar.org/graph/v1/paper/search", 221 | headers=headers, 222 | params=params, 223 | ) 224 | rsp.raise_for_status() 225 | 226 | results = rsp.json() 227 | if not results.get("total"): 228 | return None 229 | 230 | time.sleep(1.0) 231 | return cast(Optional[List[Dict[str, Any]]], results.get("data")) 232 | 233 | def _search_openalex( 234 | self, query: str, result_limit: int 235 | ) -> Optional[List[Dict[str, Any]]]: 236 | import pyalex 237 | from pyalex import Works 238 | 239 | mail = os.environ.get("OPENALEX_MAIL_ADDRESS") 240 | if mail: 241 | pyalex.config.email = mail 242 | else: 243 | print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better API access") 244 | 245 | works = Works().search(query).get(per_page=result_limit) 246 | if not works: 247 | return None 248 | 249 | return [self._extract_work_info(work) for work in works] 250 | 251 | @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) 252 | def fetch_bibtex(self, paper_id: str) -> Any: 253 | headers = {"X-API-KEY": self.s2_api_key} if self.s2_api_key else {} 254 | rsp = requests.get( 255 | f"https://api.semanticscholar.org/graph/v1/paper/{paper_id}", 256 | headers=headers, 257 | params={"fields": "citationStyles"}, 258 | ) 259 | rsp.raise_for_status() 260 | citation_styles = rsp.json().get("citationStyles", {}) 261 | return citation_styles.get("bibtex", "N/A") 262 | 263 | @staticmethod 264 | def _extract_work_info( 265 | work: Dict[str, Any], max_abstract_length: int = 1000 266 | ) -> Dict[str, str]: 267 | venue = next( 268 | ( 269 | loc["source"]["display_name"] 270 | for loc in work["locations"] 271 | if loc["source"] 272 | ), 273 | "Unknown", 274 | ) 275 | 276 | authors_list = [ 277 | author["author"]["display_name"] for author in work["authorships"] 278 | ] 279 | authors = ( 280 | " and ".join(authors_list) 281 | if len(authors_list) < 20 282 | else f"{authors_list[0]} et al." 283 | ) 284 | 285 | abstract = work.get("abstract", "") 286 | if len(abstract) > max_abstract_length: 287 | print(f"[WARNING] {work['title']}: Abstract is too long, truncating.") 288 | abstract = abstract[:max_abstract_length] 289 | 290 | return { 291 | "title": work["title"], 292 | "authors": authors, 293 | "venue": venue, 294 | "year": work.get("publication_year", "Unknown"), 295 | "abstract": abstract, 296 | "citationCount": work.get("cited_by_count", 0), 297 | } 298 | 299 | 300 | class DrawerTool(BaseTool): 301 | def __init__( 302 | self, 303 | model: Any, 304 | prompt_template_dir: Optional[str] = None, 305 | temperature: float = 0.75, 306 | ): 307 | self.client, self.model = create_client(model) 308 | self.temperature = temperature 309 | 310 | # Load prompt templates using Config 311 | self.config = Config(prompt_template_dir) 312 | self.prompts = self.config.prompt_template.drawer_prompt 313 | 314 | # Process template instructions 315 | if hasattr(self.prompts, "template_instructions") and hasattr( 316 | self.prompts, "few_shot_instructions" 317 | ): 318 | self.prompts.few_shot_instructions = ( 319 | self.prompts.few_shot_instructions.replace( 320 | "{{ template_instructions }}", self.prompts.template_instructions 321 | ) 322 | ) 323 | 324 | self.dir_path = os.path.dirname(os.path.realpath(__file__)) 325 | 326 | def run(self, query: str) -> Dict[str, Dict[str, str]]: 327 | diagram = self.draw_diagram(query) 328 | results = {} 329 | if diagram: 330 | results["diagram"] = { 331 | "summary": diagram.get("summary", ""), 332 | "svg": diagram.get("svg", ""), 333 | } 334 | return results 335 | 336 | def draw_diagram( 337 | self, 338 | text: str, 339 | example: Optional[str] = None, 340 | msg_history: Optional[List[Dict[str, Any]]] = None, 341 | return_msg_history: bool = False, 342 | drawer_system_prompt: Optional[str] = None, 343 | ) -> Any: 344 | # Use default system prompt if none provided 345 | drawer_system_prompt = ( 346 | drawer_system_prompt or self.prompts.diagram_system_prompt_base 347 | ) 348 | 349 | # Prepare prompt with the few-shot example 350 | base_prompt = self._prepare_diagram_prompt(text, example) 351 | 352 | # Generate diagram 353 | diagram, updated_msg_history = self._generate_diagram( 354 | base_prompt, drawer_system_prompt, msg_history 355 | ) 356 | 357 | return (diagram, updated_msg_history) if return_msg_history else diagram 358 | 359 | def _prepare_diagram_prompt(self, text: str, example: Optional[str] = None) -> str: 360 | if example: 361 | # Format with the example 362 | few_shot_prompt = self.prompts.few_shot_instructions.format(example=example) 363 | base_prompt = f"{few_shot_prompt}\n\nHere is the paper you are asked to create a diagram for:\n```\n{text}\n```" 364 | else: 365 | # Use just the template instructions 366 | base_prompt = f"{self.prompts.template_instructions}\n\nHere is the paper you are asked to create a diagram for:\n```\n{text}\n```" 367 | 368 | return str(base_prompt) 369 | 370 | @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) 371 | def _generate_diagram( 372 | self, 373 | base_prompt: str, 374 | drawer_system_prompt: str, 375 | msg_history: Optional[List[Dict[str, Any]]], 376 | ) -> tuple[Dict[str, Any], List[Dict[str, Any]]]: 377 | # Ensure msg_history is a list 378 | msg_history = msg_history or [] 379 | 380 | # Generate diagram 381 | llm_response, msg_history = get_response_from_llm( 382 | base_prompt, 383 | model=self.model, 384 | client=self.client, 385 | system_message=drawer_system_prompt, 386 | print_debug=False, 387 | msg_history=msg_history, 388 | temperature=self.temperature, 389 | ) 390 | 391 | # Extract the diagram from the response 392 | diagram = self._extract_diagram(llm_response) 393 | 394 | return diagram, msg_history 395 | 396 | def _extract_diagram(self, response: str) -> Dict[str, Any]: 397 | result = {"summary": "", "svg": "", "full_response": response} 398 | 399 | # Extract the summary 400 | summary_start = response.find("SUMMARY:") 401 | if summary_start != -1: 402 | summary_end = response.find("DIAGRAM SVG:", summary_start) 403 | if summary_end != -1: 404 | result["summary"] = response[summary_start + 8 : summary_end].strip() 405 | 406 | # Extract the SVG 407 | svg_start = response.find("```svg", summary_start if summary_start != -1 else 0) 408 | if svg_start == -1: 409 | # Try without language specifier 410 | svg_start = response.find( 411 | "```", summary_start if summary_start != -1 else 0 412 | ) 413 | if svg_start != -1: 414 | svg_start += 3 # Skip past ``` 415 | else: 416 | svg_start += 6 # Skip past ```svg 417 | 418 | if svg_start != -1: 419 | svg_end = response.find("```", svg_start) 420 | if svg_end != -1: 421 | raw_svg = response[svg_start:svg_end].strip() 422 | result["svg"] = self._clean_svg(raw_svg) 423 | 424 | return result 425 | 426 | def _clean_svg(self, svg: str) -> str: 427 | # Strip any outer code block delimiters 428 | svg = svg.strip() 429 | svg = re.sub(r"^```(?:svg)?", "", svg) 430 | svg = re.sub(r"```$", "", svg) 431 | 432 | # Replace problematic ampersands 433 | svg = svg.replace("&", "&") 434 | 435 | # Ensure no double XML declarations 436 | svg = re.sub(r"<\?xml.*?\?>", "", svg, count=1) 437 | 438 | # Remove extra whitespace lines 439 | svg = "\n".join([line for line in svg.splitlines() if line.strip()]) 440 | 441 | return svg.strip() 442 | -------------------------------------------------------------------------------- /tiny_scientist/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulab-uiuc/tiny-scientist/fec7ecf2dc9e30d77ac4cd96395d50b81eb452b0/tiny_scientist/utils/__init__.py -------------------------------------------------------------------------------- /tiny_scientist/utils/bib_manager.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Any, Dict, Optional 3 | 4 | from rich import print 5 | 6 | from .llm import get_response_from_llm 7 | 8 | 9 | class BibManager: 10 | def __init__(self, model: str, client: Any) -> None: 11 | self.model = model 12 | self.client = client 13 | 14 | def _update_bib_cite( 15 | self, references: Dict[str, Any], dest_template_dir: str, template: str 16 | ) -> None: 17 | if template == "acl": 18 | bib_path = osp.join(dest_template_dir, "custom.bib") 19 | if template == "iclr": 20 | # you should create a custom.bib file in the iclr folder 21 | bib_path = osp.join(dest_template_dir, "custom.bib") 22 | 23 | bib_entries = [] 24 | for meta in references.values(): 25 | bibtex = meta.get("bibtex", "").strip() 26 | if bibtex: 27 | bib_entries.append(bibtex) 28 | 29 | if not bib_entries: 30 | print("No BibTeX entries to write.") 31 | return 32 | 33 | # Write all entries to the bib file 34 | with open(bib_path, "w", encoding="utf-8") as f: 35 | f.write("\n\n".join(bib_entries)) 36 | 37 | print(f"custom.bib created with {len(bib_entries)} entries.") 38 | 39 | def _get_bibtex_for_key(self, key: str) -> Optional[str]: 40 | prompt = f"Provide the bibtex entry for the paper with citation key '{key}'. Output only the bibtex entry." 41 | try: 42 | result = get_response_from_llm( 43 | msg=prompt, 44 | client=self.client, 45 | model=self.model, 46 | system_message="You are an expert in academic citations. Please provide a valid bibtex entry.", 47 | ) 48 | 49 | if isinstance(result, tuple): 50 | bibtex_entry = result[0] 51 | else: 52 | bibtex_entry = result 53 | 54 | if ( 55 | isinstance(bibtex_entry, str) 56 | and "@" in bibtex_entry 57 | and key in bibtex_entry 58 | ): 59 | return bibtex_entry.strip() 60 | else: 61 | print(f"Invalid bibtex returned for key: {key}") 62 | return None 63 | 64 | except Exception as e: 65 | print(f"Error fetching bibtex for key '{key}': {e}") 66 | return None 67 | -------------------------------------------------------------------------------- /tiny_scientist/utils/error_handler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from functools import wraps 4 | 5 | from beartype.typing import Any, Callable, Optional, TypeVar, cast 6 | from pydantic import BaseModel 7 | from rich import print 8 | 9 | INF = float(math.inf) 10 | 11 | T = TypeVar("T", bound=Callable[..., Any]) 12 | 13 | 14 | def api_calling_error_exponential_backoff( 15 | retries: int = 5, base_wait_time: int = 1 16 | ) -> Callable[[T], T]: 17 | """ 18 | Decorator for applying exponential backoff to a function. 19 | :param retries: Maximum number of retries. 20 | :param base_wait_time: Base wait time in seconds for the exponential backoff. 21 | :return: The wrapped function with exponential backoff applied. 22 | """ 23 | 24 | def decorator(func: T) -> T: 25 | @wraps(func) 26 | def wrapper(*args: Any, **kwargs: Any) -> Any: 27 | error_handler_mode = kwargs.get("mode", None) 28 | if error_handler_mode == "TEST": 29 | modified_retries = 1 30 | modified_base_wait_time = 1 31 | else: 32 | modified_retries = retries 33 | modified_base_wait_time = base_wait_time 34 | 35 | attempts = 0 36 | while attempts < modified_retries: 37 | try: 38 | return func(*args, **kwargs) 39 | except Exception: 40 | wait_time = modified_base_wait_time * (2**attempts) 41 | print( 42 | f"[API calling error] Attempt {attempts + 1} failed. Waiting {wait_time} seconds before retrying..." 43 | ) 44 | time.sleep(wait_time) 45 | attempts += 1 46 | print( 47 | f"Failed to execute '{func.__name__}' after {modified_retries} retries." 48 | ) 49 | return None 50 | 51 | return cast(T, wrapper) 52 | 53 | return cast(Callable[[T], T], decorator) 54 | 55 | 56 | TBaseModel = TypeVar("TBaseModel", bound=Callable[..., BaseModel]) 57 | 58 | 59 | def parsing_error_exponential_backoff( 60 | retries: int = 5, base_wait_time: int = 1 61 | ) -> Callable[[TBaseModel], TBaseModel]: 62 | """ 63 | Decorator for retrying a function that returns a BaseModel with exponential backoff. 64 | :param retries: Maximum number of retries. 65 | :param base_wait_time: Base wait time in seconds for the exponential backoff. 66 | :return: The wrapped function with retry logic applied. 67 | """ 68 | 69 | def decorator(func: TBaseModel) -> TBaseModel: 70 | @wraps(func) 71 | def wrapper(self: Any, *args: Any, **kwargs: Any) -> Optional[BaseModel]: 72 | attempts = 0 73 | while attempts < retries: 74 | try: 75 | return func(self, *args, **kwargs) 76 | except Exception as e: 77 | wait_time = base_wait_time * (2**attempts) 78 | print(f"Attempt {attempts + 1} failed: {e}") 79 | print(f"Waiting {wait_time} seconds before retrying...") 80 | time.sleep(wait_time) 81 | attempts += 1 82 | print( 83 | f"Failed to get valid input from {func.__name__} after {retries} retries." 84 | ) 85 | return None 86 | 87 | return cast(TBaseModel, wrapper) 88 | 89 | return cast(Callable[[TBaseModel], TBaseModel], decorator) 90 | -------------------------------------------------------------------------------- /tiny_scientist/utils/input_formatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import pymupdf 6 | import pymupdf4llm 7 | from pypdf import PdfReader 8 | from rich import print 9 | 10 | 11 | class InputFormatter: 12 | def _load_paper( 13 | self, pdf_path: str, num_pages: Optional[int] = None, min_size: int = 100 14 | ) -> str: 15 | """ 16 | Loads a PDF, attempting to convert it to Markdown via pymupdf4llm. 17 | If that fails, falls back to direct pymupdf extraction, and then 18 | finally to PyPDF2. Returns the extracted text as a single string. 19 | """ 20 | try: 21 | if num_pages is None: 22 | text = pymupdf4llm.to_markdown(pdf_path) 23 | else: 24 | reader = PdfReader(pdf_path) 25 | min_pages = min(len(reader.pages), num_pages) 26 | text = pymupdf4llm.to_markdown(pdf_path, pages=list(range(min_pages))) 27 | if len(text) < min_size: 28 | raise Exception("Text too short") 29 | except Exception as e: 30 | print(f"Error with pymupdf4llm, falling back to pymupdf: {e}") 31 | try: 32 | doc = pymupdf.open(pdf_path) 33 | if num_pages: 34 | doc = doc[:num_pages] 35 | text = "" 36 | for page in doc: 37 | text += page.get_text() 38 | if len(text) < min_size: 39 | raise Exception("Text too short") 40 | except Exception as e: 41 | print(f"Error with pymupdf, falling back to PyPDF2: {e}") 42 | reader = PdfReader(pdf_path) 43 | if num_pages is None: 44 | text = "".join(page.extract_text() for page in reader.pages) 45 | else: 46 | text = "".join( 47 | page.extract_text() for page in reader.pages[:num_pages] 48 | ) 49 | if len(text) < min_size: 50 | raise Exception("Text too short") 51 | return str(text) 52 | 53 | def _extract_subsections( 54 | self, section_text: str 55 | ) -> Tuple[str, List[Dict[str, str]]]: 56 | """ 57 | Helper function to parse sub-subsections of the form: 58 | '**x.x** **Subsection Title**'. 59 | Returns a tuple (clean_text, subsections), where 'clean_text' is the 60 | remaining text outside these subsections, and 'subsections' is a list 61 | of dicts with the keys 'subsection_number', 'subsection_title', and 62 | 'subsection_content'. 63 | """ 64 | subsections = [] 65 | subsec_pattern = re.compile( 66 | r"(?m)^\*\*(\d+\.\d+)\*\*\s+\*\*(.*?)\*\*\s*(.*?)(?=^\*\*\d+\.\d+\*\*|\Z)", 67 | re.DOTALL, 68 | ) 69 | matches = list(subsec_pattern.finditer(section_text)) 70 | 71 | if not matches: 72 | return section_text.strip(), [] 73 | 74 | leftover_parts = [] 75 | last_end = 0 76 | 77 | for m in matches: 78 | start_idx = m.start() 79 | leftover = section_text[last_end:start_idx] 80 | leftover_parts.append(leftover) 81 | 82 | subsection_number = m.group(1).strip() 83 | subsection_title = m.group(2).strip() 84 | subsection_content = m.group(3).strip() 85 | 86 | subsections.append( 87 | { 88 | "subsection_number": subsection_number, 89 | "subsection_title": subsection_title, 90 | "subsection_content": subsection_content, 91 | } 92 | ) 93 | 94 | last_end = m.end() 95 | 96 | leftover_parts.append(section_text[last_end:]) 97 | clean_text = "\n".join(part.strip() for part in leftover_parts).strip() 98 | 99 | return clean_text, subsections 100 | 101 | def _parse_markdown(self, markdown_str: str) -> Dict[str, Any]: 102 | """ 103 | Parses a markdown document with the following structure: 104 | 105 | 1) Optional document title of the form: 106 | ## My Document Title 107 | 108 | 2) Everything before '### Abstract' goes into 'header'. 109 | 110 | 3) From '### Abstract' onward, each '### Some Heading' is treated 111 | as a top-level section. That section's content is everything 112 | until the next '### ' heading or the end of the document. 113 | 114 | 4) Within each top-level section, sub-sections appear in lines 115 | of the form: 116 | **x.x** **Subsection Title** 117 | and continue until the next sub-section or the next top-level section. 118 | 119 | Returns a dictionary of the form: 120 | 121 | { 122 | "title": "...", 123 | "header": "...", 124 | "sections": [ 125 | { 126 | "section_name": "...", 127 | "content": "...", 128 | "subsections": [ 129 | { 130 | "subsection_number": "x.x", 131 | "subsection_title": "...", 132 | "subsection_content": "..." 133 | }, 134 | ... 135 | ] 136 | }, 137 | ... 138 | ] 139 | } 140 | """ 141 | # 1) Extract optional document title 142 | title_pattern = re.compile(r"(?m)^##\s+(.*)") 143 | title_match = title_pattern.search(markdown_str) 144 | title = "" 145 | if title_match: 146 | title = title_match.group(1).strip() 147 | full_line = title_match.group(0) 148 | markdown_str = markdown_str.replace(full_line, "", 1) 149 | 150 | # 2) Split out "header" from everything after '### Abstract' 151 | split_pattern = r"(?s)(.*?)^### Abstract(.*)" 152 | match = re.search(split_pattern, markdown_str, re.MULTILINE) 153 | 154 | if not match: 155 | return {"title": title, "header": markdown_str.strip(), "sections": []} 156 | 157 | part_before = match.group(1) 158 | part_after = "### Abstract" + match.group(2) 159 | header = part_before.strip() 160 | 161 | # 3) Extract top-level sections from 'part_after' 162 | section_pattern = re.compile( 163 | r"(?m)^###\s+(.*?)\s*\n" r"(.*?)(?=^###\s+|\Z)", re.DOTALL 164 | ) 165 | 166 | raw_sections = section_pattern.findall(part_after) 167 | sections = [] 168 | 169 | # Parse each top-level section 170 | for section_name, section_text in raw_sections: 171 | section_name = section_name.strip() 172 | clean_text, subsections_list = self._extract_subsections(section_text) 173 | section_dict = { 174 | "section_name": section_name, 175 | "content": clean_text, 176 | "subsections": subsections_list, 177 | } 178 | sections.append(section_dict) 179 | 180 | return {"title": title, "header": header, "sections": sections} 181 | 182 | def _load_review(self, review_path: str) -> str: 183 | """ 184 | Loads a JSON file (at review_path) and returns the string under the 'review' key. 185 | The JSON is expected to have the structure: { "review": "..." }. 186 | """ 187 | with open(review_path, "r", encoding="utf-8") as f: 188 | data: Dict[str, str] = json.load(f) 189 | return data["review"] 190 | 191 | def parse_paper_pdf_to_json( 192 | self, pdf_path: str, num_pages: Optional[int] = None, min_size: int = 100 193 | ) -> Dict[str, Any]: 194 | """ 195 | Convenience method to load a PDF, convert it to text, parse the markdown, 196 | and return a structured JSON-like Python dictionary. 197 | 198 | If no sections are found during parsing, returns the raw PDF text in a 199 | compatible format with "title": "", "header": "", and a single section 200 | containing the full text. 201 | """ 202 | pdf_text = self._load_paper(pdf_path, num_pages=num_pages, min_size=min_size) 203 | parsed_result = self._parse_markdown(pdf_text) 204 | 205 | # If no sections were found, return the raw text in a compatible format 206 | if not parsed_result.get("sections"): 207 | print("No sections found in parsed result, returning raw text") 208 | return { 209 | "title": "", 210 | "header": "", 211 | "sections": [ 212 | { 213 | "section_name": "Full Text", 214 | "content": pdf_text, 215 | "subsections": [], 216 | } 217 | ], 218 | } 219 | 220 | return parsed_result 221 | 222 | def parse_review_json(self, review_path: str) -> Dict[str, Any]: 223 | """ 224 | Convenience method to load a JSON 'review' file, then parse it using the 225 | same markdown rules, returning a structured JSON-like dictionary. 226 | """ 227 | review_text = self._load_review(review_path) 228 | return self._parse_markdown(review_text) 229 | -------------------------------------------------------------------------------- /tiny_scientist/utils/output_formatter.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import os.path as osp 4 | import platform 5 | import re 6 | import shutil 7 | import subprocess 8 | import sys 9 | from typing import Any, Dict, Match 10 | 11 | import requests 12 | from rich import print 13 | 14 | from .bib_manager import BibManager 15 | from .water_marker import WaterMarker 16 | 17 | 18 | class BaseOutputFormatter(abc.ABC): 19 | @abc.abstractmethod 20 | def run( 21 | self, 22 | content: Dict[str, Any], 23 | references: Dict[str, Any], 24 | output_dir: str, 25 | output_pdf_path: str, 26 | name: str, 27 | timeout: int = 30, 28 | ) -> None: 29 | pass 30 | 31 | def strip_latex(self, text: str) -> str: 32 | text = re.sub(r"%.*", "", text) 33 | text = re.sub(r"\\[a-zA-Z]+\{.*?\}", "", text) 34 | text = re.sub(r"\\begin\{.*?\}.*?\\end\{.*?\}", "", text, flags=re.DOTALL) 35 | text = re.sub(r"\s+", " ", text).strip() 36 | return text 37 | 38 | def _clean_latex_content(self, content: str) -> str: 39 | match = re.search(r"```latex\s*(.*?)\s*```", content, flags=re.DOTALL) 40 | if match: 41 | return match.group(1) 42 | 43 | # If no code block is found, perform minimal cleaning: 44 | lines = content.splitlines() 45 | cleaned_lines = [] 46 | for line in lines: 47 | stripped = line.strip() 48 | # Remove lines that are exactly code fences (```), but keep inline backticks if any. 49 | if stripped in ["```"]: 50 | continue 51 | # Remove markdown header lines (starting with '#' and not a LaTeX comment) 52 | if stripped.startswith("#") and not stripped.startswith("%"): 53 | continue 54 | cleaned_lines.append(line) 55 | return "\n".join(cleaned_lines) 56 | 57 | def _wrap_tables_in_latex(self, content: str) -> str: 58 | def replacer(match: Match[str]) -> str: 59 | tabular_block = match.group(1) 60 | 61 | # Check if the tabular block is already inside a table environment 62 | if ( 63 | "\\begin{table}" in content[: match.start()] 64 | and "\\end{table}" in content[match.end() :] 65 | ): 66 | return tabular_block # Already inside a table, skip wrapping 67 | 68 | return ( 69 | "\\begin{table}[ht]\n" 70 | "\\centering\n" 71 | "\\resizebox{\\linewidth}{!}{%\n" 72 | f"{tabular_block}\n" 73 | "}\n" 74 | "\\caption{}\n" 75 | "\\label{}\n" 76 | "\\end{table}" 77 | ) 78 | 79 | return re.sub( 80 | r"(\\begin{tabular}.*?\\end{tabular})", replacer, content, flags=re.DOTALL 81 | ) 82 | 83 | def _assemble_body(self, contents: Dict[str, Dict[str, Any]]) -> str: 84 | section_order = [ 85 | "Abstract", 86 | "Introduction", 87 | "Related_Work", 88 | "Method", 89 | "Experimental_Setup", 90 | "Results", 91 | "Discussion", 92 | "Conclusion", 93 | ] 94 | 95 | section_titles = { 96 | "Abstract": None, 97 | "Introduction": "Introduction", 98 | "Related_Work": "Related Work", 99 | "Method": "Method", 100 | "Experimental_Setup": "Experimental Setup", 101 | "Results": "Results", 102 | "Discussion": "Discussion", 103 | "Conclusion": "Conclusion", 104 | } 105 | 106 | body = "" 107 | for section in section_order: 108 | raw = contents.get(section, "") 109 | content = raw.get("text", "") if isinstance(raw, dict) else raw 110 | if content: 111 | cleaned_content = self._clean_latex_content(content) 112 | cleaned_content = self._wrap_tables_in_latex(cleaned_content) 113 | section_title = section_titles[section] 114 | if section_title is not None: 115 | starts_with_section = re.match( 116 | rf"\\section\{{{re.escape(section_title)}\}}", 117 | cleaned_content, 118 | re.IGNORECASE, 119 | ) 120 | starts_with_text = cleaned_content.lower().startswith( 121 | section_title.lower() 122 | ) 123 | if not starts_with_section and not starts_with_text: 124 | body += f"\\section{{{section_title}}}\n" 125 | body += f"{cleaned_content}\n\n" 126 | 127 | body += "\n\n\\bibliography{custom}" 128 | return body 129 | 130 | def _insert_body_into_template( 131 | self, template_text: str, body_content: str, new_title: str 132 | ) -> str: 133 | template_text = re.sub( 134 | r"(\\title\{)[^}]*\}", r"\1" + new_title + r"}", template_text 135 | ) 136 | 137 | begin_doc_match = re.search(r"(\\begin{document})", template_text) 138 | if not begin_doc_match: 139 | raise ValueError("Template is missing \\begin{document}.") 140 | 141 | # Check if there's a \maketitle command after \begin{document} 142 | maketitle_match = re.search(r"(\\maketitle)", template_text) 143 | ending_match = re.search(r"(\\end{document})", template_text) 144 | if not ending_match: 145 | raise ValueError("Template is missing \\end{document}.") 146 | ending = template_text[ending_match.start() :] 147 | 148 | if maketitle_match: 149 | insertion_point = maketitle_match.end() 150 | return template_text[:insertion_point] + "\n" + body_content + "\n" + ending 151 | else: 152 | preamble = template_text[: begin_doc_match.end()] 153 | return preamble + "\n" + body_content + "\n" + ending 154 | 155 | 156 | class TemplateDownloader: 157 | @staticmethod 158 | def download_acl_template(output_dir: str) -> str: 159 | print(f"Downloading ACL template from GitHub to {output_dir}") 160 | dest_template_dir = osp.join(output_dir, "latex") 161 | os.makedirs(dest_template_dir, exist_ok=True) 162 | 163 | # GitHub repository URL for ACL 164 | acl_api_url = ( 165 | "https://api.github.com/repos/acl-org/acl-style-files/contents/latex" 166 | ) 167 | response = requests.get(acl_api_url) 168 | response.raise_for_status() 169 | 170 | files_data = response.json() 171 | for file_info in files_data: 172 | if file_info["type"] == "file": 173 | file_url = file_info["download_url"] 174 | filename = file_info["name"] 175 | 176 | print(f"Downloading {filename}...") 177 | file_response = requests.get(file_url) 178 | file_response.raise_for_status() 179 | 180 | with open(osp.join(dest_template_dir, filename), "wb") as f: 181 | f.write(file_response.content) 182 | 183 | return dest_template_dir 184 | 185 | @staticmethod 186 | def download_iclr_template(output_dir: str) -> str: 187 | print(f"Downloading ICLR template from GitHub to {output_dir}") 188 | dest_template_dir = osp.join(output_dir, "latex") 189 | os.makedirs(dest_template_dir, exist_ok=True) 190 | 191 | # Get list of files in the iclr2025 directory 192 | iclr_api_url = ( 193 | "https://api.github.com/repos/ICLR/Master-Template/contents/iclr2025" 194 | ) 195 | response = requests.get(iclr_api_url) 196 | response.raise_for_status() 197 | 198 | files_data = response.json() 199 | 200 | # Download each file in the directory 201 | for file_info in files_data: 202 | if file_info["type"] == "file": 203 | file_url = file_info["download_url"] 204 | filename = file_info["name"] 205 | 206 | print(f"Downloading {filename}...") 207 | file_response = requests.get(file_url) 208 | file_response.raise_for_status() 209 | 210 | with open(osp.join(dest_template_dir, filename), "wb") as f: 211 | f.write(file_response.content) 212 | 213 | return dest_template_dir 214 | 215 | 216 | class ACLOutputFormatter(BaseOutputFormatter): 217 | def __init__(self, model: str, client: Any) -> None: 218 | self.template = "acl" 219 | self.bib_manager = BibManager(model, client) 220 | self.watermarker = WaterMarker() 221 | 222 | def run( 223 | self, 224 | content: Dict[str, Any], 225 | references: Dict[str, Any], 226 | output_dir: str, 227 | output_pdf_path: str, 228 | name: str, 229 | timeout: int = 30, 230 | ) -> None: 231 | body_content = self._assemble_body(content) 232 | dest_template_dir = TemplateDownloader.download_acl_template(output_dir) 233 | 234 | self.bib_manager._update_bib_cite(references, dest_template_dir, self.template) 235 | 236 | main_tex_path = osp.join(dest_template_dir, "acl_latex.tex") 237 | 238 | with open(main_tex_path, "r", encoding="utf-8") as f: 239 | template_text = f.read() 240 | 241 | final_content = self._insert_body_into_template( 242 | template_text, body_content, name 243 | ) 244 | 245 | with open(main_tex_path, "w", encoding="utf-8") as f: 246 | f.write(final_content) 247 | 248 | with open(main_tex_path, "r") as f: 249 | final_content = f.read() 250 | 251 | self._compile_latex(dest_template_dir, output_pdf_path, timeout) 252 | self.watermarker._add_watermark( 253 | output_pdf_path, 254 | watermark_text="CAUTION!!! THIS PAPER WAS AUTONOMOUSLY GENERATED BY THE TINY_SCIENTIST", 255 | output_pdf_path=output_pdf_path, 256 | ) 257 | 258 | def _set_output_dir(self, output_dir: str) -> str: 259 | script_dir = osp.dirname(__file__) 260 | project_root = osp.abspath(osp.join(script_dir, "..")) 261 | source_template_dir = osp.join( 262 | project_root, "tiny_scientist", f"{self.template}_latex" 263 | ) 264 | 265 | if osp.isdir(source_template_dir): 266 | dest_template_dir = osp.join(output_dir, "latex") 267 | 268 | if osp.exists(dest_template_dir): 269 | shutil.rmtree(dest_template_dir) 270 | shutil.copytree(source_template_dir, dest_template_dir) 271 | 272 | return dest_template_dir 273 | 274 | def _compile_latex(self, cwd: str, output_pdf_path: str, timeout: int) -> None: 275 | def _ensure_pdflatex() -> None: 276 | if shutil.which("pdflatex") is not None: 277 | return 278 | system = platform.system() 279 | print("[System] pdflatex not found. Attempting to install...") 280 | 281 | try: 282 | if system == "Darwin": 283 | subprocess.run(["brew", "install", "--cask", "mactex"], check=True) 284 | print("[System] Installed MacTeX via Homebrew.") 285 | elif system == "Linux": 286 | subprocess.run(["sudo", "apt-get", "update"], check=True) 287 | subprocess.run( 288 | ["sudo", "apt-get", "install", "-y", "texlive-full"], check=True 289 | ) 290 | print("[System] Installed TeX Live via apt.") 291 | else: 292 | raise RuntimeError( 293 | "Unsupported system for automatic pdflatex installation." 294 | ) 295 | except Exception as e: 296 | print(f"[Error] Automatic pdflatex installation failed: {e}") 297 | sys.exit(1) 298 | 299 | _ensure_pdflatex() 300 | 301 | fname = "acl_latex.tex" 302 | compile_target = fname 303 | 304 | if not osp.exists(osp.join(cwd, compile_target)): 305 | print(f"File {compile_target} not found in {cwd}.") 306 | return 307 | 308 | if not compile_target: 309 | print("Error: No .tex file found to compile. Aborting.") 310 | return 311 | 312 | commands = [ 313 | ["pdflatex", "-interaction=nonstopmode", compile_target], 314 | ["bibtex", compile_target.replace(".tex", "")], 315 | ["pdflatex", "-interaction=nonstopmode", compile_target], 316 | ["pdflatex", "-interaction=nonstopmode", compile_target], 317 | ] 318 | for command in commands: 319 | try: 320 | result = subprocess.run( 321 | command, 322 | cwd=cwd, 323 | stdout=subprocess.PIPE, 324 | stderr=subprocess.PIPE, 325 | timeout=timeout, 326 | ) 327 | print("Standard Output:\n", result.stdout) 328 | print("Standard Error:\n", result.stderr) 329 | except subprocess.TimeoutExpired: 330 | print(f"Latex timed out after {timeout} seconds") 331 | except subprocess.CalledProcessError as e: 332 | print(f"Error running command {' '.join(command)}: {e}") 333 | print("FINISHED GENERATING LATEX") 334 | # The PDF name is the same as compile_target minus .tex, e.g. 'latex.pdf' or 'template.pdf' 335 | pdf_name = compile_target.replace(".tex", ".pdf") 336 | try: 337 | shutil.move(osp.join(cwd, pdf_name), output_pdf_path) 338 | except FileNotFoundError: 339 | print("Failed to rename PDF.") 340 | 341 | 342 | class ICLROutputFormatter(BaseOutputFormatter): 343 | def __init__(self, model: str, client: Any) -> None: 344 | self.template = "iclr" 345 | self.bib_manager = BibManager(model, client) 346 | self.watermarker = WaterMarker() 347 | 348 | def run( 349 | self, 350 | content: Dict[str, Any], 351 | references: Dict[str, Any], 352 | output_dir: str, 353 | output_pdf_path: str, 354 | name: str, 355 | timeout: int = 30, 356 | ) -> None: 357 | body_content = self._assemble_body(content) 358 | dest_template_dir = TemplateDownloader.download_iclr_template(output_dir) 359 | 360 | self.bib_manager._update_bib_cite(references, dest_template_dir, self.template) 361 | 362 | main_tex_path = osp.join(dest_template_dir, "iclr2025_conference.tex") 363 | 364 | with open(main_tex_path, "r", encoding="utf-8") as f: 365 | template_text = f.read() 366 | 367 | final_content = self._insert_body_into_template( 368 | template_text, body_content, name 369 | ) 370 | 371 | with open(main_tex_path, "w", encoding="utf-8") as f: 372 | f.write(final_content) 373 | 374 | with open(main_tex_path, "r") as f: 375 | final_content = f.read() 376 | 377 | self._compile_latex(dest_template_dir, output_pdf_path, timeout) 378 | self.watermarker._add_watermark( 379 | output_pdf_path, 380 | watermark_text="CAUTION!!! THIS PAPER WAS AUTONOMOUSLY GENERATED BY THE TINY_SCIENTIST", 381 | output_pdf_path=output_pdf_path, 382 | ) 383 | 384 | def _set_output_dir(self, output_dir: str) -> str: 385 | script_dir = osp.dirname(__file__) 386 | project_root = osp.abspath(osp.join(script_dir, "..")) 387 | source_template_dir = osp.join( 388 | project_root, "tiny_scientist", f"{self.template}_latex" 389 | ) 390 | 391 | if osp.isdir(source_template_dir): 392 | dest_template_dir = osp.join(output_dir, "latex") 393 | 394 | if osp.exists(dest_template_dir): 395 | shutil.rmtree(dest_template_dir) 396 | shutil.copytree(source_template_dir, dest_template_dir) 397 | 398 | return dest_template_dir 399 | 400 | def _compile_latex(self, cwd: str, output_pdf_path: str, timeout: int) -> None: 401 | def _ensure_pdflatex() -> None: 402 | if shutil.which("pdflatex") is not None: 403 | return 404 | system = platform.system() 405 | print("[System] pdflatex not found. Attempting to install...") 406 | 407 | try: 408 | if system == "Darwin": 409 | subprocess.run(["brew", "install", "--cask", "mactex"], check=True) 410 | print("[System] Installed MacTeX via Homebrew.") 411 | elif system == "Linux": 412 | subprocess.run(["sudo", "apt-get", "update"], check=True) 413 | subprocess.run( 414 | ["sudo", "apt-get", "install", "-y", "texlive-full"], check=True 415 | ) 416 | print("[System] Installed TeX Live via apt.") 417 | else: 418 | raise RuntimeError( 419 | "Unsupported system for automatic pdflatex installation." 420 | ) 421 | except Exception as e: 422 | print(f"[Error] Automatic pdflatex installation failed: {e}") 423 | sys.exit(1) 424 | 425 | _ensure_pdflatex() 426 | 427 | fname = "iclr2025_conference.tex" 428 | 429 | compile_target = fname 430 | if not osp.exists(osp.join(cwd, compile_target)): 431 | print(f"File {compile_target} not found in {cwd}.") 432 | return 433 | 434 | if not compile_target: 435 | print("Error: No .tex file found to compile. Aborting.") 436 | return 437 | 438 | commands = [ 439 | ["pdflatex", "-interaction=nonstopmode", compile_target], 440 | ["bibtex", compile_target.replace(".tex", "")], 441 | ["pdflatex", "-interaction=nonstopmode", compile_target], 442 | ["pdflatex", "-interaction=nonstopmode", compile_target], 443 | ] 444 | for command in commands: 445 | try: 446 | result = subprocess.run( 447 | command, 448 | cwd=cwd, 449 | stdout=subprocess.PIPE, 450 | stderr=subprocess.PIPE, 451 | text=True, 452 | timeout=timeout, 453 | ) 454 | print("Standard Output:\n", result.stdout) 455 | print("Standard Error:\n", result.stderr) 456 | except subprocess.TimeoutExpired: 457 | print(f"Latex timed out after {timeout} seconds") 458 | except subprocess.CalledProcessError as e: 459 | print(f"Error running command {' '.join(command)}: {e}") 460 | print("FINISHED GENERATING LATEX") 461 | # The PDF name is the same as compile_target minus .tex, e.g. 'latex.pdf' or 'template.pdf' 462 | pdf_name = compile_target.replace(".tex", ".pdf") 463 | try: 464 | shutil.move(osp.join(cwd, pdf_name), output_pdf_path) 465 | except FileNotFoundError: 466 | print("Failed to rename PDF.") 467 | -------------------------------------------------------------------------------- /tiny_scientist/utils/pricing.py: -------------------------------------------------------------------------------- 1 | # Pricing data for each model (prices are in dollars per million (1,000,000) tokens) 2 | MODEL_PRICING = { 3 | # OpenAI models 4 | "gpt-3.5-turbo": (0.5, 1.5), 5 | "gpt-4o-mini": (0.15, 0.6), 6 | "gpt-4o": (2.5, 10), 7 | "o1-preview": (15, 60), 8 | "o1-mini": (1.1, 4.4), 9 | "o1": (15, 60), 10 | # OpenRouter models 11 | "llama3.1-405b": (3.5, 3.5), 12 | # Anthropic models 13 | "claude-3-sonnet-v1": (0.8, 4), 14 | "claude-3-sonnet": (3, 15), 15 | "claude-3-5-sonnet-v2": (3, 15), 16 | "claude-3-5-sonnet": (3, 15), 17 | "claude-3-haiku-v1": (0.25, 1.25), 18 | "claude-3-haiku": (0.25, 1.25), 19 | "claude-3-opus-v1": (15, 75), 20 | "claude-3-opus": (0.8, 4), 21 | # DeepSeek models 22 | "deepseek-chat": (0.07, 0.27), 23 | "deepseek-reasoner": (0.14, 0.55), 24 | # Google Gemini models 25 | "gemini-1.5-flash": (0.01875, 0.075), 26 | "gemini-1.5-pro": (0.3125, 1.25), 27 | } 28 | 29 | 30 | def calculate_pricing(model: str, input_tokens: int, output_tokens: int) -> float: 31 | # Check if the model exists 32 | if model not in MODEL_PRICING: 33 | for m in MODEL_PRICING: 34 | if model.startswith(m): 35 | model = m 36 | else: 37 | raise ValueError(f"Pricing for '{model}' is not found.") 38 | 39 | input_price, output_price = MODEL_PRICING[model] 40 | 41 | # Check if pricing data is available 42 | if input_price is None or output_price is None: 43 | raise ValueError(f"Pricing for '{model}' is unavailable.") 44 | 45 | # The pricing is per million (1,000,000) tokens. 46 | input_cost = (input_tokens / 1000000) * input_price 47 | output_cost = (output_tokens / 1000000) * output_price 48 | 49 | total_cost = input_cost + output_cost 50 | return total_cost 51 | -------------------------------------------------------------------------------- /tiny_scientist/utils/water_marker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import textwrap 4 | 5 | from pypdf import PageObject, PdfReader, PdfWriter 6 | from reportlab.lib.colors import Color 7 | from reportlab.lib.pagesizes import letter 8 | from reportlab.pdfgen import canvas 9 | from rich import print 10 | 11 | 12 | class WaterMarker: 13 | def _add_watermark( 14 | self, original_pdf_path: str, watermark_text: str, output_pdf_path: str 15 | ) -> None: 16 | with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: 17 | watermark_pdf_path = tmp_file.name 18 | 19 | c = canvas.Canvas(watermark_pdf_path, pagesize=letter) 20 | c.saveState() 21 | c.translate(300, 400) 22 | c.rotate(45) 23 | c.setFillColor(Color(0.95, 0.95, 0.95)) 24 | c.setFont("Helvetica-Bold", 28) 25 | 26 | max_chars_per_line = 30 27 | lines = textwrap.wrap(watermark_text, width=max_chars_per_line) 28 | 29 | line_height = 35 30 | y_offset = 0 31 | for line in lines: 32 | c.drawCentredString(0, y_offset, line) 33 | y_offset -= line_height 34 | c.restoreState() 35 | c.showPage() 36 | c.save() 37 | 38 | original_reader = PdfReader(original_pdf_path) 39 | watermark_reader = PdfReader(watermark_pdf_path) 40 | if len(watermark_reader.pages) == 0: 41 | print("Warning: Watermark PDF is empty. No watermark will be applied.") 42 | return 43 | 44 | watermark_page = watermark_reader.pages[0] 45 | writer = PdfWriter() 46 | 47 | for orig_page in original_reader.pages: 48 | new_page = PageObject.create_blank_page( 49 | width=orig_page.mediabox.width, height=orig_page.mediabox.height 50 | ) 51 | 52 | new_page.merge_page(watermark_page) 53 | new_page.merge_page(orig_page) 54 | 55 | writer.add_page(new_page) 56 | 57 | with open(output_pdf_path, "wb") as out_f: 58 | writer.write(out_f) 59 | print(f"Watermarked PDF saved to: {output_pdf_path}") 60 | os.remove(watermark_pdf_path) 61 | --------------------------------------------------------------------------------