├── .gitattributes ├── .github ├── CODEOWNERS ├── cache_mybinder.py └── workflows │ ├── build.yml │ └── repo2docker.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── docs ├── _config.yml ├── _toc.yml ├── contributing.md ├── index.md ├── overview.png ├── references.bib ├── release_history.rst ├── set_version.py └── tutorials │ ├── _init.py │ ├── basics_0_spps.ipynb │ ├── basics_1_using_components.ipynb │ ├── basics_2_working_with_data.ipynb │ ├── basics_3_parallelization.ipynb │ ├── basics_4_minibatching.ipynb │ ├── basics_5_rag.ipynb │ ├── example_instruction_optimization.ipynb │ ├── example_prompt_engineering.ipynb │ ├── example_sammo_express.ipynb │ ├── quickstart.ipynb │ └── special_topics │ ├── 0_llm_apis.ipynb │ ├── 1_custom_runners.ipynb │ ├── 2_handling_failures.ipynb │ ├── 3_rate_limiting.ipynb │ └── 4_structured_outputs.ipynb ├── environment.yml ├── examples ├── blog │ └── stop_wasting_tokens.ipynb ├── paper_instruction_tuning │ ├── data_splits.json │ ├── instruction_tuning_dspy.py │ └── instruction_tuning_sammo.py ├── paper_prompt_compression │ ├── data_splits.json │ └── prompt_compression.py └── paper_rag │ ├── rag_tuning_dspy.py │ └── rag_tuning_sammo.py ├── pyproject.toml └── sammo ├── __init__.py ├── base.py ├── base_test.py ├── compactbars.py ├── compactbars_test.py ├── components.py ├── components_test.py ├── css_matching.py ├── css_matching_test.py ├── data.py ├── data_tests.py ├── dataformatters.py ├── dataformatters_test.py ├── express.py ├── express_test.py ├── extractors.py ├── extractors_test.py ├── instructions.py ├── instructions_test.py ├── integration_test.py ├── mutators.py ├── mutators_test.py ├── runners.py ├── runners_test.py ├── scheduler.py ├── search.py ├── search_op.py ├── search_op_test.py ├── store.py ├── store_test.py ├── throttler.py ├── throttler_test.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in 2 | # the repo. Unless a later match takes precedence, 3 | # @global-owner1 and @global-owner2 will be requested for 4 | # review when someone opens a pull request. 5 | * @t-schn @pbourke 6 | -------------------------------------------------------------------------------- /.github/cache_mybinder.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import playwright 3 | from playwright.async_api import async_playwright 4 | 5 | TIMEOUT = 10 * 60 * 1000 # 8 minutes 6 | 7 | 8 | async def handle_route(route): 9 | print("Replacing bundle.js") 10 | response = await route.fetch() 11 | body = await response.text() 12 | body = body.replace("var i=[];e.o", "var i=window.myglobal=[];e.o") 13 | await route.fulfill( 14 | response=response, 15 | body=body, 16 | headers=response.headers, 17 | ) 18 | 19 | 20 | async def prefetch_binder( 21 | url="https://mybinder.org/v2/gh/microsoft/sammo/main?urlpath=tree/docs/tutorials/quickstart.ipynb", 22 | ): 23 | async with async_playwright() as pw: 24 | browser = await pw.chromium.launch() 25 | page = await browser.new_page() 26 | await page.route("**/bundle.js*", handle_route) 27 | await page.goto(url) 28 | 29 | old_log = [] 30 | while True: 31 | try: 32 | current_log = await page.evaluate("() => window.myglobal || []") 33 | current_log = [x for x in current_log if x.strip() != ""] 34 | if current_log != old_log: 35 | print("".join(current_log[len(old_log) :]), flush=True) 36 | old_log = current_log 37 | await asyncio.sleep(1) 38 | except playwright._impl._errors.Error: 39 | print(f"Redirected to {page.url}") 40 | break 41 | await browser.close() 42 | 43 | 44 | asyncio.get_event_loop().run_until_complete(prefetch_binder()) 45 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | run-name: Build Pipeline 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request_target: 8 | branches: 9 | - main 10 | release: 11 | types: [published] 12 | workflow_dispatch: 13 | jobs: 14 | build: 15 | name: Build 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: 21 | - ubuntu-latest 22 | - windows-latest 23 | - macos-latest 24 | python: 25 | - "3.11" 26 | - "3.12" 27 | 28 | runs-on: ${{ matrix.os }} 29 | 30 | defaults: 31 | run: 32 | shell: bash 33 | 34 | steps: 35 | - name: Checkout repository 36 | uses: actions/checkout@v4 37 | - name: Setup Python 38 | uses: actions/setup-python@v5 39 | id: setup-python 40 | with: 41 | python-version: ${{ matrix.python }} 42 | - name: Setup pipx for build tool isolation 43 | run: | 44 | pip install --user pipx 45 | pipx ensurepath 46 | - name: Set up poetry and install dependencies 47 | run: | 48 | pipx install --python '${{ steps.setup-python.outputs.python-path }}' poetry 49 | poetry config virtualenvs.create true --local 50 | poetry config virtualenvs.in-project true --local 51 | poetry install --with dev 52 | - name: Run pre-commit checks 53 | run: | 54 | poetry run poe pre-commit 55 | - name: Run static type checks 56 | continue-on-error: true 57 | run: | 58 | poetry run poe type-check 59 | - name: Run tests 60 | run: | 61 | poetry run poe test -v 62 | - name: Build wheel and sdist for distribution 63 | run: | 64 | pipx install --python '${{ steps.setup-python.outputs.python-path }}' build 65 | pyproject-build 66 | - name: Check distribution with twine 67 | run: | 68 | pipx install --python '${{ steps.setup-python.outputs.python-path }}' twine 69 | twine check --strict dist/* 70 | - name: Store the distribution files 71 | uses: actions/upload-artifact@v4 72 | if: matrix.os == 'ubuntu-latest' && matrix.python == '3.11' 73 | with: 74 | name: python-package-distributions 75 | path: dist/ 76 | - name: Build documentation artifacts 77 | run: | 78 | poetry run poe build-docs 79 | - name: Zip documentation 80 | if: matrix.os == 'ubuntu-latest' && matrix.python == '3.11' 81 | run: | 82 | DOCS_VERSION=`poetry version -s` 83 | mkdir _build_docs/dist 84 | pushd _build_docs/_build/html 85 | zip -r ../../dist/sammo-docs-$DOCS_VERSION.zip . 86 | popd 87 | - name: Store the documentation artifacts for GitHub Pages 88 | uses: actions/upload-pages-artifact@v3 89 | if: matrix.os == 'ubuntu-latest' && matrix.python == '3.11' 90 | with: 91 | path: _build_docs/_build/html 92 | - name: Store the zipped documentation 93 | uses: actions/upload-artifact@v4 94 | if: matrix.os == 'ubuntu-latest' && matrix.python == '3.11' 95 | with: 96 | name: zipped-documentation 97 | path: _build_docs/dist/ 98 | 99 | release-publish-artifacts: 100 | name: Publish Artifacts 101 | runs-on: ubuntu-latest 102 | needs: build 103 | if: github.event_name == 'release' 104 | defaults: 105 | run: 106 | shell: bash 107 | 108 | environment: 109 | name: pypi 110 | url: https://pypi.org/p/sammo 111 | 112 | permissions: 113 | id-token: write # IMPORTANT: mandatory for trusted publishing 114 | contents: write # Allow artifacts to be uploaded to release on GH 115 | 116 | steps: 117 | - name: Download the distribution artifacts 118 | uses: actions/download-artifact@v4 119 | with: 120 | name: python-package-distributions 121 | path: dist/ 122 | - name: Publish distribution to PyPI 123 | uses: pypa/gh-action-pypi-publish@release/v1 124 | - name: Download zipped documentation 125 | uses: actions/download-artifact@v4 126 | with: 127 | name: zipped-documentation 128 | path: dist/ 129 | - name: Generate SHA256 checksums for all artifacts 130 | run: | 131 | sha256sum dist/*.whl > checksums.txt 132 | sha256sum dist/*.tar.gz >> checksums.txt 133 | sha256sum dist/*.zip >> checksums.txt 134 | cat checksums.txt 135 | - name: Update release with SHA256 and Artifacts 136 | uses: softprops/action-gh-release@v1 137 | with: 138 | token: ${{ secrets.GITHUB_TOKEN }} 139 | files: | 140 | dist/* 141 | checksums.txt 142 | 143 | release-publish-pages: 144 | name: Publish Documentation Site 145 | runs-on: ubuntu-latest 146 | needs: build 147 | if: github.event_name == 'release' || github.event_name == 'workflow_dispatch' 148 | defaults: 149 | run: 150 | shell: bash 151 | 152 | permissions: 153 | pages: write # to deploy to Pages 154 | id-token: write # to verify the deployment originates from an appropriate source 155 | 156 | environment: 157 | name: github-pages 158 | url: ${{ steps.deployment.outputs.page_url }} 159 | 160 | steps: 161 | - name: Download zipped documentation 162 | uses: actions/download-artifact@v4 163 | with: 164 | name: zipped-documentation 165 | path: dist/ 166 | - name: Deploy to GitHub Pages 167 | id: deployment 168 | uses: actions/deploy-pages@v4 169 | -------------------------------------------------------------------------------- /.github/workflows/repo2docker.yml: -------------------------------------------------------------------------------- 1 | name: Cache MyBinder 2 | on: 3 | release: 4 | types: 5 | - published 6 | permissions: 7 | contents: read 8 | jobs: 9 | prefetch: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: 3.x 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | - name: Install Playwright 21 | run: | 22 | pip install playwright 23 | playwright install chromium # Install necessary browsers for Playwright 24 | - name: Run upload.py script 25 | run: | 26 | python .github/cache_mybinder.py 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /config 2 | build 3 | **/cache/ 4 | *.lock 5 | .*/ 6 | !.github/ 7 | _*/ 8 | /dist 9 | /cache 10 | /deprecated 11 | /data 12 | /utils 13 | /examples 14 | /*.ipynb 15 | /*.json 16 | *.txt 17 | .python-version 18 | sammo-env/ 19 | docs/api/ 20 | poetry.toml 21 | *.swp 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # See https://pre-commit.com for more information 3 | # See https://pre-commit.com/hooks.html for more hooks 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.5.0 7 | hooks: 8 | - id: mixed-line-ending 9 | args: [--fix=no] 10 | - id: trailing-whitespace 11 | args: [--markdown-linebreak-ext=md] 12 | - repo: https://github.com/psf/black-pre-commit-mirror 13 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 14 | rev: 23.12.1 15 | hooks: 16 | - id: black 17 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | Welcome, and thank you for contributing to the project! 4 | 5 | ## Setting up your dev environment 6 | 7 | This project uses [Poetry](https://python-poetry.org/) for project management. Some tasks have been standardized to execute via the [Poe](https://poethepoet.natn.io/) task runner. 8 | 9 | We recommend that you install Poetry using [pipx](https://pipx.pypa.io/stable/) so that it's isolated from the sammo codebase. 10 | ### Step 1: Install poetry 11 | ``` 12 | pipx install poetry 13 | ``` 14 | 15 | ### Step 2: Create a separate environment 16 | 17 | #### Option 1: Let poetry create a venv 18 | Optional, but recommended: have poetry create a venv in the project folder rather than in its cache dir 19 | 20 | ``` 21 | poetry config virtualenvs.in-project true --local 22 | ``` 23 | 24 | #### Option 2: Let poetry use a conda env 25 | 26 | ``` 27 | conda create --name sammo python=3.11 28 | conda activate sammo 29 | ``` 30 | 31 | Note: you can skip the `poetry run` prefix for the rest of the commands if you use a conda env. 32 | 33 | ### Step 3: Install library and tooling 34 | Check out and install the dev dependencies 35 | ``` 36 | # assume HTTPS, adjust for SSH 37 | git clone https://github.com/microsoft/sammo.git 38 | cd sammo 39 | poetry install --with dev 40 | ``` 41 | 42 | Set up pre-commit hooks 43 | 44 | ``` 45 | poetry run pre-commit install 46 | ``` 47 | 48 | 49 | Show the configured tasks available through the Poe runner: 50 | 51 | ``` 52 | poetry run poe 53 | ``` 54 | ## Running Tests 55 | 56 | The [pytest](https://docs.pytest.org/) tests can be run using the following command 57 | 58 | ``` 59 | poetry run poe test 60 | ``` 61 | 62 | arguments can be appended (ie for verbose mode) 63 | 64 | ``` 65 | poetry run poe test -v 66 | ``` 67 | 68 | ## Running Type Checks 69 | 70 | ``` 71 | poetry run poe type-check 72 | ``` 73 | 74 | ## Building and previewing documentation 75 | 76 | This project uses [Jupyter Book](https://jupyterbook.org/) for documentation. The documentation configuration and contents are contained in the `docs` folder. 77 | 78 | To build the documentation, run the following command: 79 | 80 | ``` 81 | poetry run poe build-docs 82 | ``` 83 | 84 | to preview it using Python's built-in HTTP server, run: 85 | 86 | ``` 87 | poetry run poe serve-docs 88 | ``` 89 | 90 | This will open a server accessible at http://localhost:8000 to preview the documentation site. You can change the host and port as needed (these arguments just pass through to the call to `http.server`): 91 | 92 | ``` 93 | poetry run poe serve-docs -b 0.0.0.0 8001 94 | ``` 95 | 96 | ## PR workflow 97 | 98 | All changes must come through a pull request on a feature branch. 99 | 100 | 1. If there isn't an existing issue for your change, please make one 101 | 1. Ensure your local main branch is up to date 102 | 1. Create a new branch to hold your changes. Suggested branch naming convention is `/`. For example `pbourke/update-contributor-docs`. 103 | 1. Run `poetry version`. If the current version is **not** a pre-release (ie 0.1.0.6 vs 0.1.0.6rc0), then bump to the next pre-release version: 104 | 105 | ``` 106 | # example version bump 107 | $ poetry version 108 | sammo 0.1.0.6 109 | $ poetry version 0.1.0.7rc0 110 | Bumping version from 0.1.0.6 to 0.1.0.7rc0 111 | ``` 112 | 1. Make your changes and commit to your feature branch. 113 | 1. Push to GitHub as appropriate (to your fork for non-maintainers) 114 | 1. Open a Pull Request to the project and reference the associated issue from your PR 115 | 1. GitHub Actions will run automated checks and tests 116 | 1. When you're ready, request review from the maintainers 117 | 118 | ## Release Process 119 | 120 | The following instructions are for maintainers 121 | 122 | 1. Each release should begin with a PR to (at the least) update the version from pre-release to final 123 | 1. Decide on the new version number by following [Semantic Versioning](https://semver.org/) principles 124 | 1. After the release PR is merged, the release can be made from the main branch. Each release is given a tag on the main branch with the version number (this happens automatically via the GH release mechanism) 125 | 1. Go to [the sammo project releases page](https://github.com/microsoft/sammo/releases) and click "Draft a new release" 126 | 1. Enter the new version number as the tag and release title and give a brief description 127 | 1. Click "Publish release" 128 | 1. A GitHub Actions release hook will run the automated checks and tests, publish the package to PyPI and publish the documentation to the GitHub Pages site 129 | 130 | 131 | ## Mcrosoft Contributor License Agreement 132 | 133 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 134 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 135 | the rights to use your contribution. For details, visit . 136 | 137 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 138 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 139 | provided by the bot. You will only need to do this once across all repos using our CLA. 140 | 141 | ## Code of Conduct 142 | 143 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 144 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 145 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com>) with any additional questions or comments. 146 | 147 | ## Submitting an Issue 148 | 149 | Please [search for your issue](https://github.com/microsoft/sammo/issues?q=is%3Aissue) before submitting a new one. 150 | 151 | If nothing relevant shows up, please do [open a new issue](https://github.com/microsoft/sammo/issues/new) and provide as much detail as you can (ie: OS, python version, data formats, etc). Outputs of commands, error logs, source code snippets, etc are welcomed and will help to trace down the issue. Questions are also welcomed as they provide an opportunity for us to improve the documentation. 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAMMO ([📘User Guide](https://microsoft.github.io/sammo/)) 2 | 3 | [![Latest PyPI version](https://img.shields.io/pypi/v/sammo.svg)](https://pypi.python.org/pypi/sammo) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/microsoft/sammo/main?urlpath=tree/docs/tutorials/quickstart.ipynb) 6 | 7 | A flexible, easy-to-use library for running and optimizing prompts for Large Language Models (LLMs). 8 | 9 | ## 🎉 News 10 | - Nov 13, 2024: Turn Markdown into prompt programs: First version of SAMMO express released 11 | - Nov 1, 2024: Use CSS selectors to query and modify prompt programs! 12 | - Oct 15, 2024: SAMMO now supports structured outputs! 13 | 14 | ## How to Get Started 15 | Go to the [user guide](https://microsoft.github.io/sammo/) for examples, how-tos, and API reference. 16 | 17 | Just want to have a quick look? Try the [live demo on Binder](https://mybinder.org/v2/gh/microsoft/sammo/main?urlpath=tree/docs/tutorials/quickstart.ipynb). 18 | 19 | 20 | ### Install library only 21 | 22 | ```bash 23 | pip install sammo 24 | ``` 25 | 26 | ### Install and run tutorials 27 | 28 | ***Prerequisites*** 29 | * Python 3.9+ 30 | 31 | The following commands will install sammo and jupyter and launch jupyter notebook. It's recommended that you create and activate a virtualenv prior to installing packages. 32 | 33 | ```bash 34 | pip install sammo jupyter 35 | 36 | # clone sammo to a local directory 37 | git clone https://github.com/microsoft/sammo.git 38 | cd sammo 39 | 40 | # launch jupyter notebook and open tutorials directory 41 | jupyter notebook --notebook-dir docs/tutorials 42 | ``` 43 | 44 | ## Example 45 | This example shows how easy it is to optimize a prompt with SAMMO. The full example is in the [user guide](https://microsoft.github.io/sammo/). 46 | ```python 47 | runner = OpenAIChat(model_id="gpt-3.5-turbo", api_config=API_CONFIG) 48 | PROMPT_IN_MARKDOWN = """ 49 | # Instructions 50 | Convert the following user queries into a SQL query. 51 | 52 | # Table 53 | Users: 54 | - user_id (INTEGER, PRIMARY KEY) 55 | - name (TEXT) 56 | - age (INTEGER) 57 | - city (TEXT) 58 | 59 | # Complete this 60 | Input: {{{input}}} 61 | Output: 62 | """ 63 | 64 | spp = MarkdownParser(PROMPT_IN_MARKDOWN).get_sammo_program() 65 | mutation_operators = BagOfMutators( 66 | Output(GenerateText(spp)), 67 | Paraphrase("#instr"), 68 | Rewrite("#instr", "Make this more verbose.\n\n {{{{text}}}}") 69 | ) 70 | prompt_optimizer = BeamSearch(runner, mutation_operators, accuracy) 71 | prompt_optimizer.fit(d_train) 72 | prompt_optimizer.show_report() 73 | ``` 74 | 75 | ## Use Cases 76 | ![Overview](https://microsoft.github.io/sammo/_images/overview.png) 77 | 78 | SAMMO is designed to support 79 | - **Efficient data labeling**: Supports minibatching by packing and parsing multiple datapoints into a single prompt. 80 | - **Prompt prototyping and engineering**: Re-usable components and prompt structures to quickly build and test new prompts. 81 | - **Instruction optimization**: Optimize instructions to do better on a given task. 82 | - **Prompt compression**: Compress prompts while maintaining performance. 83 | - **Large-scale prompt execution**: parallelization 84 | and rate-limiting out-of-the-box so you can run many queries in parallel and at scale without overwhelming the LLM API. 85 | 86 | It is less useful if you want to build 87 | - Interactive, agent-based LLM applications (→ check out [AutoGen](https://microsoft.github.io/autogen/)) 88 | - Interactive, production-ready LLM applications (→ check out [LangChain](https://www.langchain.com/)) 89 | 90 | 91 | 92 | ## Licence 93 | 94 | This project is licensed under [MIT](https://choosealicense.com/licenses/mit/). 95 | 96 | To cite this paper, you can use the following BibTeX entry: 97 | 98 | ```bibtex 99 | @inproceedings{schnabel-neville-2024-symbolic, 100 | title = "Symbolic Prompt Program Search: A Structure-Aware Approach to Efficient Compile-Time Prompt Optimization", 101 | author = "Schnabel, Tobias and Neville, Jennifer", 102 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2024", 103 | year = "2024", 104 | url = "https://aclanthology.org/2024.findings-emnlp.37", 105 | pages = "670--686" 106 | } 107 | ``` 108 | 109 | ## Authors 110 | 111 | `SAMMO` was written by [Tobias Schnabel](mailto:sammo@microsoft.com). 112 | 113 | ## Contributing 114 | 115 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 116 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 117 | the rights to use your contribution. For details, visit . 118 | 119 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 120 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 121 | provided by the bot. You will only need to do this once across all repos using our CLA. 122 | 123 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 124 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 125 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com>) with any additional questions or comments. 126 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please head to the [discussions page](https://github.com/microsoft/sammo/discussions). We will try and support the project as best as we can, but please keep in mind that with the current resources, we may not be able to respond to all questions. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for SAMMO is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: SAMMO 5 | author: Tobias Schnabel 6 | copyright: "2023" 7 | #exclude_patterns: 8 | # - examples* 9 | # - .* 10 | ## - sammo* 11 | # - dev* 12 | # - utils* 13 | # - data* 14 | # - deprecated* 15 | # - cache* 16 | ## - "*.md" 17 | 18 | # Force re-execution of notebooks on each build. 19 | # See https://jupyterbook.org/content/execute.html 20 | execute: 21 | execute_notebooks: off 22 | 23 | # Add a bibtex file so that we can create citations 24 | bibtex_bibfiles: 25 | - references.bib 26 | 27 | # Information about where the book exists on the web 28 | repository: 29 | url: https://github.com/microsoft/sammo/ # Online location of your book 30 | 31 | launch_buttons: 32 | binderhub_url: https://mybinder.org 33 | 34 | # Add GitHub buttons to your book 35 | # See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository 36 | html: 37 | use_issues_button: true 38 | use_repository_button: true 39 | 40 | sphinx: 41 | extra_extensions: 42 | - autodoc2 43 | local_extensions: 44 | set_version: . 45 | config: 46 | nb_merge_streams: true 47 | autodoc2_index_template: null 48 | autodoc2_output_dir: api 49 | autodoc2_module_all_regexes: 50 | - "sammo.*(throttler|store|utils|search_op|data|compactbars|components).*" 51 | autodoc2_packages: 52 | - "../sammo" 53 | autodoc2_hidden_objects: 54 | - inherited 55 | - private 56 | - dunder 57 | autodoc2_skip_module_regexes: 58 | - .*test.* 59 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | format: jb-book 2 | root: index 3 | parts: 4 | - chapters: 5 | - file: tutorials/quickstart 6 | - caption: 🎓 Learning SAMMO 7 | numbered: true 8 | chapters: 9 | - glob: tutorials/basics_* 10 | - caption: 💡 Using SAMMO 11 | numbered: true 12 | chapters: 13 | - glob: tutorials/example_* 14 | - caption: 🔌 LLM APIs 15 | chapters: 16 | - glob: tutorials/special_topics/* 17 | - caption: 📖 API Reference 18 | chapters: 19 | - file: api/sammo/sammo 20 | - caption: ℹ️ Project Info 21 | chapters: 22 | - file: contributing 23 | - file: release_history 24 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CONTRIBUTING.md 2 | ``` 3 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # 🏠 Overview 2 | 3 | ```{eval-rst} 4 | Version: 5 | |version| 6 | ``` 7 | A flexible, easy-to-use library for running and optimizing prompts for Large Language Models (LLMs). 8 | 9 | ```{image} overview.png 10 | :alt: overview 11 | :class: bg-primary mb-1 12 | :align: center 13 | ``` 14 | 15 | ```{include} ../README.md 16 | :start-after: 17 | :end-before: 18 | ``` 19 | 20 | -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/sammo/7ad76482f5776ed83f608f39ef9be4b96cab8329/docs/overview.png -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/sammo/7ad76482f5776ed83f608f39ef9be4b96cab8329/docs/references.bib -------------------------------------------------------------------------------- /docs/release_history.rst: -------------------------------------------------------------------------------- 1 | Release History 2 | =============== 3 | 4 | The release history is available on `PyPI `_ and `GitHub `_ 5 | -------------------------------------------------------------------------------- /docs/set_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | def setup(app): 6 | """Reads project version number from pyproject toml and sets it in the Sphinx context 7 | under the expected key""" 8 | from poetry.core.factory import Factory 9 | 10 | # read project version string from pyproject.toml 11 | poetry = Factory().create_poetry() 12 | version = poetry.package.pretty_version 13 | 14 | # wire up a handler to set the version string in the Sphinx config object after config is initialized 15 | def set_version_handler(_, config): 16 | config["version"] = version 17 | 18 | app.connect("config-inited", set_version_handler) 19 | 20 | return { 21 | "version": "0.1", 22 | "parallel_read_safe": True, 23 | "parallel_write_safe": True, 24 | } 25 | -------------------------------------------------------------------------------- /docs/tutorials/_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import pathlib 4 | import sammo 5 | from sammo.runners import OpenAIChat 6 | from sammo.base import Template, EvaluationScore 7 | from sammo.components import Output, GenerateText, ForEach, Union 8 | from sammo.extractors import ExtractRegex 9 | from sammo.data import DataTable 10 | import json 11 | import requests 12 | import os 13 | 14 | if not "OPENAI_API_KEY" in os.environ: 15 | raise ValueError("Please set the environment variable 'OPENAI_API_KEY'.") 16 | 17 | _ = sammo.setup_logger("WARNING") # we're only interested in warnings for now 18 | 19 | runner = OpenAIChat( 20 | model_id="gpt-3.5-turbo", 21 | api_config={"api_key": os.environ["OPENAI_API_KEY"]}, 22 | cache=os.getenv("CACHE_FILE", "cache.tsv"), 23 | timeout=30, 24 | ) 25 | 26 | 27 | def load_data( 28 | url="https://github.com/google/BIG-bench/raw/main/bigbench/benchmark_tasks/implicatures/task.json", 29 | ): 30 | task = json.loads(requests.get(url).content) 31 | # convert label to single string 32 | for x in task["examples"]: 33 | x["output"] = max(x["target_scores"], key=x["target_scores"].get) 34 | 35 | return DataTable.from_records( 36 | task["examples"], 37 | input_fields="input", 38 | constants={"instructions": task["task_prefix"]}, 39 | ) 40 | 41 | 42 | def accuracy(y_true: DataTable, y_pred: DataTable) -> EvaluationScore: 43 | y_true = y_true.outputs.normalized_values() 44 | y_pred = y_pred.outputs.normalized_values() 45 | n_correct = sum([y_p == y_t for y_p, y_t in zip(y_pred, y_true)]) 46 | 47 | return EvaluationScore(n_correct / len(y_true)) 48 | -------------------------------------------------------------------------------- /docs/tutorials/basics_3_parallelization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "editable": true, 8 | "slideshow": { 9 | "slide_type": "" 10 | }, 11 | "tags": [ 12 | "remove-cell" 13 | ] 14 | }, 15 | "outputs": [], 16 | "source": [ 17 | "# Load from parent directory if not installed\n", 18 | "import importlib\n", 19 | "import os\n", 20 | "\n", 21 | "if not importlib.util.find_spec(\"sammo\"):\n", 22 | " import sys\n", 23 | "\n", 24 | " sys.path.append(\"../\")" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "editable": true, 31 | "slideshow": { 32 | "slide_type": "" 33 | }, 34 | "tags": [] 35 | }, 36 | "source": [ 37 | "# Parallelization\n", 38 | "\n", 39 | "SAMMO automatically parallelizes runs across all **rows** of input data." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 7, 45 | "metadata": { 46 | "editable": true, 47 | "slideshow": { 48 | "slide_type": "" 49 | }, 50 | "tags": [ 51 | "hide-input" 52 | ] 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "# %load -r 3:18 _init.py\n", 57 | "import pathlib\n", 58 | "import sammo\n", 59 | "from sammo.runners import OpenAIChat\n", 60 | "from sammo.base import Template, EvaluationScore\n", 61 | "from sammo.components import Output, GenerateText, ForEach, Union\n", 62 | "from sammo.extractors import ExtractRegex\n", 63 | "from sammo.data import DataTable\n", 64 | "import json\n", 65 | "import requests\n", 66 | "import os\n", 67 | "\n", 68 | "if not \"OPENAI_API_KEY\" in os.environ:\n", 69 | " raise ValueError(\"Please set the environment variable 'OPENAI_API_KEY'.\")\n", 70 | "\n", 71 | "_ = sammo.setup_logger(\"WARNING\") # we're only interested in warnings for now" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 12, 77 | "metadata": { 78 | "editable": true, 79 | "slideshow": { 80 | "slide_type": "" 81 | }, 82 | "tags": [] 83 | }, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "minibatches[#################################################################################]5/5[00:00<00:00, 333.33it/s]\n" 90 | ] 91 | }, 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "+---------+----------+\n", 96 | "| input | output |\n", 97 | "+=========+==========+\n", 98 | "| 1 | I |\n", 99 | "+---------+----------+\n", 100 | "| 2 | II |\n", 101 | "+---------+----------+\n", 102 | "| 3 | III |\n", 103 | "+---------+----------+\n", 104 | "| 4 | IV |\n", 105 | "+---------+----------+\n", 106 | "| 5 | V |\n", 107 | "+---------+----------+\n", 108 | "Constants: None" 109 | ] 110 | }, 111 | "execution_count": 12, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "runner = OpenAIChat(\n", 118 | " model_id=\"gpt-3.5-turbo\",\n", 119 | " api_config={\"api_key\": os.environ[\"OPENAI_API_KEY\"]},\n", 120 | " cache=os.getenv(\"CACHE_FILE\", \"cache.tsv\"),\n", 121 | ")\n", 122 | "numbers = list(range(1,6))\n", 123 | "spp = Output(GenerateText(Template(\"Output as a latin numeral: {{input}}\")))\n", 124 | "spp.run(runner, DataTable(numbers))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": { 130 | "editable": true, 131 | "slideshow": { 132 | "slide_type": "" 133 | }, 134 | "tags": [] 135 | }, 136 | "source": [ 137 | "Here, SAMMO automatically runs queries for the six inputs in parallel while adhering to query limits (by default, 2 queries per second). We can change this when constructing the runner. We can also skip constructing the `DataTable` and just pass the list directly." 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 13, 143 | "metadata": { 144 | "editable": true, 145 | "slideshow": { 146 | "slide_type": "" 147 | }, 148 | "tags": [] 149 | }, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "minibatches[###################################################################################]5/5[00:00 LLMResult:\n", 107 | " formatted_prompt = f\"[INST] {prompt} [/INST]\"\n", 108 | " request = dict(\n", 109 | " input=formatted_prompt,\n", 110 | " max_new_tokens=self._max_context_window or max_tokens,\n", 111 | " temperature=randomness,\n", 112 | " )\n", 113 | " fingerprint = serialize_json({\"seed\": seed, \"generative_model_id\": self._model_id, **request})\n", 114 | " return await self._execute_request(request, fingerprint, priority)\n", 115 | "\n", 116 | " async def _call_backend(self, request: dict) -> dict:\n", 117 | " async with self._get_session() as session:\n", 118 | " async with session.post(\n", 119 | " f\"https://api.deepinfra.com/v1/inference/{self._model_id}\",\n", 120 | " json=request,\n", 121 | " headers={\"Authorization\": f\"Bearer {self._api_config['api_key']}\"}\n", 122 | " ) as response:\n", 123 | " return await response.json()\n", 124 | "\n", 125 | " def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes):\n", 126 | " return LLMResult(\n", 127 | " json_data[\"results\"][0][\"generated_text\"],\n", 128 | " costs=Costs(json_data[\"num_input_tokens\"], json_data[\"num_tokens\"]),\n", 129 | " )" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "id": "a4504c53-757b-4f83-8e08-a584ad85e14c", 136 | "metadata": { 137 | "editable": true, 138 | "slideshow": { 139 | "slide_type": "" 140 | }, 141 | "tags": [] 142 | }, 143 | "outputs": [ 144 | { 145 | "name": "stdin", 146 | "output_type": "stream", 147 | "text": [ 148 | "Enter your API key ········\n" 149 | ] 150 | }, 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "+---------+-------------------------------------------------------------+\n", 156 | "| input | output |\n", 157 | "+=========+=============================================================+\n", 158 | "| None | Horses, majestic creatures, have accompanied humans for |\n", 159 | "| | thousands of years, serving in transportation, agriculture, |\n", 160 | "| | and warfare. Today, they are cherished for companionship, |\n", 161 | "| | sport, and therapy. With their powerful build, graceful |\n", 162 | "| | movements, and intuitive nature, horses continue to inspire |\n", 163 | "| | and connect us to the natural world. Their enduring bond |\n", 164 | "| | with humans is a testament to their intelligence and |\n", 165 | "| | emotional depth. |\n", 166 | "+---------+-------------------------------------------------------------+\n", 167 | "Constants: None\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "runner = DeepInfraChat(\n", 173 | " \"mistralai/Mixtral-8x7B-Instruct-v0.1\", api_config={\"api_key\": getpass.getpass(\"Enter your API key\")}\n", 174 | ")\n", 175 | "print(Output(GenerateText(\"Generate a 50 word essay about horses.\")).run(runner))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "cd43b229-ba94-4b24-81df-66456f84f442", 181 | "metadata": { 182 | "editable": true, 183 | "slideshow": { 184 | "slide_type": "" 185 | }, 186 | "tags": [] 187 | }, 188 | "source": [ 189 | "The three things we had to implement were\n", 190 | "\n", 191 | "1. `generate_text()`: To format the prompt into a dictionary and compute a fingerprint for \n", 192 | "2. `_call_backend()`: To make the actual REST request\n", 193 | "3. `_to_llm_result()`: To convert the JSON object into an LLM result instance.\n", 194 | "\n", 195 | "That's it! The parent class will take care of all caching." 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3 (ipykernel)", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.11.9" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 5 220 | } 221 | -------------------------------------------------------------------------------- /docs/tutorials/special_topics/3_rate_limiting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "61643fe1-aa70-4f99-a3c4-a4d0577b70f3", 7 | "metadata": { 8 | "editable": true, 9 | "slideshow": { 10 | "slide_type": "" 11 | }, 12 | "tags": [ 13 | "remove-cell" 14 | ] 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "## Load from parent directory if not installed\n", 19 | "import importlib\n", 20 | "\n", 21 | "if not importlib.util.find_spec(\"sammo\"):\n", 22 | " import sys\n", 23 | "\n", 24 | " sys.path.append(\"../../../\")\n", 25 | "\n", 26 | "CACHE_FILE = \"cache/special_topics.tsv\"" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "id": "65550492-591f-44f3-85a1-9bee6db2316e", 33 | "metadata": { 34 | "editable": true, 35 | "slideshow": { 36 | "slide_type": "" 37 | }, 38 | "tags": [ 39 | "hide-input" 40 | ] 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "# %load -r :19 ../_init.py\n", 45 | "import pathlib\n", 46 | "import sammo\n", 47 | "from sammo.runners import OpenAIChat\n", 48 | "from sammo.base import Template, EvaluationScore\n", 49 | "from sammo.components import Output, GenerateText, ForEach, Union\n", 50 | "from sammo.extractors import ExtractRegex\n", 51 | "from sammo.data import DataTable\n", 52 | "import json\n", 53 | "import requests\n", 54 | "\n", 55 | "API_CONFIG_FILE = pathlib.Path().cwd().parent.parent / \"config\" / \"personal.openai\"\n", 56 | "API_CONFIG = \"\"\n", 57 | "if API_CONFIG_FILE.exists():\n", 58 | " API_CONFIG = API_CONFIG_FILE\n", 59 | "if not API_CONFIG:\n", 60 | " raise ValueError('Please set API_CONFIG to {\"api_key\": \"YOUR_KEY\"}')\n", 61 | "\n", 62 | "_ = sammo.setup_logger(\"WARNING\") # we're only interested in warnings for now" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "51dcda8e-55b7-4b73-a81a-90397af50b8b", 68 | "metadata": {}, 69 | "source": [ 70 | "# Rate Limiting\n", 71 | "\n", 72 | "Many APIs have rate limits, often in terms of number of requests within a certain time period or a total cost.\n", 73 | "\n", 74 | "You have three options to specify rate limits in {class}`~sammo.runners.Runner` (in increasing order of flexibility):\n", 75 | "\n", 76 | "1. Specify a number for the ``rate_limit`` parameter. This will enforce a requests per second limit equal to that number.\n", 77 | "2. Specify a list of {class}``~sammo.throttler.AtMost`` objects that are applied in an logical AND\n", 78 | " fashion.\n", 79 | "3. Pass an instance of {class}`~sammo.throttler.Throttler` (or a subclass of it). This allows you to fine-tune some settings, e.g., how costs are calculated.\n", 80 | "\n", 81 | "## Simple rate limit (qps)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 14, 87 | "id": "8752663a-a4c2-4321-8655-75353dfe72d4", 88 | "metadata": { 89 | "editable": true, 90 | "slideshow": { 91 | "slide_type": "" 92 | }, 93 | "tags": [] 94 | }, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "minibatches[###################################################################################]5/5[00:04<00:00, 1.13it/s]\n" 101 | ] 102 | }, 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "+---------+----------+\n", 107 | "| input | output |\n", 108 | "+=========+==========+\n", 109 | "| 1 | I |\n", 110 | "+---------+----------+\n", 111 | "| 2 | II |\n", 112 | "+---------+----------+\n", 113 | "| 3 | III |\n", 114 | "+---------+----------+\n", 115 | "| 4 | IV |\n", 116 | "+---------+----------+\n", 117 | "| 5 | V |\n", 118 | "+---------+----------+\n", 119 | "Constants: None" 120 | ] 121 | }, 122 | "execution_count": 14, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "runner = OpenAIChat(model_id=\"gpt-3.5-turbo-16k\", api_config=API_CONFIG, rate_limit=1)\n", 129 | "Output(GenerateText(Template(\"Output as a latin numeral: {{input}}\"))).run(\n", 130 | " runner, list(range(1,6))\n", 131 | ")" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "id": "fb99566f-0e56-4110-b63b-e61983ffefa7", 137 | "metadata": {}, 138 | "source": [ 139 | "As specified, `SAMMO` issued exactly one prompt request per second." 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "a64614da-c49e-44ae-84a2-5b824957fe13", 145 | "metadata": {}, 146 | "source": [ 147 | "## Advanced rate limits\n", 148 | "\n", 149 | "Let's say we want to make sure we never have more than 1 running request." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 18, 155 | "id": "16613f90-ce2b-4545-975d-6ac7133facae", 156 | "metadata": { 157 | "editable": true, 158 | "slideshow": { 159 | "slide_type": "" 160 | }, 161 | "tags": [] 162 | }, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "minibatches[###################################################################################]5/5[00:02<00:00, 1.88it/s]\n" 169 | ] 170 | }, 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "+---------+----------+\n", 175 | "| input | output |\n", 176 | "+=========+==========+\n", 177 | "| 1 | I |\n", 178 | "+---------+----------+\n", 179 | "| 2 | II |\n", 180 | "+---------+----------+\n", 181 | "| 3 | III |\n", 182 | "+---------+----------+\n", 183 | "| 4 | IV |\n", 184 | "+---------+----------+\n", 185 | "| 5 | V |\n", 186 | "+---------+----------+\n", 187 | "Constants: None" 188 | ] 189 | }, 190 | "execution_count": 18, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "from sammo.throttler import AtMost\n", 197 | "\n", 198 | "runner = OpenAIChat(model_id=\"gpt-3.5-turbo-16k\", api_config=API_CONFIG, rate_limit=AtMost(1, \"running\"))\n", 199 | "\n", 200 | "Output(GenerateText(Template(\"Output as a latin numeral: {{input}}\"))).run(\n", 201 | " runner, list(range(1,6))\n", 202 | ")" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "f20316f8-8588-4c9b-aace-b063bf6fcd1f", 208 | "metadata": { 209 | "editable": true, 210 | "slideshow": { 211 | "slide_type": "" 212 | }, 213 | "tags": [] 214 | }, 215 | "source": [ 216 | "Or, you want to run five queries every 10 seconds, but make sure they have at least 100ms breaks." 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 19, 222 | "id": "1f7d7dc3-cb56-4a61-8e18-4c129f2d46fb", 223 | "metadata": { 224 | "editable": true, 225 | "slideshow": { 226 | "slide_type": "" 227 | }, 228 | "tags": [] 229 | }, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "minibatches[###################################################################################]5/5[00:00<00:00, 5.08it/s]\n" 236 | ] 237 | }, 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "+---------+----------+\n", 242 | "| input | output |\n", 243 | "+=========+==========+\n", 244 | "| 1 | I |\n", 245 | "+---------+----------+\n", 246 | "| 2 | II |\n", 247 | "+---------+----------+\n", 248 | "| 3 | III |\n", 249 | "+---------+----------+\n", 250 | "| 4 | IV |\n", 251 | "+---------+----------+\n", 252 | "| 5 | V |\n", 253 | "+---------+----------+\n", 254 | "Constants: None" 255 | ] 256 | }, 257 | "execution_count": 19, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "limits = [AtMost(1, \"calls\", 0.1), AtMost(5, \"calls\", 10)]\n", 264 | "runner = OpenAIChat(model_id=\"gpt-3.5-turbo-16k\", api_config=API_CONFIG, rate_limit=limits)\n", 265 | "\n", 266 | "Output(GenerateText(Template(\"Output as a latin numeral: {{input}}\"))).run(\n", 267 | " runner, list(range(1,6))\n", 268 | ")" 269 | ] 270 | } 271 | ], 272 | "metadata": { 273 | "kernelspec": { 274 | "display_name": "Python 3 (ipykernel)", 275 | "language": "python", 276 | "name": "python3" 277 | }, 278 | "language_info": { 279 | "codemirror_mode": { 280 | "name": "ipython", 281 | "version": 3 282 | }, 283 | "file_extension": ".py", 284 | "mimetype": "text/x-python", 285 | "name": "python", 286 | "nbconvert_exporter": "python", 287 | "pygments_lexer": "ipython3", 288 | "version": "3.11.9" 289 | } 290 | }, 291 | "nbformat": 4, 292 | "nbformat_minor": 5 293 | } 294 | -------------------------------------------------------------------------------- /docs/tutorials/special_topics/4_structured_outputs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "id": "61643fe1-aa70-4f99-a3c4-a4d0577b70f3", 7 | "metadata": { 8 | "editable": true, 9 | "slideshow": { 10 | "slide_type": "" 11 | }, 12 | "tags": [ 13 | "remove-cell" 14 | ] 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "# Load from parent directory if not installed\n", 19 | "import importlib\n", 20 | "import os\n", 21 | "\n", 22 | "if not importlib.util.find_spec(\"sammo\"):\n", 23 | " import sys\n", 24 | " \n", 25 | " sys.path.insert(0, \"../../../\")\n", 26 | "os.environ[\"CACHE_FILE\"] = \"cache/structured_outputs.tsv\"" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 10, 32 | "id": "65550492-591f-44f3-85a1-9bee6db2316e", 33 | "metadata": { 34 | "editable": true, 35 | "slideshow": { 36 | "slide_type": "" 37 | }, 38 | "tags": [ 39 | "hide-input" 40 | ] 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "# %load -r 3:18 ../_init.py\n", 45 | "import pathlib\n", 46 | "import sammo\n", 47 | "from sammo.runners import OpenAIChat\n", 48 | "from sammo.base import Template, EvaluationScore\n", 49 | "from sammo.components import Output, GenerateText, ForEach, Union\n", 50 | "from sammo.extractors import ExtractRegex\n", 51 | "from sammo.data import DataTable\n", 52 | "import json\n", 53 | "import requests\n", 54 | "import os\n", 55 | "\n", 56 | "if not \"OPENAI_API_KEY\" in os.environ:\n", 57 | " raise ValueError(\"Please set the environment variable 'OPENAI_API_KEY'.\")\n", 58 | "\n", 59 | "_ = sammo.setup_logger(\"WARNING\") # we're only interested in warnings for now" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "51dcda8e-55b7-4b73-a81a-90397af50b8b", 65 | "metadata": {}, 66 | "source": [ 67 | "# Structured Outputs\n", 68 | "\n", 69 | "There are two ways in which models offer parseable JSON objects:\n", 70 | "\n", 71 | "1. By setting a flag that ensures that the output is *some* JSON object\n", 72 | "2. By specifying the exact JSON schema that the output needs to adhere to\n", 73 | "\n", 74 | "Option 2 is preferrable in general, the first option will likely disappear in future API versions.\n", 75 | "\n", 76 | "## Setting a flag\n", 77 | "\n", 78 | "For this, simply pass `json_mode = True` to {class}`~sammo.components.GenerateText`." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 8, 84 | "id": "8752663a-a4c2-4321-8655-75353dfe72d4", 85 | "metadata": { 86 | "editable": true, 87 | "slideshow": { 88 | "slide_type": "" 89 | }, 90 | "tags": [] 91 | }, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "+---------+---------------------------------------------------------+\n", 97 | "| input | output |\n", 98 | "+=========+=========================================================+\n", 99 | "| None | { \"names\": [ \"Emma Johnson\", \"Liam Smith\", |\n", 100 | "| | \"Olivia Brown\", \"Noah Davis\", \"Ava Wilson\", |\n", 101 | "| | \"Elijah Martinez\", \"Sophia Anderson\", \"Lucas |\n", 102 | "| | Taylor\", \"Isabella Thomas\", \"Mason Moore\" ] } |\n", 103 | "+---------+---------------------------------------------------------+\n", 104 | "Constants: None" 105 | ] 106 | }, 107 | "execution_count": 8, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "runner = OpenAIChat(\n", 114 | " model_id=\"gpt-4o\",\n", 115 | " api_config={\"api_key\": os.environ[\"OPENAI_API_KEY\"]},\n", 116 | " cache=os.getenv(\"CACHE_FILE\", \"cache.tsv\"),\n", 117 | " timeout=30,\n", 118 | ")\n", 119 | "Output(GenerateText(\"Generate a list of 10 full names in JSON format.\", json_mode=True)).run(runner)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "a64614da-c49e-44ae-84a2-5b824957fe13", 125 | "metadata": {}, 126 | "source": [ 127 | "What if we actually wanted first and last names as separate fields? We could provide the model with an example output, or:\n", 128 | "\n", 129 | "## Specifying a JSON schema\n", 130 | "\n", 131 | "Say we want something like " 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 4, 137 | "id": "16613f90-ce2b-4545-975d-6ac7133facae", 138 | "metadata": { 139 | "editable": true, 140 | "slideshow": { 141 | "slide_type": "" 142 | }, 143 | "tags": [] 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "example = {\"names\": [{\"first\": \"John\", \"last\": \"Smith\"}]}" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "f20316f8-8588-4c9b-aace-b063bf6fcd1f", 153 | "metadata": { 154 | "editable": true, 155 | "slideshow": { 156 | "slide_type": "" 157 | }, 158 | "tags": [] 159 | }, 160 | "source": [ 161 | "While you can manually write a schema, `SAMMO` provides you with a convenience function that works in many cases." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "id": "1f7d7dc3-cb56-4a61-8e18-4c129f2d46fb", 168 | "metadata": { 169 | "editable": true, 170 | "slideshow": { 171 | "slide_type": "" 172 | }, 173 | "tags": [] 174 | }, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "{\n", 181 | " \"type\": \"object\",\n", 182 | " \"properties\": {\n", 183 | " \"names\": {\n", 184 | " \"type\": \"array\",\n", 185 | " \"items\": {\n", 186 | " \"type\": \"object\",\n", 187 | " \"properties\": {\n", 188 | " \"first\": {\n", 189 | " \"type\": \"string\"\n", 190 | " },\n", 191 | " \"last\": {\n", 192 | " \"type\": \"string\"\n", 193 | " }\n", 194 | " },\n", 195 | " \"required\": [\n", 196 | " \"first\",\n", 197 | " \"last\"\n", 198 | " ],\n", 199 | " \"additionalProperties\": false\n", 200 | " }\n", 201 | " }\n", 202 | " },\n", 203 | " \"required\": [\n", 204 | " \"names\"\n", 205 | " ],\n", 206 | " \"additionalProperties\": false\n", 207 | "}\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "schema = runner.guess_json_schema(example)\n", 213 | "print(schema)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "599f916e-cf1b-4edb-ba01-63c14c28a935", 219 | "metadata": {}, 220 | "source": [ 221 | "That would have been quite some work! Let's pass this to {class}`~sammo.components.GenerateText`." 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 7, 227 | "id": "03e559e4-f8c5-48c0-9b18-087c720ddf6b", 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/plain": [ 233 | "+---------+--------------------------------------------------------------+\n", 234 | "| input | output |\n", 235 | "+=========+==============================================================+\n", 236 | "| None | {\"names\":[{\"first\":\"Liam\",\"last\":\"Johnson\"},{\"first\":\"Emma\", |\n", 237 | "| | \"last\":\"Williams\"},{\"first\":\"Noah\",\"last\":\"Brown\"},{\"first\": |\n", 238 | "| | \"Olivia\",\"last\":\"Jones\"},{\"first\":\"Ava\",\"last\":\"Garcia\"},{\"f |\n", 239 | "| | irst\":\"Sophia\",\"last\":\"Martinez\"},{\"first\":\"Isabella\",\"last\" |\n", 240 | "| | :\"Davis\"},{\"first\":\"Mia\",\"last\":\"Rodriguez\"},{\"first\":\"Charl |\n", 241 | "| | otte\",\"last\":\"Hernandez\"},{\"first\":\"Amelia\",\"last\":\"Lopez\"}] |\n", 242 | "| | } |\n", 243 | "+---------+--------------------------------------------------------------+\n", 244 | "Constants: None" 245 | ] 246 | }, 247 | "execution_count": 7, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "Output(GenerateText(\"Generate a list of 10 full names in JSON format.\", json_mode=schema)).run(runner)" 254 | ] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3 (ipykernel)", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.11.9" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 5 278 | } 279 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: binder-environment 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - pandas 7 | - pip 8 | - pip: 9 | - sammo==0.2.1 -------------------------------------------------------------------------------- /examples/paper_instruction_tuning/instruction_tuning_dspy.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import pathlib 4 | import dspy 5 | import orjson 6 | import click 7 | 8 | from dspy.evaluate import Evaluate 9 | from dspy.teleprompt import COPRO 10 | 11 | CONFIG_PATH = pathlib.Path(__file__).parent.parent.parent / "config" 12 | MODEL_CONFIGS = { 13 | "gpt-3.5": { 14 | "config": {"model": "gpt-3.5-turbo-16k-0613"}, 15 | "credentials": CONFIG_PATH / "personal.openai", 16 | "class": "OpenAI", 17 | }, 18 | "gpt-4": { 19 | "config": {"model": "gpt-4-0613"}, 20 | "credentials": CONFIG_PATH / "personal.openai", 21 | "class": "OpenAI", 22 | }, 23 | "llama-2": { 24 | "config": { 25 | "model": "meta-llama/Llama-2-70b-chat-hf", 26 | "api_base": "https://api.deepinfra.com/v1/openai/", 27 | }, 28 | "credentials": CONFIG_PATH / "personal.deepinfra", 29 | "class": "DeepInfra", 30 | }, 31 | "mixtral": { 32 | "config": { 33 | "model": "cognitivecomputations/dolphin-2.6-mixtral-8x7b", 34 | "api_base": "https://api.deepinfra.com/v1/openai/", 35 | }, 36 | "credentials": CONFIG_PATH / "personal.deepinfra", 37 | "class": "DeepInfra", 38 | }, 39 | } 40 | MODELS = list(MODEL_CONFIGS.keys()) 41 | TASKS = [ 42 | "implicatures", 43 | "metaphor_boolean", 44 | "navigate", 45 | "presuppositions_as_nli", 46 | "sports_understanding", 47 | "vitaminc_fact_verification", 48 | "winowhy", 49 | "word_sorting", 50 | ] 51 | DATA = "data_splits.json" 52 | RESULTS_DIR = pathlib.Path(__file__).parent / "dspy" 53 | RESULTS_DIR.mkdir(exist_ok=True) 54 | 55 | 56 | class BasicQA(dspy.Signature): 57 | question = dspy.InputField() 58 | answer = dspy.OutputField() 59 | 60 | 61 | class DeepInfra(dspy.OpenAI): 62 | MAX_BATCH_SIZE = 1 63 | 64 | def __call__( 65 | self, 66 | prompt: str, 67 | **kwargs, 68 | ): 69 | n = kwargs.get("n", 1) 70 | if n > self.MAX_BATCH_SIZE: 71 | completions = [] 72 | for i in range(0, n, self.MAX_BATCH_SIZE): 73 | args = dict(**kwargs) 74 | args["n"] = min(n, i + self.MAX_BATCH_SIZE) - i 75 | args["temperature"] = kwargs.get("temperature", 0.7) - 0.01 * i 76 | minibatch = super().__call__(prompt=prompt, **args) 77 | completions += minibatch 78 | else: 79 | completions = super().__call__(prompt=prompt, **kwargs) 80 | return completions 81 | 82 | 83 | def normalize(x): 84 | return x.lower().replace(" ", "") 85 | 86 | 87 | def accuracy(gold, pred, trace=None) -> bool: 88 | return normalize(pred.answer) == normalize(gold.answer) 89 | 90 | 91 | class SimpleTaskPipeline(dspy.Module): 92 | def __init__(self, instructions): 93 | super().__init__() 94 | 95 | my_module = copy.copy(BasicQA) 96 | my_module.__doc__ = instructions 97 | self.signature = my_module 98 | self.predictor = dspy.Predict(self.signature) 99 | 100 | def forward(self, question): 101 | return self.predictor(question=question) 102 | 103 | 104 | def load_data(): 105 | with open(DATA, "rb") as f: 106 | splits = orjson.loads(f.read()) 107 | as_dict = dict() 108 | for task in splits: 109 | as_dict[task["task_id"]] = task 110 | for split in ["d_incontext", "d_train", "d_test", "d_val"]: 111 | as_dict[task["task_id"]][split] = [ 112 | dspy.Example(question=x["input"], answer=x["output"]).with_inputs("question") for x in task[split] 113 | ] 114 | return as_dict 115 | 116 | 117 | def load_program(path): 118 | loaded_program = SimpleTaskPipeline(None) 119 | loaded_program.load(path) 120 | 121 | 122 | @click.command() 123 | @click.option("--llm", default=MODELS[0], type=click.Choice(MODELS), prompt=True) 124 | @click.option("--task-reference_id", default=TASKS[0], type=click.Choice(TASKS), prompt=True) 125 | @click.option("--uuid", default=None, type=str) 126 | @click.option("--confirmed", is_flag=True, default=None) 127 | def main(llm, task_id, uuid, confirmed, num_threads=24, show_example=True): 128 | if confirmed is None: 129 | click.confirm(f"Do you want to run {task_id} with {llm}?", abort=True, default=True) 130 | task = load_data()[task_id] 131 | model_config = MODEL_CONFIGS[llm] 132 | config = json.loads(model_config["credentials"].read_text()) 133 | llm_class = {"OpenAI": dspy.OpenAI, "DeepInfra": DeepInfra}[model_config["class"]] 134 | runner = llm_class(api_key=config["api_key"], **model_config["config"]) 135 | dspy.settings.configure(lm=runner) 136 | run_id = f"{llm}_{task['task_id']}" 137 | 138 | dspy_program = SimpleTaskPipeline(task["instructions"]) 139 | 140 | if show_example: 141 | pred = dspy_program(question=task["d_train"][0].question) 142 | runner.inspect_history(n=1) 143 | 144 | copro_teleprompter = COPRO( 145 | metric=accuracy, 146 | breadth=12, 147 | depth=4, 148 | track_stats=True, 149 | init_temperature=1.4 if "gpt" in llm else 0.7, 150 | ) 151 | 152 | optimized_program = copro_teleprompter.compile( 153 | dspy_program, 154 | trainset=task["d_train"], 155 | eval_kwargs=dict(num_threads=num_threads, display_progress=True, display_table=0), 156 | ) 157 | print(optimized_program) 158 | 159 | eval_params = dict( 160 | metric=accuracy, 161 | num_threads=num_threads, 162 | display_progress=True, 163 | display_table=0, 164 | return_outputs=True, 165 | ) 166 | y_test_score, y_test = Evaluate(devset=task["d_test"], **eval_params)(optimized_program) 167 | print(y_test_score) 168 | y_train_score, y_train = Evaluate(devset=task["d_train"], **eval_params)(optimized_program) 169 | 170 | state = orjson.dumps( 171 | { 172 | "y_test_score": y_test_score / 100.0, 173 | "y_train_score": y_train_score / 100.0, 174 | "y_test_input": [v[0].toDict() for v in y_test], 175 | "y_test_output": [v[1].toDict() for v in y_test], 176 | "y_train_input": [v[0].toDict() for v in y_train], 177 | "y_train_output": [v[1].toDict() for v in y_train], 178 | "run_id": run_id, 179 | "model": optimized_program.dump_state(), 180 | }, 181 | option=orjson.OPT_INDENT_2, 182 | ) 183 | (RESULTS_DIR / f"{run_id}.dspy").write_bytes(state) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /examples/paper_instruction_tuning/instruction_tuning_sammo.py: -------------------------------------------------------------------------------- 1 | import click 2 | import sammo 3 | import orjson 4 | 5 | from sammo.base import EvaluationScore 6 | from sammo.mutators import ( 7 | BagOfMutators, 8 | APE, 9 | InduceInstructions, 10 | SyntaxTreeMutator, 11 | APO, 12 | Paraphrase, 13 | ) 14 | from sammo.runners import OpenAIChat 15 | from sammo.throttler import AtMost 16 | 17 | logger = sammo.setup_logger(log_prompts_to_file=True) 18 | 19 | from sammo import search_op 20 | from sammo.data import DataTable 21 | from sammo.instructions import MetaPrompt, Paragraph, InputData 22 | from sammo.components import Output 23 | from sammo.dataformatters import PlainFormatter 24 | from sammo.search import EnumerativeSearch, BeamSearch, Optimizer 25 | from sammo.store import PersistentDict 26 | 27 | import pathlib 28 | 29 | MAIN_FOLDER = sammo.utils.DEFAULT_SAVE_PATH 30 | CONFIG_PATH = MAIN_FOLDER.parent.parent.parent / "config" 31 | MODEL_CONFIGS = { 32 | "gpt-3.5": { 33 | "full_id": "gpt-3.5-turbo-16k-0613", 34 | "equivalence_class": "gpt-3.5-turbo-16k", 35 | "credentials": CONFIG_PATH / "personal.openai", 36 | "rate_limit": 10, 37 | "timeout": 90, 38 | "max_context_window": None, 39 | }, 40 | "gpt-4": { 41 | "full_id": "gpt-4-0613", 42 | "equivalence_class": "gpt-4-0613", 43 | "credentials": CONFIG_PATH / "personal.openai", 44 | "rate_limit": 10, 45 | "timeout": 90, 46 | "max_context_window": None, 47 | }, 48 | "llama-2": { 49 | "full_id": "meta-llama/Llama-2-70b-chat-hf", 50 | "equivalence_class": "meta-llama/Llama-2-70b-chat-hf", 51 | "credentials": CONFIG_PATH / "personal.deepinfra", 52 | "rate_limit": [AtMost(10, "running"), AtMost(2, "rejected", 1)], 53 | "timeout": 180, 54 | "max_context_window": 4096, 55 | }, 56 | "mixtral": { 57 | "full_id": "cognitivecomputations/dolphin-2.6-mixtral-8x7b", 58 | "equivalence_class": "dolphin-2.6-mixtral-8x7b", 59 | "credentials": CONFIG_PATH / "personal.deepinfra", 60 | "rate_limit": [AtMost(10, "running"), AtMost(2, "rejected", 1)], 61 | "timeout": 180, 62 | "max_context_window": None, 63 | }, 64 | } 65 | MODELS = list(MODEL_CONFIGS.keys()) 66 | TASKS = [ 67 | "implicatures", 68 | "metaphor_boolean", 69 | "navigate", 70 | "presuppositions_as_nli", 71 | "sports_understanding", 72 | "vitaminc_fact_verification", 73 | "winowhy", 74 | "word_sorting", 75 | ] 76 | DATA = "data_splits.json" 77 | METHODS = ["sammo", "apo", "ape", "grips"] 78 | 79 | 80 | def accuracy(y_true: DataTable, y_pred: DataTable) -> EvaluationScore: 81 | def normalize(x): 82 | if isinstance(x, dict): 83 | print(x) 84 | return x.lower().replace(" ", "") 85 | 86 | mistakes = list() 87 | 88 | y_in = y_true.inputs.raw_values 89 | y_true, y_pred = y_true.outputs.normalized_values(), y_pred.outputs.normalized_values(on_empty="") 90 | 91 | for i in range(len(y_true)): 92 | is_mistake = normalize(y_true[i]) != normalize(y_pred[i]) 93 | is_mistake = is_mistake and normalize(y_in[i] + y_true[i]) != normalize(y_pred[i]) 94 | if is_mistake: 95 | mistakes.append(i) 96 | 97 | accuracy = 1 - len(mistakes) / len(y_true) 98 | return EvaluationScore(accuracy, mistakes) 99 | 100 | 101 | class InstructionTuningSearchSpace: 102 | def __init__(self, dtrain): 103 | self.dtrain = dtrain 104 | 105 | def __call__(self): 106 | example_formatter = PlainFormatter(all_labels=self.dtrain.outputs.unique(), orient="item") 107 | 108 | labels = self.dtrain.outputs.unique() 109 | instructions = MetaPrompt( 110 | [ 111 | Paragraph("Instructions: "), 112 | Paragraph( 113 | search_op.one_of( 114 | [ 115 | self.dtrain.constants["instructions"], 116 | "", 117 | "Find the best output label given the input.", 118 | self.dtrain.constants["instructions"] * 2, 119 | ] 120 | ), 121 | reference_id="instructions", 122 | ), 123 | Paragraph("\n"), 124 | Paragraph(f"Output labels: {', '.join(labels)}\n" if len(labels) <= 10 else ""), 125 | Paragraph(InputData()), 126 | Paragraph("Output: "), 127 | ], 128 | render_as="raw", 129 | data_formatter=example_formatter, 130 | ) 131 | 132 | return Output( 133 | instructions.with_extractor("raise"), 134 | minibatch_size=1, 135 | on_error="empty_result", 136 | ) 137 | 138 | 139 | @click.command() 140 | @click.option("--llm", default=MODELS[0], type=click.Choice(MODELS), prompt=True) 141 | @click.option("--task-id", default=TASKS[0], type=click.Choice(TASKS), prompt=True) 142 | @click.option("--method", default=METHODS[0], type=click.Choice(METHODS), prompt=True) 143 | @click.option("--uuid", default=None, type=str) 144 | @click.option("--confirmed", is_flag=True, default=None) 145 | def main(llm, task_id, method, uuid=None, confirmed=None, debug=False): 146 | if confirmed is None: 147 | click.confirm(f"Do you want to run {task_id} with {llm}?", abort=True, default=True) 148 | model_config = MODEL_CONFIGS[llm] 149 | run_id = f"{task_id}_{model_config['equivalence_class'].replace('/', '_')}" 150 | runner = OpenAIChat( 151 | model_id=model_config["full_id"], 152 | api_config=model_config["credentials"], 153 | equivalence_class=model_config["equivalence_class"], 154 | rate_limit=model_config["rate_limit"], 155 | cache=sammo.store.PersistentDict(MAIN_FOLDER / f"{run_id}.cache.tsv"), 156 | timeout=model_config["timeout"], 157 | max_retries=50000, 158 | max_context_window=model_config["max_context_window"], 159 | ) 160 | all_tasks = {x["task_id"]: x for x in orjson.loads(pathlib.Path(DATA).read_bytes())} 161 | task = all_tasks[task_id] 162 | 163 | data = dict() 164 | for k, v in task.items(): 165 | if k.startswith("d_"): 166 | data[k] = DataTable.from_records(v, constants=dict(instructions=task["instructions"])) 167 | 168 | search_space = InstructionTuningSearchSpace(data["d_train"]) 169 | baseline_performance = EnumerativeSearch(runner, search_space, accuracy, max_candidates=1) 170 | baseline_performance.fit_transform(data["d_train"]) 171 | baseline_performance.transform(data["d_test"]) 172 | baseline_performance.show_report() 173 | baseline_performance.save_json(MAIN_FOLDER / "baseline" / f"{run_id}.model.json") 174 | 175 | if method == "ape": 176 | prompt_optimizer = BeamSearch( 177 | runner, 178 | APE("#instructions", search_space, data["d_train"], 5), 179 | accuracy, 180 | maximize=True, 181 | n_initial_candidates=12, 182 | depth=3, 183 | mutations_per_beam=2, 184 | beam_width=4, 185 | add_previous=True, 186 | ) 187 | elif method == "apo": 188 | prompt_optimizer = BeamSearch( 189 | runner, 190 | APO( 191 | "#instructions content", 192 | search_space, 193 | num_gradients=2, 194 | steps_per_gradient=1, 195 | num_rewrites=1, 196 | ), 197 | accuracy, 198 | maximize=True, 199 | depth=7, 200 | mutations_per_beam=2, 201 | beam_width=4, 202 | add_previous=True, 203 | ) 204 | elif method == "grips": 205 | mutation_operators = SyntaxTreeMutator( 206 | "#instructions", 207 | search_space, 208 | PersistentDict(MAIN_FOLDER / "trees" / f"{run_id}.cache.json"), 209 | ) 210 | prompt_optimizer = BeamSearch( 211 | runner, 212 | mutation_operators, 213 | accuracy, 214 | maximize=True, 215 | depth=7, 216 | mutations_per_beam=2, 217 | n_initial_candidates=1, 218 | beam_width=4, 219 | add_previous=True, 220 | ) 221 | elif method == "sammo": 222 | mutation_operators = BagOfMutators( 223 | search_space, 224 | InduceInstructions("#instructions", data["d_incontext"]), 225 | APO( 226 | "#instructions content", 227 | None, 228 | num_gradients=2, 229 | steps_per_gradient=1, 230 | num_rewrites=0, 231 | ), 232 | Paraphrase("#instructions"), 233 | sample_for_init_candidates=True, 234 | ) 235 | prompt_optimizer = BeamSearch( 236 | runner, 237 | mutation_operators, 238 | accuracy, 239 | maximize=True, 240 | depth=4, 241 | mutations_per_beam=2, 242 | n_initial_candidates=4, 243 | beam_width=4, 244 | add_previous=True, 245 | ) 246 | prompt_optimizer.fit(data["d_train"]) 247 | prompt_optimizer.show_report() 248 | 249 | if not debug: 250 | dtest_pred = prompt_optimizer.transform(data["d_test"]) 251 | print(f"Test score: {accuracy(data['d_test'], dtest_pred)}") 252 | prompt_optimizer.save_json(MAIN_FOLDER / method / f"{run_id}.model.json") 253 | 254 | 255 | if __name__ == "__main__": 256 | main() 257 | -------------------------------------------------------------------------------- /examples/paper_rag/rag_tuning_dspy.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import pathlib 4 | 5 | import click 6 | from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction 7 | import chromadb 8 | import pandas as pd 9 | import dspy 10 | from dspy.retrieve.chromadb_rm import ChromadbRM 11 | from dspy.evaluate import Evaluate 12 | from dspy.teleprompt import MIPRO 13 | 14 | BASE_DIR = pathlib.Path(__file__).parent 15 | RESULTS_DIR = BASE_DIR / "dspy" 16 | RESULTS_DIR.mkdir(exist_ok=True) 17 | DB_PATH = str(BASE_DIR / "chroma") 18 | DATA = BASE_DIR / "data_splits.json" 19 | CONFIG_PATH = pathlib.Path(__file__).parent.parent.parent / "config" 20 | MODEL_CONFIGS = { 21 | "gpt-3.5": { 22 | "config": {"model": "gpt-3.5-turbo-16k-0613"}, 23 | "credentials": CONFIG_PATH / "personal.openai", 24 | "class": "OpenAI", 25 | }, 26 | "gpt-4": { 27 | "config": {"model": "gpt-4-0613"}, 28 | "credentials": CONFIG_PATH / "personal.openai", 29 | "class": "OpenAI", 30 | }, 31 | "llama-2": { 32 | "config": { 33 | "model": "meta-llama/Llama-2-70b-chat-hf", 34 | "api_base": "https://api.deepinfra.com/v1/openai/", 35 | }, 36 | "credentials": CONFIG_PATH / "personal.deepinfra", 37 | "class": "DeepInfra", 38 | }, 39 | "llama-2-alt": { 40 | "config": {"model": "meta-llama/Llama-2-70b-chat-hf", "api_base": "https://api.together.xyz/v1/"}, 41 | "credentials": CONFIG_PATH / "personal.together", 42 | "class": "Together", 43 | }, 44 | "mixtral": { 45 | "config": { 46 | "model": "cognitivecomputations/dolphin-2.6-mixtral-8x7b", 47 | "api_base": "https://api.deepinfra.com/v1/openai/", 48 | }, 49 | "credentials": CONFIG_PATH / "personal.deepinfra", 50 | "class": "DeepInfra", 51 | }, 52 | } 53 | MODELS = list(MODEL_CONFIGS.keys()) 54 | TASKS = ["smcalflow", "geo880", "overnight"] 55 | 56 | 57 | class GenerateAnswer(dspy.Signature): 58 | context = dspy.InputField(desc="may contain relevant facts") 59 | input = dspy.InputField() 60 | answer = dspy.OutputField() 61 | 62 | 63 | class RAG(dspy.Module): 64 | def __init__( 65 | self, 66 | n_fewshot=10, 67 | instructions="Answer questions with short factoid answers.", 68 | ): 69 | super().__init__() 70 | my_module = copy.copy(GenerateAnswer) 71 | my_module.__doc__ = instructions 72 | 73 | self.retrieve = dspy.Retrieve(k=n_fewshot) 74 | self.generate_answer = dspy.Predict(my_module) 75 | 76 | def forward(self, input): 77 | context = self.retrieve(input).passages 78 | prediction = self.generate_answer(context=context, input=input) 79 | return dspy.Prediction(context=context, answer=prediction.answer) 80 | 81 | 82 | class DeepInfra(dspy.OpenAI): 83 | MAX_BATCH_SIZE = 1 84 | 85 | def __call__( 86 | self, 87 | prompt: str, 88 | **kwargs, 89 | ): 90 | n = kwargs.get("n", 1) 91 | if n > self.MAX_BATCH_SIZE: 92 | completions = [] 93 | for i in range(0, n, self.MAX_BATCH_SIZE): 94 | args = dict(**kwargs) 95 | args["n"] = min(n, i + self.MAX_BATCH_SIZE) - i 96 | args["temperature"] = kwargs.get("temperature", 0.7) - 0.01 * i 97 | minibatch = super().__call__(prompt=prompt, **args) 98 | completions += minibatch 99 | else: 100 | completions = super().__call__(prompt=prompt, **kwargs) 101 | return completions 102 | 103 | 104 | class TogetherPatched(dspy.OpenAI): 105 | pass 106 | 107 | 108 | def normalize(x): 109 | return x.lower().strip() 110 | 111 | 112 | def accuracy(gold, pred, trace=None) -> bool: 113 | return normalize(pred.answer) == normalize(gold.answer) 114 | 115 | 116 | def init_retriever(coll_name, docs, overwrite=False): 117 | client = chromadb.PersistentClient(path=DB_PATH) 118 | if coll_name in [c.name for c in client.list_collections()]: 119 | if overwrite: 120 | client.delete_collection(coll_name) 121 | else: 122 | return 123 | 124 | collection = client.create_collection(name=coll_name, embedding_function=EMBEDDING_FUNC) 125 | collection.add( 126 | documents=[f"Input: {doc['input']}\nAnswer:{doc['output']}" for doc in docs], 127 | ids=[str(i) for i in range(len(docs))], 128 | ) 129 | 130 | 131 | @click.command() 132 | @click.option("--llm", default=MODELS[0], type=click.Choice(MODELS), prompt=True) 133 | @click.option("--task-reference_id", default=TASKS[0], type=click.Choice(TASKS), prompt=True) 134 | @click.option("--uuid", default=None, type=str) 135 | @click.option("--confirmed", is_flag=True, default=None) 136 | @click.option("--debug", default=True, type=bool, prompt=True) 137 | def main( 138 | llm, 139 | task_id, 140 | uuid, 141 | confirmed, 142 | num_threads=16, 143 | show_example=True, 144 | n_fewshot=10, 145 | debug=False, 146 | ): 147 | if confirmed is None: 148 | click.confirm( 149 | f"Do you want to run {task_id} with {llm}?", 150 | abort=True, 151 | default=True, 152 | ) 153 | task = json.loads(pathlib.Path(DATA).read_bytes())[task_id] 154 | task["task_id"] = task_id 155 | 156 | model_config = MODEL_CONFIGS[llm] 157 | config = json.loads(model_config["credentials"].read_text()) 158 | llm_class = {"OpenAI": dspy.OpenAI, "DeepInfra": DeepInfra, "Together": TogetherPatched}[model_config["class"]] 159 | runner = llm_class(api_key=config["api_key"], **model_config["config"]) 160 | num_threads = 1 if debug else num_threads 161 | run_id = f"{llm}_{task['task_id']}" 162 | 163 | init_retriever(task["task_id"], task["incontext"]["records"]) 164 | retriever_model = ChromadbRM( 165 | task["task_id"], 166 | DB_PATH, 167 | embedding_function=EMBEDDING_FUNC, 168 | k=n_fewshot, 169 | ) 170 | dspy.settings.configure(lm=runner, rm=retriever_model) 171 | 172 | # Tell DSPy that the 'input' field is the input. Any other fields are labels and/or metadata. 173 | trainset = [ 174 | dspy.Example(input=x["input"], answer=x["output"]).with_inputs("input") for x in task["train"]["records"] 175 | ] 176 | testset = [dspy.Example(input=x["input"], answer=x["output"]).with_inputs("input") for x in task["test"]["records"]] 177 | if debug: 178 | trainset = trainset[:5] 179 | testset = testset[:5] 180 | dspy_program = RAG(n_fewshot=n_fewshot, instructions=task["train"]["constants"]["full_dd"]) 181 | if show_example or debug: 182 | dspy_program(input=trainset[0].input) 183 | runner.inspect_history(n=1) 184 | 185 | # Set up a basic teleprompter, which will compile our RAG program. 186 | teleprompter = MIPRO( 187 | metric=accuracy, 188 | num_candidates=2 if debug else 10, 189 | track_stats=True, 190 | ) 191 | 192 | # Compile! 193 | optimized_program = teleprompter.compile( 194 | dspy_program, 195 | trainset=trainset, 196 | num_trials=1 if debug else 24, 197 | max_bootstrapped_demos=1 if debug else 5, 198 | view_examples=False, 199 | max_labeled_demos=1 if debug else 5, 200 | eval_kwargs=dict(num_threads=num_threads, display_progress=True, display_table=0), 201 | requires_permission_to_run=False, 202 | ) 203 | runner.inspect_history(n=1) 204 | 205 | for name, parameter in optimized_program.named_predictors(): 206 | print(name) 207 | print(parameter) 208 | 209 | eval_params = dict( 210 | metric=accuracy, 211 | num_threads=num_threads, 212 | display_progress=True, 213 | display_table=0, 214 | return_outputs=True, 215 | ) 216 | 217 | y_test_score, y_test = Evaluate(devset=testset, **eval_params)(optimized_program) 218 | print(y_test_score) 219 | runner.inspect_history(n=1) 220 | y_train_score, y_train = Evaluate(devset=trainset, **eval_params)(optimized_program) 221 | try: 222 | logs = str(optimized_program.trial_logs) 223 | llm = json.dumps(optimized_program.dump_state()) 224 | except Exception as e: 225 | print("Failed to dump model state.", e) 226 | logs = str(optimized_program.trial_logs) 227 | llm = str(optimized_program.dump_state()) 228 | 229 | state = json.dumps( 230 | { 231 | "y_test_score": y_test_score / 100.0, 232 | "y_train_score": y_train_score / 100.0, 233 | "y_test_input": [v[0].toDict() for v in y_test], 234 | "y_test_output": [v[1].toDict() for v in y_test], 235 | "y_train_input": [v[0].toDict() for v in y_train], 236 | "y_train_output": [v[1].toDict() for v in y_train], 237 | "run_id": run_id, 238 | "logs": logs, 239 | "model": llm, 240 | }, 241 | indent=4, 242 | ) 243 | (RESULTS_DIR / f"{run_id}.json").write_text(state) 244 | 245 | 246 | if __name__ == "__main__": 247 | EMBEDDING_FUNC = OpenAIEmbeddingFunction( 248 | api_key=json.loads(MODEL_CONFIGS["gpt-3.5"]["credentials"].read_bytes())["api_key"], 249 | model_name="text-embedding-3-small", 250 | ) 251 | main() 252 | -------------------------------------------------------------------------------- /examples/paper_rag/rag_tuning_sammo.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from sammo.base import EvaluationScore 4 | from sammo.mutators import * 5 | from sammo.runners import OpenAIEmbedding, OpenAIChat 6 | from sammo.throttler import AtMost 7 | 8 | logger = sammo.setup_logger(log_prompts_to_file=True) 9 | import pandas as pd 10 | import click 11 | import sammo.store 12 | from sammo import search_op 13 | from sammo.instructions import * 14 | from sammo.components import * 15 | from sammo.dataformatters import JSONDataFormatter, QuestionAnswerFormatter, XMLDataFormatter 16 | import json 17 | from sammo.search import ( 18 | EnumerativeSearch, 19 | ) 20 | 21 | MAIN_FOLDER = sammo.utils.DEFAULT_SAVE_PATH 22 | CONFIG_PATH = MAIN_FOLDER.parent.parent.parent / "config" 23 | MODEL_CONFIGS = { 24 | "gpt-3.5": { 25 | "full_id": "gpt-3.5-turbo-16k-0613", 26 | "equivalence_class": "gpt-3.5-turbo-16k", 27 | "credentials": CONFIG_PATH / "personal.openai", 28 | "rate_limit": 10, 29 | "timeout": 90, 30 | }, 31 | "gpt-4": { 32 | "full_id": "gpt-4-0613", 33 | "equivalence_class": "gpt-4-0613", 34 | "credentials": CONFIG_PATH / "personal.openai", 35 | "rate_limit": 10, 36 | "timeout": 90, 37 | }, 38 | "llama-2": { 39 | "full_id": "meta-llama/Llama-2-70b-chat-hf", 40 | "equivalence_class": "Llama-2-70b-chat-hf", 41 | "credentials": CONFIG_PATH / "personal.deepinfra", 42 | "rate_limit": [AtMost(10, "running"), AtMost(2, "rejected", 1)], 43 | "timeout": 180, 44 | }, 45 | "mixtral": { 46 | "full_id": "cognitivecomputations/dolphin-2.6-mixtral-8x7b", 47 | "equivalence_class": "dolphin-2.6-mixtral-8x7b", 48 | "credentials": CONFIG_PATH / "personal.deepinfra", 49 | "rate_limit": [AtMost(10, "running"), AtMost(2, "rejected", 1)], 50 | "timeout": 180, 51 | }, 52 | } 53 | MODELS = list(MODEL_CONFIGS.keys()) 54 | DATA = MAIN_FOLDER.parent / "data_splits_new.json" 55 | TASKS = ["smcalflow", "geo880", "overnight"] 56 | 57 | 58 | def accuracy(y_true: DataTable, y_pred: DataTable) -> EvaluationScore: 59 | y_true, y_pred = y_true.outputs.normalized_values(on_empty=""), y_pred.outputs.normalized_values(on_empty="") 60 | mistakes = list() 61 | for i in range(len(y_true)): 62 | if y_true[i].lower() != str(y_pred[i]).lower(): 63 | mistakes.append(i) 64 | 65 | return EvaluationScore(1 - len(mistakes) / len(y_true), mistakes) 66 | 67 | 68 | class RagSearchSpace: 69 | def __init__(self, dtrain, examples, embedding_runner): 70 | self.examples = examples 71 | self.dtrain = dtrain 72 | self._embedding_runner = embedding_runner 73 | 74 | def __call__(self, return_raw=False): 75 | orientation = search_op.one_of(["item", "kind"], reference_id="orientation") 76 | example_formatter = search_op.one_of( 77 | [ 78 | QuestionAnswerFormatter( 79 | all_labels=self.dtrain.outputs.unique(), orient=orientation, attributes_processor=None 80 | ), 81 | XMLDataFormatter(orient=orientation, attributes_processor=None), 82 | JSONDataFormatter(orient=orientation, attributes_processor=None), 83 | ] 84 | ) 85 | 86 | instr = search_op.one_of(["full_dd", "list_of_operators"], reference_id="instructions") 87 | structure = [ 88 | Section("Syntax", f"{self.dtrain.constants[instr]}"), 89 | Section( 90 | "Examples", 91 | EmbeddingFewshotExamples( 92 | self._embedding_runner, 93 | self.examples, 94 | search_op.one_of([10, 5], reference_id="n_examples"), 95 | budget="relative", 96 | ), 97 | ), 98 | Section( 99 | "Complete and output in the same format as above", 100 | InputData(id_offset=len(self.examples)), 101 | ), 102 | ] 103 | instructions = MetaPrompt(structure, render_as="markdown", data_formatter=example_formatter) 104 | return Output(instructions.with_extractor("empty_result"), minibatch_size=1, on_error="empty_result") 105 | 106 | 107 | def load_task(task_id, data_path=DATA): 108 | task_info = json.loads(pathlib.Path(data_path).read_bytes())[task_id] 109 | return {k: DataTable.from_json(v) for k, v in task_info.items()} 110 | 111 | 112 | @click.command() 113 | @click.option("--llm", default=MODELS[0], type=click.Choice(MODELS), prompt=True) 114 | @click.option("--task-id", default=TASKS[0], type=click.Choice(TASKS), prompt=True) 115 | @click.option("--uuid", default=None, type=str) 116 | @click.option("--confirmed", is_flag=True, default=None) 117 | def main(llm, task_id, uuid, confirmed): 118 | if confirmed is None: 119 | click.confirm(f"Do you want to run {task_id} with {llm}?", abort=True, default=True) 120 | loaded_data = load_task(task_id) 121 | d_incontext, d_test = loaded_data["incontext"], loaded_data["test"] 122 | d_train = loaded_data["train"] 123 | 124 | print("Duplicates in train:", len(d_incontext) - len(d_incontext.inputs.unique())) 125 | print(f"Dataset sizes: {len(d_train)} (train), {len(d_incontext)} (incontext), {len(d_test)} (test)") 126 | 127 | model_config = MODEL_CONFIGS[llm] 128 | run_id = f"{task_id}_{model_config['equivalence_class'].replace('/', '_')}" 129 | runner = OpenAIChat( 130 | model_id=model_config["full_id"], 131 | api_config=model_config["credentials"], 132 | equivalence_class=model_config["equivalence_class"], 133 | rate_limit=model_config["rate_limit"], 134 | cache=sammo.store.PersistentDict(MAIN_FOLDER / f"{run_id}.cache.tsv"), 135 | timeout=model_config["timeout"], 136 | max_retries=50000, 137 | ) 138 | 139 | embedder = OpenAIEmbedding( 140 | model_id="text-embedding-3-small", 141 | api_config=CONFIG_PATH / "personal.openai", 142 | rate_limit=10, 143 | cache=sammo.store.SqlLiteDict(MAIN_FOLDER / f"{task_id}" / "fewshotcache"), 144 | ) 145 | search_space = RagSearchSpace(d_train, d_incontext, embedding_runner=embedder) 146 | 147 | # Baseline 148 | baseline_model = EnumerativeSearch(runner, search_space, accuracy, maximize=True, max_candidates=1) 149 | baseline_model.fit_transform(d_train) 150 | dtest_baseline = baseline_model.transform(d_test) 151 | baseline_model.save_json(MAIN_FOLDER / "baseline" / f"{run_id}.model.json") 152 | 153 | # SAMMO 154 | sammo_model = EnumerativeSearch(runner, search_space, accuracy, maximize=True) 155 | sammo_model.fit(d_train) 156 | sammo_model.show_report() 157 | dtest_sammo = sammo_model.transform(d_test) 158 | sammo_model.save_json(MAIN_FOLDER / "sammo" / f"{run_id}.model.json") 159 | 160 | print(f"Baseline (test):\n {accuracy(d_test, dtest_baseline)}") 161 | print(f"SAMMO (test):\n {accuracy(d_test, dtest_sammo)}") 162 | 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sammo" 3 | version = "0.3.2" 4 | description = "A flexible, easy-to-use library for running and optimizing prompts for Large Language Models (LLMs)." 5 | authors = ["Tobias Schnabel"] 6 | license = "MIT" 7 | readme = "README.md" 8 | repository = "https://github.com/microsoft/sammo/" 9 | documentation = "https://microsoft.github.io/sammo/docs/" 10 | packages = [ 11 | { include = "sammo" } 12 | ] 13 | 14 | [tool.poetry.dependencies] 15 | python = "^3.9,<3.13" 16 | beartype = "^0.15" 17 | benepar = {version = "^0.2", optional = true} 18 | filelock = "^3.12" 19 | frozendict = "^2.3" 20 | jsonpath_ng = "^1.5" 21 | markdown-it-py = "^2.2" 22 | more-itertools = "^10.1" 23 | numpy = "^1.25" 24 | orjson = "^3.9" 25 | pybars3 = "^0.9" 26 | pyglove = "^0.4" 27 | spacy = "^3.6" 28 | tabulate = "^0.9" 29 | xmltodict = "^0.13" 30 | PyYAML = "^6.0" 31 | aiohttp = "^3.6" 32 | diskcache = "^5.2" 33 | dill = "^0.3" 34 | quattro = "^24" 35 | async-timeout = "^4.0.3" 36 | lxml = "^5.3" 37 | cssselect = "^1.2" 38 | mistletoe = "^1.4" 39 | 40 | [tool.poetry.extras] 41 | parser = ["benepar"] 42 | 43 | [tool.poetry.group.dev] 44 | optional = true 45 | 46 | [tool.poetry.group.dev.dependencies] 47 | pytest = "^7.4" 48 | pytest-skip-slow = "*" 49 | pytest-mock = "*" 50 | pytest-asyncio = "*" 51 | black = "^23" 52 | pandas = "*" 53 | pre-commit = "^3.6.0" 54 | mypy = "^1.8.0" 55 | poethepoet = "^0.24.4" 56 | jupyter-book = "^0.15" 57 | astroid = "^3.0.2" 58 | sphinx-autodoc2 = "*" 59 | poetry-core = "^1.8.1" # (see set_version.py) 60 | 61 | [tool.black] 62 | line-length = 120 63 | 64 | [tool.poe.tasks.build-docs] 65 | help = "Build the documentation site" 66 | cmd = "jb build --path-output _build_docs docs" 67 | 68 | [tool.poe.tasks.serve-docs] 69 | help = "Preview the documentation site using python's built-in http server" 70 | cmd = "python -m http.server -d _build_docs/_build/html/" 71 | 72 | [tool.poe.tasks.type-check] 73 | help = "Run static type checking" 74 | cmd = "mypy sammo" 75 | 76 | [tool.poe.tasks.test] 77 | help = "Run tests" 78 | cmd = "pytest" 79 | 80 | [tool.poe.tasks.pre-commit] 81 | help = "Run all pre-commit checks" 82 | cmd = "pre-commit run --all --show-diff-on-failure --color=always" 83 | 84 | [build-system] 85 | requires = ["poetry-core"] 86 | build-backend = "poetry.core.masonry.api" 87 | 88 | [tool.pytest.ini_options] 89 | asyncio_mode = "auto" 90 | -------------------------------------------------------------------------------- /sammo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import logging 4 | import beartype 5 | from beartype.typing import Union 6 | import sammo.utils as utils 7 | from pathlib import Path 8 | 9 | PROMPT_LOGGER_NAME = "prompt_logger" 10 | 11 | 12 | @beartype.beartype 13 | def setup_logger( 14 | default_level: Union[int, str] = "DEBUG", 15 | log_prompts_to_file: bool = False, 16 | prompt_level: Union[int, str] = "DEBUG", 17 | prompt_logfile_name: str = None, 18 | ) -> logging.Logger: 19 | if log_prompts_to_file: 20 | if prompt_logfile_name is None: 21 | prompt_logfile_name = (utils.MAIN_PATH / "logs" / utils.MAIN_NAME).with_suffix(".log") 22 | log_prompts_to_file = Path(prompt_logfile_name) 23 | log_prompts_to_file.parent.mkdir(parents=True, exist_ok=True) 24 | file_handler = logging.FileHandler(log_prompts_to_file, mode="w", delay=0, encoding="utf-8") 25 | file_handler.setFormatter(logging.Formatter("===%(asctime)s===\n%(message)s")) 26 | 27 | # add logger just for prompt requests 28 | prompt_logger = logging.getLogger(PROMPT_LOGGER_NAME) 29 | prompt_logger.setLevel(prompt_level) 30 | prompt_logger.addHandler(file_handler) 31 | 32 | logger = logging.getLogger(__name__) 33 | logger.setLevel(default_level) 34 | logger.handlers = list() 35 | console = logging.StreamHandler() 36 | console.setFormatter(logging.Formatter("%(asctime)s,%(msecs)d: %(message)s", datefmt="%H:%M:%S")) 37 | logger.addHandler(console) 38 | 39 | return logger 40 | -------------------------------------------------------------------------------- /sammo/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from sammo.base import Template, VerbatimText 5 | 6 | 7 | def test_simple_query(): 8 | a = VerbatimText("Hello", reference_id="1") 9 | b = VerbatimText("World", reference_id="2") 10 | nested = Template("{{a}} {{b}}", a=a, b=b) 11 | assert nested.find_first("#1").node == a 12 | assert nested.find_first("#2").node == b 13 | assert nested.find_first("a").node == a 14 | assert nested.find_first("b").node == b 15 | assert nested.find_first("#3") is None 16 | assert len(nested.find_all("verbatimtext")) == 2 17 | -------------------------------------------------------------------------------- /sammo/compactbars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """ 4 | Provides a way of displaying multiple progress bars in a single line. Works in both interactive and non-interactive 5 | environments. 6 | """ 7 | from __future__ import annotations 8 | import collections 9 | import datetime 10 | import io 11 | import math 12 | import shutil 13 | import sys 14 | import time 15 | from beartype import beartype 16 | from beartype.typing import Union 17 | 18 | from sammo import utils 19 | 20 | __all__ = ["CompactProgressBars", "SubProgressBar"] 21 | 22 | 23 | @beartype 24 | class LinePrinter: 25 | """ 26 | A class that prints a line to a text output. 27 | 28 | :param out: The output device to print to. 29 | """ 30 | 31 | BACKSPACE = "\b" * 1000 32 | 33 | def __init__(self, out: io.TextIOBase = sys.stdout): 34 | self._out = out 35 | self._clear_prefix = "" 36 | self._is_interactive = utils.is_interactive() 37 | self._is_finalized = False 38 | 39 | @staticmethod 40 | def get_terminal_width(default_width: int = 120) -> int: 41 | width, _ = shutil.get_terminal_size(fallback=(default_width, 80)) 42 | return width 43 | 44 | def print(self, value: str): 45 | print(self._clear_prefix + value, file=self._out, flush=True, end="") 46 | if self._is_interactive: 47 | self._clear_prefix = "\r" 48 | else: 49 | self._clear_prefix = self.BACKSPACE[: len(value)] 50 | 51 | def finalize(self): 52 | if not self._is_finalized: 53 | print(file=self._out, flush=True) 54 | self._is_finalized = True 55 | 56 | 57 | @beartype 58 | class SubProgressBar: 59 | """ 60 | A class that represents an individual progress bar. 61 | 62 | :param total: The total number of items to process. 63 | :param parent: The parent progress bar. 64 | :param moving_avg_size: The size of the moving average window for calculating the rate. 65 | :param width: The width of the progress bar in characters. 66 | :param prefix: The prefix to display before the progress bar. 67 | :param show_rate: Whether to show the rate of progress. 68 | :param show_time: Whether to show the elapsed time and ETA. 69 | :param ascii: Whether to use ASCII (or UTF-8) characters for the progress bar. If "auto", uses ASCII if pdb is imported. 70 | """ 71 | 72 | phases = {True: (" ", "_", "*", "#"), False: (" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉", "█")} 73 | 74 | def __init__( 75 | self, 76 | total: int, 77 | parent: "CompactProgressBars", 78 | moving_avg_size: int = 10, 79 | width: int = 100, 80 | prefix: str = "", 81 | show_rate: bool = True, 82 | show_time: bool = True, 83 | ascii: str = "auto", 84 | ): 85 | self._start = time.monotonic() 86 | self._now = time.monotonic() 87 | self._last_updates = collections.deque(maxlen=moving_avg_size) 88 | self._last_updates.appendleft(self._now) 89 | if ascii == "auto": 90 | ascii = "pdb" in sys.modules # Use ascii if debugging 91 | self.phases = SubProgressBar.phases[ascii] 92 | self.total = total 93 | self._n_done = 0 94 | self._prefix = prefix 95 | self._parent = parent 96 | self._show_time = show_time 97 | self._width = width 98 | self.max_width = width 99 | self._show_rate = show_rate 100 | 101 | @property 102 | def total(self): 103 | return self._total 104 | 105 | @total.setter 106 | def total(self, value): 107 | if value <= 0: 108 | raise ValueError("total must be positive") 109 | self._total = value 110 | 111 | @property 112 | def done(self): 113 | return self._n_done == self._total 114 | 115 | @property 116 | def elapsed_long(self): 117 | return datetime.timedelta(seconds=int(self._now - self._start)) 118 | 119 | @property 120 | def elapsed(self): 121 | return self._shorten(self.elapsed_long) 122 | 123 | @classmethod 124 | def _shorten(cls, val: datetime.timedelta): 125 | """ 126 | Shortens a timedelta object representation by removing leading hours. 127 | """ 128 | if val.total_seconds() < 3600: 129 | return str(val)[2:] 130 | else: 131 | return val 132 | 133 | @property 134 | def phase(self): 135 | remainder = (self._width * self._n_done / self._total) % 1 136 | phase_index = int(round((len(self.phases) - 1) * remainder)) 137 | if remainder == 0: 138 | return "" 139 | else: 140 | return self.phases[phase_index] 141 | 142 | @property 143 | def barwidth(self): 144 | return int(self._width * self._n_done / self._total) 145 | 146 | @property 147 | def rate(self): 148 | if self._now == self._start: 149 | return 0 150 | elif self._last_updates[0] == self._last_updates[-1]: 151 | return 1 / (self._now - self._start) 152 | else: 153 | return (len(self._last_updates) - 1) / (self._last_updates[0] - self._last_updates[-1]) 154 | 155 | @property 156 | def eta(self): 157 | if self.rate > 0: 158 | return self._shorten(datetime.timedelta(seconds=math.ceil((self._total - self._n_done) / self.rate))) 159 | else: 160 | return "??:??" 161 | 162 | def update(self, *args, **kwargs): 163 | """ 164 | Increases the number of completed tasks by one for the progress bar. 165 | """ 166 | self._n_done += 1 167 | self._now = time.monotonic() 168 | self._last_updates.appendleft(self._now) 169 | self._parent._refresh_display(force=self._n_done == self._total) 170 | 171 | def __str__(self): 172 | rate = "" 173 | time = "" 174 | if self._show_time: 175 | if self._show_rate: 176 | rate = f", {self.rate:.2f}it/s" 177 | time = f"[{self.elapsed}<{self.eta}{rate}]" 178 | 179 | template = f"{self._prefix}{{x}}{self._n_done}/{self._total}{time}" 180 | self._width = max(5, self.max_width - (len(template) - 3)) 181 | return template.format( 182 | x=f"[{self.phases[-1] * self.barwidth}" 183 | f"{self.phase}{(self.phases[0] * (self._width - self.barwidth - len(self.phase)))}]" 184 | ) 185 | 186 | 187 | @beartype 188 | class CompactProgressBars: 189 | """ 190 | A class that represents a set of progress bars drawn next to each in a single line. 191 | 192 | :param width: The total width of the progress bar layout in characters. 193 | :param refresh_interval: The minimum time interval between display refreshes. 194 | """ 195 | 196 | def __init__(self, width: Union[int, None] = None, refresh_interval: float = 1 / 50): 197 | self._bars = collections.OrderedDict() 198 | self._printer = LinePrinter() 199 | self._last_update = 0 200 | self._refresh_interval = refresh_interval 201 | 202 | if width is None: 203 | self._width = self._printer.get_terminal_width() 204 | else: 205 | self._width = width 206 | 207 | def _refresh_display(self, force: bool = False): 208 | if self._should_refresh() or force: 209 | self._printer.print(str(self)) 210 | if self._bars and list(self._bars.values())[0].done: 211 | self.finalize() 212 | 213 | def _should_refresh(self) -> bool: 214 | if time.monotonic() - self._last_update > self._refresh_interval: 215 | self._last_update = time.monotonic() 216 | return True 217 | return False 218 | 219 | def get( 220 | self, 221 | id: str, 222 | total: Union[int, None] = None, 223 | position: Union[int, None] = None, 224 | display_name: Union[str, None] = None, 225 | **kwargs, 226 | ) -> SubProgressBar: 227 | """ 228 | Gets existing or creates a new progress bar given an id. 229 | 230 | :param id: The id of the progress bar for later reference. 231 | :param total: Number of increments. 232 | :param position: Truncate existing bars beyond index and insert this one at the position. 233 | :param display_name: The name to display for the progress bar. Defaults to id. 234 | :param **kwargs: Additional arguments to pass to the SubProgressbar constructor. 235 | :return New bar if it doesn't exist, otherwise a reference to the existing one. 236 | """ 237 | if id in self._bars: 238 | existing_bar = self._bars[id] 239 | existing_bar._total = total 240 | return existing_bar 241 | 242 | if position is not None: 243 | self._bars = {k: v for i, (k, v) in enumerate(self._bars.items()) if i < position} 244 | new_width = self._width // (len(self._bars) + 1) 245 | for bar in self._bars.values(): 246 | bar.max_width = new_width 247 | if display_name is None: 248 | display_name = id 249 | new_bar = SubProgressBar(total, parent=self, width=new_width, prefix=display_name, **kwargs) 250 | self._bars[id] = new_bar 251 | self._refresh_display(True) 252 | return new_bar 253 | 254 | def finalize(self) -> None: 255 | """Finishes the line and moves the cursor to the next line.""" 256 | self._printer.finalize() 257 | 258 | def __str__(self) -> str: 259 | return " >> ".join([str(b) for b in self._bars.values()]) 260 | -------------------------------------------------------------------------------- /sammo/compactbars_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import io 4 | 5 | import pytest 6 | 7 | from sammo.compactbars import LinePrinter, SubProgressBar, CompactProgressBars 8 | 9 | 10 | class TestLinePrinter: 11 | def setup_method(self, method): 12 | self.out = io.StringIO() 13 | self.lp = LinePrinter(out=self.out) 14 | 15 | def test_get_terminal_width(self): 16 | width = self.lp.get_terminal_width() 17 | assert width > 0 18 | 19 | def test_print(self): 20 | self.lp.print("Hello") 21 | assert self.out.getvalue() == "Hello" 22 | self.lp.print("World") 23 | assert self.out.getvalue() == "Hello\b\b\b\b\bWorld" 24 | 25 | def test_finalize(self): 26 | self.lp.finalize() 27 | assert self.out.getvalue() == "\n" 28 | 29 | 30 | class TestSubProgressBar: 31 | def setup_method(self, method): 32 | self.parent = CompactProgressBars() 33 | self.spb = SubProgressBar(100, parent=self.parent) 34 | 35 | def test_invalid_total(self): 36 | with pytest.raises(ValueError): 37 | SubProgressBar(0, parent=self.parent) 38 | 39 | def test_default_values(self): 40 | assert self.spb._n_done == 0 41 | assert self.spb._total == 100 42 | 43 | def test_update(self): 44 | self.spb.update() 45 | assert self.spb._n_done == 1 46 | 47 | def test_str(self): 48 | result = str(self.spb) 49 | assert "[" in result 50 | assert "]" in result 51 | 52 | 53 | class TestCompactProgressBars: 54 | def setup_method(self, method): 55 | self.cb = CompactProgressBars() 56 | 57 | def test_get_new_bar(self): 58 | bar = self.cb.get("test", 100) 59 | assert "test" in self.cb._bars 60 | 61 | def test_get_existing_bar(self): 62 | self.cb.get("test", 100) 63 | bar = self.cb.get("test", 200) 64 | assert bar.total == 200 65 | 66 | def test_str_representation(self): 67 | self.cb.get("test1", 100) 68 | self.cb.get("test2", 200) 69 | result = str(self.cb) 70 | assert ">>" in result 71 | 72 | def test_init_default_width(self): 73 | assert self.cb._width is not None 74 | 75 | def test_should_refresh(self): 76 | result = self.cb._should_refresh() 77 | assert result 78 | -------------------------------------------------------------------------------- /sammo/components_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import pytest 4 | 5 | from sammo.base import Template 6 | from sammo.components import Output, Union, ForEach, GenerateText 7 | from sammo.runners import MockedRunner 8 | 9 | 10 | @pytest.mark.asyncio 11 | async def test_union(): 12 | res = await Union("a", "b", "c")(None, dict()) 13 | assert res.value == ["a", "b", "c"] 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_union_run(): 18 | res = await Union("a", "b", "c").arun(None) 19 | assert [r.value for r in res] == ["a", "b", "c"] 20 | 21 | 22 | @pytest.mark.asyncio 23 | async def test_for_each(): 24 | res = await ForEach("x", Union("a", "b", "c"), Template(".{{x}}"))(None, dict()) 25 | assert res.value == [".a", ".b", ".c"] 26 | 27 | 28 | @pytest.mark.asyncio 29 | async def test_generate_text(): 30 | runner = MockedRunner("Return value.") 31 | res = await GenerateText("This is a simple test.")(runner, dict()) 32 | assert res.value == "Return value." 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_override_runner(): 37 | runner1 = MockedRunner("test1") 38 | runner2 = MockedRunner("test2") 39 | res1 = GenerateText("Get test1", runner=runner1) 40 | res2 = GenerateText(Template("I got {{res1}}", res1=res1)) 41 | res = await res2(runner2, dict()) 42 | assert runner2.prompt_log[0] == "I got test1" 43 | assert res.value == "test2" 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_child_runner_not_overridden(): 48 | runner1 = MockedRunner("test1") 49 | runner2 = MockedRunner("test2") 50 | res2 = GenerateText(Template("I got {{res1}}", res1=GenerateText("Get test1")), runner=runner2) 51 | res = await res2(runner1, dict()) 52 | assert runner2.prompt_log[0] == "I got test1" 53 | assert res.value == "test2" 54 | -------------------------------------------------------------------------------- /sammo/css_matching.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import pyglove as pg 4 | from lxml import etree 5 | from lxml.cssselect import CSSSelector 6 | 7 | 8 | class XmlTree: 9 | def __init__(self, root, sym_path=None): 10 | self.root = root 11 | self.sym_path = sym_path 12 | 13 | @classmethod 14 | def from_pyglove(cls, pg_node, treat_as_attributes=({"reference_id": "id", "reference_classes": "class"})): 15 | if pg.is_abstract(pg_node): 16 | return ValueError("PyGlove object needs to be fully instantiated.") 17 | 18 | root = etree.Element("root") 19 | sym_path = {root: ""} 20 | 21 | # Do a breadth-first search 22 | fifo_queue = [(pg_node, root)] 23 | while fifo_queue: 24 | node, parent = fifo_queue.pop(0) 25 | 26 | if isinstance(node, list): 27 | for i, v in enumerate(node): 28 | xml_node = etree.SubElement(parent, v.__class__.__name__.lower()) 29 | sym_path[xml_node] = v.sym_path if hasattr(v, "sym_path") else node.sym_path + f"[{i}]" 30 | fifo_queue.append((v, xml_node)) 31 | elif isinstance(node, (pg.Object, dict)): 32 | for k, v in node.sym_items() if isinstance(node, pg.Object) else node.items(): 33 | if k in treat_as_attributes: 34 | if v is not None: 35 | parent.attrib[treat_as_attributes[k]] = " ".join(v) if isinstance(v, list) else str(v) 36 | else: 37 | xml_node = etree.SubElement(parent, k) 38 | sym_path[xml_node] = v.sym_path if hasattr(v, "sym_path") else node.sym_path + f".{k}" 39 | 40 | if v is not None and not isinstance(v, (str, int, list, float, pg.List)): 41 | xml_node = etree.SubElement(xml_node, v.__class__.__name__.lower()) 42 | 43 | sym_path[xml_node] = v.sym_path if hasattr(v, "sym_path") else node.sym_path + f"[{k}]" 44 | fifo_queue.append((v, xml_node)) 45 | 46 | elif isinstance(node, (str, int, float)): 47 | parent.text = str(node) 48 | elif node is not None: 49 | raise ValueError(f"Unsupported type: {type(node)}") 50 | return cls(root, sym_path) 51 | 52 | @staticmethod 53 | def to_string(xml_node) -> str: 54 | if xml_node is not None: 55 | return etree.tostring(xml_node, pretty_print=True).decode() 56 | else: 57 | return "None" 58 | 59 | def __repr__(self): 60 | return self.to_string(self.root) 61 | 62 | def __str__(self): 63 | return self.to_string(self.root) 64 | 65 | def find_all(self, css_expression: str) -> list: 66 | selector = CSSSelector(css_expression) 67 | matches = [match for match in selector(self.root)] 68 | return [self.sym_path[match] for match in matches] 69 | -------------------------------------------------------------------------------- /sammo/css_matching_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from sammo.css_matching import XmlTree 5 | import pytest 6 | import pyglove as pg 7 | 8 | from sammo.instructions import MetaPrompt, Section, Paragraph, InputData 9 | from sammo.search_op import one_of, get_first_point 10 | 11 | 12 | @pytest.fixture 13 | def simple_tree(): 14 | return pg.Dict(name="root", children=[pg.Dict(child1="value1"), pg.Dict(child2="value2")]) 15 | 16 | 17 | @pytest.fixture 18 | def sammo_example(): 19 | return MetaPrompt( 20 | [ 21 | Section( 22 | title="T-A", 23 | content=Section(title="Title of subsection", content="Text in subsection", reference_id="42"), 24 | reference_id="A", 25 | reference_classes=["class1"], 26 | ), 27 | Section(title="T-B", reference_id="B", content="Other text.", reference_classes=["class1"]), 28 | ] 29 | ) 30 | 31 | 32 | def test_from_pyglove(simple_tree): 33 | xml_tree = XmlTree.from_pyglove(simple_tree) 34 | assert isinstance(xml_tree, XmlTree) 35 | assert xml_tree.root.tag == "root" 36 | 37 | 38 | def test_find_all(simple_tree): 39 | xml_tree = XmlTree.from_pyglove(simple_tree) 40 | matches = xml_tree.find_all("child1") 41 | assert len(matches) == 1 42 | assert matches[0] == "children[0].child1" 43 | 44 | 45 | def test_empty_pg_node(): 46 | str(XmlTree.from_pyglove(None)) == "" 47 | 48 | 49 | def test_list(): 50 | list = pg.List(["yes", "no"]) 51 | xml_tree = XmlTree.from_pyglove(list) 52 | assert len(xml_tree.find_all("str")) == 2 53 | 54 | 55 | def test_sammo_integration(sammo_example): 56 | xml_tree = XmlTree.from_pyglove(sammo_example) 57 | matches = xml_tree.find_all("#42") 58 | assert len(matches) == 1 59 | matches = xml_tree.find_all("#A") 60 | assert len(matches) == 1 61 | matches = xml_tree.find_all(".class1") 62 | assert len(matches) == 2 63 | -------------------------------------------------------------------------------- /sammo/data_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import pytest 4 | import pandas as pd 5 | from sammo.data import DataTable, MinibatchIterator, Accessor 6 | 7 | 8 | @pytest.fixture 9 | def sample_dict_data(): 10 | attributes = [{"name": "Alice", "age": 25}, {"name": "Bob", "age": 30}, {"name": "Charlie", "age": 35}] 11 | labels = [1, 0, 1] 12 | return DataTable(attributes, labels) 13 | 14 | 15 | @pytest.fixture 16 | def sample_scalar_data(): 17 | attributes = ["Alice", "Bob", "Charlie"] 18 | labels = [1, 0, 1] 19 | return DataTable(attributes, labels) 20 | 21 | 22 | def test_datatable_filtered(sample_dict_data): 23 | filtered = sample_dict_data.inputs.filtered_on(lambda x: x["name"] == "Alice") 24 | assert len(filtered) == 1 25 | 26 | 27 | def test_datatable_to_string(sample_dict_data): 28 | string = sample_dict_data.to_string() 29 | assert isinstance(string, str) 30 | DataTable([]).to_string() 31 | 32 | 33 | def test_from_pandas(sample_dict_data): 34 | df = pd.DataFrame(sample_dict_data.to_records()) 35 | dt = DataTable.from_pandas(df, input_fields="input", output_fields="output") 36 | assert dt.to_records() == sample_dict_data.to_records() 37 | 38 | 39 | def test_datatable_indexing(sample_dict_data): 40 | sample_dict_data.outputs[0] 41 | sample_dict_data.outputs[[0, 1]] 42 | sample_dict_data.outputs[0:1] 43 | sample_dict_data.inputs[0] 44 | sample_dict_data.inputs[[0, 1]] 45 | sample_dict_data.inputs[0:1] 46 | 47 | 48 | def test_datatable_only_rows(): 49 | DataTable([1, 2, 3]) 50 | 51 | 52 | def test_datatable_to_dicts(sample_dict_data): 53 | dicts = sample_dict_data.to_records() 54 | expected = [ 55 | {"input": {"name": "Alice", "age": 25}, "output": 1}, 56 | {"input": {"name": "Bob", "age": 30}, "output": 0}, 57 | {"input": {"name": "Charlie", "age": 35}, "output": 1}, 58 | ] 59 | assert dicts == expected 60 | 61 | 62 | def test_datatable_scalars_to_dicts(sample_scalar_data): 63 | dicts = sample_scalar_data.to_records() 64 | expected = [ 65 | {"input": "Alice", "output": 1}, 66 | {"input": "Bob", "output": 0}, 67 | {"input": "Charlie", "output": 1}, 68 | ] 69 | assert dicts == expected 70 | 71 | 72 | def test_accessor_init(): 73 | parent_data = {"input": [1, 2, 3], "output": [4, 5, 6]} 74 | accessor = Accessor(parent_data, "input") 75 | assert hasattr(accessor, "_parent") 76 | assert hasattr(accessor, "_group") 77 | assert accessor._group == "input" 78 | 79 | 80 | def test_accessor_safe_get(): 81 | y = {"test": 1} 82 | result = Accessor._safe_get(y, "test") 83 | assert result == 1 84 | 85 | 86 | def test_datatable_init(): 87 | attributes = [1, 2, 3] 88 | labels = [4, 5, 6] 89 | table = DataTable(attributes, labels) 90 | assert len(table) == 3 91 | 92 | 93 | def test_datatable_persistent_hash(): 94 | attributes = [1, 2, 3] 95 | labels = [4, 5, 6] 96 | table = DataTable(attributes, labels) 97 | assert isinstance(table.persistent_hash(), int) 98 | 99 | 100 | def test_hash_changes(sample_dict_data): 101 | hash_before = sample_dict_data.persistent_hash() 102 | sample_dict_data.outputs[[0, 1]] = [1, 0] 103 | hash_after = sample_dict_data.persistent_hash() 104 | assert hash_before == hash_after 105 | sample_dict_data.outputs[0] = [2] 106 | hash_after = sample_dict_data.persistent_hash() 107 | assert hash_before != hash_after 108 | 109 | 110 | def test_datatable_from_records(): 111 | records = [{"input": 4, "output": 1}, {"input": 5, "output": 2}, {"input": 6, "output": 3}] 112 | table = DataTable.from_records(records) 113 | assert len(table) == 3 114 | 115 | 116 | def test_datable_set_all(sample_dict_data): 117 | sample_dict_data.outputs[:] = 3 118 | sample_dict_data.outputs[:] = "hello" 119 | 120 | 121 | def test_datatable_set_to_list(sample_dict_data): 122 | sample_dict_data.outputs[0] = [1, 2, 3] 123 | assert sample_dict_data.outputs.raw_values[0] == [1, 2, 3] 124 | 125 | 126 | def test_datatable_getitem_single(): 127 | attributes = [1, 2, 3] 128 | labels = [4, 5, 6] 129 | table = DataTable(attributes, labels) 130 | sliced = table[1] 131 | assert len(sliced) == 1 132 | assert sliced.outputs.raw_values == [5] 133 | 134 | 135 | def test_datatable_getitem_multiple(): 136 | attributes = [1, 2, 3] 137 | labels = [4, 5, 6] 138 | table = DataTable(attributes, labels) 139 | sliced = table[[1, 2]] 140 | assert len(sliced) == 2 141 | assert sliced.outputs.raw_values == [5, 6] 142 | 143 | 144 | def test_datatable_sample(): 145 | attributes = [1, 2, 3, 4, 5] 146 | labels = [6, 7, 8, 9, 10] 147 | table = DataTable(attributes, labels) 148 | sampled = table.sample(3) 149 | assert len(sampled) == 3 150 | 151 | 152 | def test_datatable_shuffle(): 153 | attributes = [1, 2, 3, 4, 5] 154 | labels = [6, 7, 8, 9, 10] 155 | table = DataTable(attributes, labels) 156 | shuffled = table.shuffle() 157 | assert len(shuffled) == 5 158 | 159 | 160 | def test_datatable_random_split(): 161 | attributes = [1, 2, 3, 4, 5] 162 | labels = [6, 7, 8, 9, 10] 163 | table = DataTable(attributes, labels) 164 | split1, split2 = table.random_split(3, 2) 165 | assert len(split1) == 3 166 | assert len(split2) == 2 167 | 168 | 169 | def test_minibatch_iterator(): 170 | attributes = [1, 2, 3, 4, 5] 171 | labels = [6, 7, 8, 9, 10] 172 | table = DataTable(attributes, labels) 173 | iterator = MinibatchIterator(table, 2) 174 | batches = list(iterator) 175 | assert len(batches) == 3 176 | assert len(batches[0]) == 2 177 | assert len(batches[-1]) == 1 178 | iterator = MinibatchIterator(table, 1) 179 | batches = list(iterator) 180 | assert len(batches) == 5 181 | assert batches[0] == 0 182 | 183 | 184 | def test_persistent_hashing_consistency(sample_dict_data): 185 | hash1 = sample_dict_data.persistent_hash() 186 | hash2 = sample_dict_data.persistent_hash() 187 | assert hash1 == hash2 188 | -------------------------------------------------------------------------------- /sammo/dataformatters_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import pytest 4 | from sammo.dataformatters import DataFormatter, JSONDataFormatter, XMLDataFormatter 5 | from sammo.base import VerbatimText 6 | 7 | 8 | def test_format_batch_typical_data(): 9 | test = DataFormatter() 10 | test._dump = lambda x: x.records 11 | # call the format_batch function 12 | returned = test.format_batch([{"key": "value1"}], [{"label": 1}], [{"label": 0}]) 13 | expected = [ 14 | {"id": 0, "kind": "input", "kind_alias": "input", "kind_order": 0, "value": "value1"}, 15 | {"id": 0, "kind": "gold_label", "kind_alias": "output", "kind_order": 1, "value": 1}, 16 | {"id": 0, "kind": "predicted_label", "kind_alias": "predicted_output", "kind_order": 2, "value": 0}, 17 | ] 18 | 19 | assert returned == expected 20 | 21 | 22 | def test_format_not_flatten_data(): 23 | test = DataFormatter(flatten_1d_dicts=False) 24 | test._dump = lambda x: x.records 25 | # call the format_batch function 26 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 27 | expected = [ 28 | {"id": 0, "kind": "input", "kind_alias": "input", "kind_order": 0, "value": {"key": "value1"}}, 29 | {"id": 0, "kind": "gold_label", "kind_alias": "output", "kind_order": 1, "value": {"l": 1}}, 30 | {"id": 0, "kind": "predicted_label", "kind_alias": "predicted_output", "kind_order": 2, "value": {"l": 0}}, 31 | ] 32 | assert returned == expected 33 | 34 | returned_single = test.format_single({"key": "value1"}, {"l": 1}, {"l": 0}) 35 | assert returned_single == expected 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_format_flat_json_item_orient(): 40 | test = JSONDataFormatter(newline_delimited=False, flatten_1d_dicts=True, indent=None) 41 | 42 | # call the format_batch function 43 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 44 | expected = '[{"id": 0, "input": "value1", "output": 1, "predicted_output": 0}]' 45 | 46 | assert returned == expected 47 | 48 | result = await test.get_extractor(VerbatimText(returned))(None, {}) 49 | assert test._unwrap_results(result) == [1] 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_format_flat_ndjson_item_orient(): 54 | test = JSONDataFormatter(newline_delimited=True, flatten_1d_dicts=True, indent=None) 55 | 56 | # call the format_batch function 57 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 58 | expected = '{"id": 0, "input": "value1", "output": 1, "predicted_output": 0}' 59 | assert returned == expected 60 | 61 | result = await test.get_extractor(VerbatimText(returned))(None, {}) 62 | assert test._unwrap_results(result) == [1] 63 | 64 | 65 | @pytest.mark.asyncio 66 | async def test_format_flat_ndjson_kind_orient(): 67 | test = JSONDataFormatter(newline_delimited=True, flatten_1d_dicts=True, indent=None, orient="kind") 68 | 69 | # call the format_batch function 70 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 71 | expected = ( 72 | 'input: [{"id": 0, "value": "value1"}]\n' 73 | 'output: [{"id": 0, "value": 1}]\n' 74 | 'predicted_output: [{"id": 0, "value": 0}]' 75 | ) 76 | assert returned == expected 77 | 78 | 79 | def test_format_flat_json_kind_orient(): 80 | test = JSONDataFormatter(newline_delimited=False, flatten_1d_dicts=True, indent=None, orient="kind") 81 | 82 | # call the format_batch function 83 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 84 | expected = ( 85 | '{"input": [{"id": 0, "value": "value1"}], "output": [{"id": 0, "value": 1}], ' 86 | '"predicted_output": [{"id": 0, "value": 0}]}' 87 | ) 88 | assert returned == expected 89 | 90 | 91 | @pytest.mark.asyncio 92 | async def test_format_nested_json_item_orient(): 93 | test = JSONDataFormatter( 94 | newline_delimited=False, flatten_1d_dicts=False, indent=None, include_ids=False, orient="item" 95 | ) 96 | 97 | # call the format_batch function 98 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 99 | expected = '[{"input": {"key": "value1"}, "output": {"l": 1}, "predicted_output": {"l": 0}}]' 100 | assert returned == expected 101 | 102 | result = await test.get_extractor(VerbatimText(returned))(None, {}) 103 | assert result.value == [{"l": 1}] 104 | 105 | 106 | @pytest.mark.asyncio 107 | async def test_format_nested_json_item_orient_with_ids(): 108 | test = JSONDataFormatter( 109 | newline_delimited=False, flatten_1d_dicts=False, indent=None, include_ids=True, orient="item" 110 | ) 111 | 112 | # call the format_batch function 113 | returned = test.format_batch([{"key": "value1"}], [{"l": 1}], [{"l": 0}]) 114 | expected = '[{"id": 0, "input": {"key": "value1"}, "output": {"l": 1}, "predicted_output": {"l": 0}}]' 115 | assert returned == expected 116 | 117 | result = await test.get_extractor(VerbatimText(returned))(None, {}) 118 | assert result.value == [{"l": 1}] 119 | 120 | 121 | def test_xml_flat_item_orient(): 122 | test = XMLDataFormatter(flatten_1d_dicts=True, orient="item") 123 | 124 | # call the format_batch function 125 | returned = test.format_batch([{"key": "row1"}, {"key": "row2"}], [{"l": [1, 2]}, {"l": [0, 1]}]) 126 | expected = ( 127 | '\n' 128 | "\tvalue1\n" 129 | '\n' 130 | "\t1\n" 131 | "\t2\n" 132 | '\n' 133 | "\t0\n" 134 | "\t1\n" 135 | "" 136 | ) 137 | returned == expected 138 | -------------------------------------------------------------------------------- /sammo/express.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from collections import namedtuple 4 | from mistletoe.block_token import List 5 | from mistletoe.markdown_renderer import MarkdownRenderer 6 | 7 | from mistletoe import Document, block_token 8 | 9 | from sammo.base import Template 10 | from sammo.instructions import MetaPrompt, Section, Paragraph 11 | 12 | HTML_COMMENT = re.compile(r"") 13 | HTML_IDS = re.compile(r"#(\w+)($|\s)") 14 | HTML_CLASSES = re.compile(r"\.([\w-]+)($|\s)") 15 | 16 | 17 | def _extract_html_comment(text): 18 | rest = text 19 | inner_comment = "" 20 | 21 | if HTML_COMMENT.search(text) is not None: 22 | inner_comment = HTML_COMMENT.search(text).group(1) 23 | rest = HTML_COMMENT.sub("", text) 24 | 25 | return inner_comment, rest 26 | 27 | 28 | def _get_ids_and_classes(text): 29 | comment, rest = _extract_html_comment(text) 30 | ids = HTML_IDS.findall(comment) or list() 31 | ids = [i[0] for i in ids] 32 | 33 | classes = HTML_CLASSES.findall(comment) or list() 34 | classes = [c[0] for c in classes] 35 | 36 | return {"text": rest, "ids": ids, "classes": classes} 37 | 38 | 39 | class MarkdownParser: 40 | def __init__(self, input_text: str): 41 | self._input_text = input_text 42 | self._sammo_tree, self._sammo_config = None, None 43 | 44 | def _parse(self): 45 | if self._sammo_tree is None: 46 | json_tree, config = self._parse_annotated_markdown(self._input_text) 47 | self._sammo_tree = self._json_to_sammo(json_tree) 48 | self._sammo_config = config 49 | 50 | def get_sammo_program(self): 51 | self._parse() 52 | return self._sammo_tree 53 | 54 | def get_sammo_config(self): 55 | self._parse() 56 | return self._sammo_config 57 | 58 | @staticmethod 59 | def from_file(file_path): 60 | with open(file_path, "r", encoding="utf-8") as file: 61 | return MarkdownParser(file.read()) 62 | 63 | @staticmethod 64 | def _parse_annotated_markdown(text): 65 | doc = Document(text) 66 | sammo_config = dict() 67 | State = namedtuple("State", ["current", "parent", "level"]) 68 | with MarkdownRenderer() as mrender: 69 | processed = list() 70 | stack = [State(processed, processed, 0)] 71 | for element in doc.children: 72 | last = stack[-1] 73 | if isinstance(element, List): 74 | list_elements = list() 75 | classes = set() 76 | ids = set() 77 | 78 | for c in element.children: 79 | d = _get_ids_and_classes(mrender.render(c)) 80 | classes.update(d["classes"]) 81 | ids.update(d["ids"]) 82 | list_elements.append(d["text"]) 83 | 84 | last.current.append( 85 | {"type": "list", "children": list_elements, "class": list(classes), "id": list(ids)} 86 | ) 87 | elif isinstance(element, block_token.Heading): 88 | d = _get_ids_and_classes(mrender.render(element)) 89 | new = { 90 | "type": "section", 91 | "title": d["text"], 92 | "children": list(), 93 | "id": d["ids"], 94 | "class": d["classes"], 95 | } 96 | if element.level < last.level: 97 | while stack[-1].level >= element.level: 98 | stack.pop() 99 | scope = stack[-1].current 100 | elif element.level == last.level: 101 | scope = last.parent 102 | else: 103 | scope = last.current 104 | stack.append(State(new["children"], scope, element.level)) 105 | scope.append(new) 106 | elif isinstance(element, block_token.CodeFence) and element.language.lower() == "{sammo/mutators}": 107 | sammo_config = json.loads(element.children[0].content) 108 | else: 109 | last.current.append( 110 | {"type": element.__class__.__name__.lower(), "children": [mrender.render(element)]} 111 | ) 112 | return {"type": "root", "children": processed}, sammo_config 113 | 114 | @classmethod 115 | def _json_to_sammo(cls, node): 116 | def _empty_to_none(x): 117 | return None if len(x) == 0 else x 118 | 119 | def _unwrap_list(x): 120 | if not isinstance(x, list) or len(x) > 1: 121 | return ValueError(f"Expected list of length 0 or 1, got {len(x)}") 122 | elif len(x) == 1: 123 | return x[0] 124 | return x 125 | 126 | def _get_annotations(x): 127 | return dict( 128 | reference_id=_empty_to_none(_unwrap_list(x.get("id", []))), 129 | reference_classes=_empty_to_none(x.get("class", [])), 130 | ) 131 | 132 | if isinstance(node, str) and "{{" in node: 133 | return Template(node) 134 | elif not isinstance(node, dict): 135 | return node 136 | elif node["type"] == "root": 137 | return MetaPrompt([cls._json_to_sammo(child) for child in node["children"]], render_as="raw") 138 | elif node["type"] == "section": 139 | return Section( 140 | title=node["title"], 141 | content=[cls._json_to_sammo(child) for child in node["children"]], 142 | **_get_annotations(node), 143 | ) 144 | elif node["type"] in ["paragraph", "list", "blockcode", "codefence", "quote"]: 145 | return Paragraph( 146 | content=[cls._json_to_sammo(child) for child in node["children"]], **_get_annotations(node) 147 | ) 148 | else: 149 | raise ValueError(f"Unsupported type: {type(node)} with node: {node}") 150 | -------------------------------------------------------------------------------- /sammo/express_test.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from sammo.express import MarkdownParser, _extract_html_comment, _get_ids_and_classes 3 | import pyglove as pg 4 | 5 | 6 | def test_extract_html_comment(): 7 | text = "Some text more text" 8 | comment, rest = _extract_html_comment(text) 9 | assert comment == " This is a comment " 10 | assert rest == "Some text more text" 11 | 12 | 13 | def test_extract_html_comment_no_comment(): 14 | text = "Some text more text" 15 | comment, rest = _extract_html_comment(text) 16 | assert comment == "" 17 | assert rest == text 18 | 19 | 20 | def test_get_ids_and_classes(): 21 | text = "Some text more text" 22 | result = _get_ids_and_classes(text) 23 | assert result["text"] == "Some text more text" 24 | assert set(result["ids"]) == {"id1", "id2"} 25 | assert set(result["classes"]) == {"class1", "class2"} 26 | 27 | 28 | def test_get_ids_and_classes_no_comment(): 29 | text = "Some text more text" 30 | result = _get_ids_and_classes(text) 31 | assert result["text"] == text 32 | assert result["ids"] == [] 33 | assert result["classes"] == [] 34 | 35 | 36 | def test_lists(): 37 | input_text = textwrap.dedent( 38 | """ 39 | * list item 1 40 | * list item 2 41 | """ 42 | ) 43 | 44 | expected = ( 45 | { 46 | "children": [{"children": ["* list item 1\n", "* list item 2\n"], "class": [], "id": [], "type": "list"}], 47 | "type": "root", 48 | }, 49 | {}, 50 | ) 51 | parsed = MarkdownParser._parse_annotated_markdown(input_text) 52 | assert parsed == expected 53 | 54 | 55 | def test_nested_sections(): 56 | input_text = textwrap.dedent( 57 | """ 58 | # Heading 1 59 | Some content 60 | ## Heading 1.1 61 | More content 62 | ### Heading 1.1.1 63 | Even more content 64 | # Heading 2 65 | Final content 66 | """ 67 | ) 68 | expected = pg.from_json( 69 | { 70 | "_type": "sammo.instructions.MetaPrompt", 71 | "child": [ 72 | { 73 | "_type": "sammo.instructions.Section", 74 | "title": "# Heading 1\n", 75 | "content": [ 76 | { 77 | "_type": "sammo.instructions.Paragraph", 78 | "content": ["Some content\n"], 79 | "reference_id": None, 80 | "reference_classes": None, 81 | }, 82 | { 83 | "_type": "sammo.instructions.Section", 84 | "title": "## Heading 1.1\n", 85 | "content": [ 86 | { 87 | "_type": "sammo.instructions.Paragraph", 88 | "content": ["More content\n"], 89 | "reference_id": None, 90 | "reference_classes": None, 91 | }, 92 | { 93 | "_type": "sammo.instructions.Section", 94 | "title": "### Heading 1.1.1\n", 95 | "content": [ 96 | { 97 | "_type": "sammo.instructions.Paragraph", 98 | "content": ["Even more content\n"], 99 | "reference_id": None, 100 | "reference_classes": None, 101 | } 102 | ], 103 | "reference_id": None, 104 | "reference_classes": None, 105 | }, 106 | ], 107 | "reference_id": None, 108 | "reference_classes": None, 109 | }, 110 | ], 111 | "reference_id": None, 112 | "reference_classes": None, 113 | }, 114 | { 115 | "_type": "sammo.instructions.Section", 116 | "title": "# Heading 2\n", 117 | "content": [ 118 | { 119 | "_type": "sammo.instructions.Paragraph", 120 | "content": ["Final content\n"], 121 | "reference_id": None, 122 | "reference_classes": None, 123 | } 124 | ], 125 | "reference_id": None, 126 | "reference_classes": None, 127 | }, 128 | ], 129 | "render_as": "raw", 130 | "data_formatter": None, 131 | "reference_id": None, 132 | "seed": 0, 133 | } 134 | ) 135 | parser = MarkdownParser(input_text) 136 | assert parser.get_sammo_program() == expected 137 | 138 | 139 | def test_express_parser_parse_annotated_markdown(): 140 | input_text = textwrap.dedent( 141 | """ 142 | # Heading 1 143 | Some content 144 | * list item 1 145 | * list item 2 146 | 147 | ## Heading 1.2 148 | {{{input}}} 149 | """ 150 | ) 151 | parser = MarkdownParser(input_text) 152 | expected = pg.from_json( 153 | { 154 | "_type": "sammo.instructions.MetaPrompt", 155 | "child": [ 156 | { 157 | "_type": "sammo.instructions.Section", 158 | "title": "# Heading 1\n", 159 | "content": [ 160 | { 161 | "_type": "sammo.instructions.Paragraph", 162 | "content": ["Some content\n"], 163 | "reference_id": None, 164 | "reference_classes": None, 165 | }, 166 | { 167 | "_type": "sammo.instructions.Paragraph", 168 | "content": ["* list item 1 \n", "* list item 2\n"], 169 | "reference_id": "id1", 170 | "reference_classes": ["class1"], 171 | }, 172 | { 173 | "_type": "sammo.instructions.Section", 174 | "title": "## Heading 1.2 \n", 175 | "content": [ 176 | { 177 | "_type": "sammo.instructions.Paragraph", 178 | "content": [ 179 | { 180 | "_type": "sammo.base.Template", 181 | "content": "{{{input}}}\n", 182 | "reference_id": None, 183 | "reference_classes": None, 184 | } 185 | ], 186 | "reference_id": None, 187 | "reference_classes": None, 188 | } 189 | ], 190 | "reference_id": "id2", 191 | "reference_classes": ["class2", "class3"], 192 | }, 193 | ], 194 | "reference_id": None, 195 | "reference_classes": None, 196 | } 197 | ], 198 | "render_as": "raw", 199 | "data_formatter": None, 200 | "reference_id": None, 201 | "seed": 0, 202 | } 203 | ) 204 | assert parser.get_sammo_program() == expected 205 | assert parser.get_sammo_config() == {} 206 | 207 | 208 | def test_express_parser_aux_tree_to_sammo(): 209 | input_text = textwrap.dedent( 210 | """ 211 | # Heading 1 212 | Some content 213 | ```{python} 214 | print("Hello, World!") 215 | ``` 216 | """ 217 | ) 218 | parser = MarkdownParser(input_text) 219 | assert parser.get_sammo_program() is not None 220 | 221 | 222 | def test_express_parser_with_mutators(): 223 | input_text = textwrap.dedent( 224 | """ 225 | # Heading 1 226 | Some content 227 | > Somewhere, something incredible is waiting to be known 228 | 229 | ```{sammo/mutators} 230 | { 231 | "mutators": [ 232 | { 233 | "name": "mutator1", 234 | "type": "type1" 235 | }, 236 | { 237 | "name": "mutator2", 238 | "type": "type2" 239 | } 240 | ] 241 | } 242 | ``` 243 | """ 244 | ) 245 | parser = MarkdownParser(input_text) 246 | assert parser.get_sammo_config() == { 247 | "mutators": [{"name": "mutator1", "type": "type1"}, {"name": "mutator2", "type": "type2"}] 248 | } 249 | -------------------------------------------------------------------------------- /sammo/extractors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import pytest 4 | from fractions import Fraction 5 | from sammo.extractors import * 6 | from sammo.base import VerbatimText 7 | from sammo.runners import MockedRunner 8 | 9 | 10 | # Test DefaultExtractor 11 | def test_default_extractor_process(): 12 | extractor = DefaultExtractor("") 13 | assert extractor._extract_from_single_value("sample") == ["sample"] 14 | 15 | 16 | # Test SplitLines 17 | @pytest.mark.asyncio 18 | async def test_split_lines_process(): 19 | split_lines = SplitLines("") 20 | assert split_lines._extract_from_single_value("line1\nline2") == ["line1", "line2"] 21 | 22 | split_lines = await SplitLines(VerbatimText("line1\nline2"))(MockedRunner, {}) 23 | assert split_lines.value == ["line1", "line2"] 24 | 25 | 26 | @pytest.mark.parametrize("expr", ["lambda: x.upper()", "lambda x, y: x.upper()", "3"]) 27 | def test_lambda_extractor_init(expr): 28 | with pytest.raises(ValueError): 29 | extractor = LambdaExtractor("", expr) 30 | 31 | 32 | def test_lambda_extractor_process(): 33 | extractor = LambdaExtractor("", "lambda x: x.upper()") 34 | assert extractor._extract_from_single_value("nEveR") == ["NEVER"] 35 | 36 | 37 | # Test ParseJSON 38 | def test_parse_json_process(): 39 | json_extractor = ParseJSON("") 40 | assert json_extractor._extract_from_single_value('{"key": "value"}') == [{"key": "value"}] 41 | 42 | 43 | def test_parse_json_process_fragments(): 44 | json_extractor = ParseJSON("", parse_fragments="all") 45 | assert json_extractor._extract_from_single_value('surrounding{"key": "value"}text{"num": [1, 2]}is ignored') == [ 46 | {"key": "value"}, 47 | {"num": [1, 2]}, 48 | ] 49 | 50 | 51 | # Test ExtractRegex 52 | def test_extract_regex_process(): 53 | extractor = ExtractRegex("", r"\d+") 54 | assert extractor._extract_from_single_value("There are 12 apples and 13 oranges.") == ["12", "13"] 55 | 56 | 57 | # Test MarkdownParser 58 | def test_markdown_parser_process(): 59 | parser = MarkdownParser("") 60 | assert parser._extract_from_single_value("# Heading1\nThis is content.") == [ 61 | {"name": "Heading1", "type": "section", "content": [{"type": "paragraph", "content": "This is content."}]} 62 | ] 63 | 64 | 65 | # Test YAMLParser 66 | def test_yaml_parser_process(): 67 | parser = YAMLParser("") 68 | assert parser._extract_from_single_value("key: value") == [{"key": "value"}] 69 | 70 | 71 | # Test ParseXML 72 | def test_parse_xml_process(): 73 | parser = ParseXML("") 74 | assert parser._extract_from_single_value("value") == [{"root": {"child": "value"}}] 75 | 76 | 77 | def test_parse_xml_process_fragments(): 78 | parser = ParseXML("", parse_fragments="all") 79 | assert parser._extract_from_single_value("valuevalue") == [ 80 | {"root": {"child": "value"}}, 81 | {"root": {"child": "value"}}, 82 | ] 83 | 84 | 85 | # Test JSONPath 86 | def test_jsonpath_process(): 87 | json_path = JSONPath("", "$.key") 88 | assert json_path._extract_from_single_value({"key": "value"}) == ["value"] 89 | 90 | 91 | # Test ToNum 92 | @pytest.mark.parametrize( 93 | "input_val, dtype, factor, offset, expected", 94 | [("3/2", "fraction", 1, 0, Fraction("3/2")), ("5", "int", 1, 0, 5), ("3.5", "float", 2, 1, 8.0)], 95 | ) 96 | def test_to_num_process(input_val, dtype, factor, offset, expected): 97 | to_num = ToNum(VerbatimText(""), dtype=dtype, factor=factor, offset=offset) 98 | assert to_num._extract_from_single_value(input_val) == [expected] 99 | 100 | # alternative syntax 101 | to_num = ToNum(VerbatimText(""), dtype=dtype) * factor + offset 102 | assert to_num._extract_from_single_value(input_val) == [expected] 103 | -------------------------------------------------------------------------------- /sammo/instructions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import pytest 5 | from frozendict import frozendict 6 | from sammo.dataformatters import PlainFormatter 7 | 8 | from sammo.data import DataTable 9 | from sammo.instructions import MetaPrompt, Section, Paragraph, InputData 10 | from sammo.runners import MockedRunner 11 | 12 | 13 | @pytest.mark.asyncio 14 | @pytest.mark.parametrize( 15 | "render_as,expected", 16 | [ 17 | ( 18 | "xml", 19 | ( 20 | "
\n" 21 | "Section\n" 22 | "
\n" 23 | "SubsectionSubsection text.\n" 24 | "
\n" 25 | "
" 26 | ), 27 | ), 28 | ("markdown", "# Section\n## Subsection\nSubsection text."), 29 | ("markdown-alt", "Section\n=======\nSubsection\n----------\nSubsection text."), 30 | ], 31 | ) 32 | async def test_basic_render(render_as, expected): 33 | runner = MockedRunner() 34 | rendered = await MetaPrompt( 35 | [Section(title="Section", content=[Section(title="Subsection", content="Subsection text.")])], 36 | render_as=render_as, 37 | )(runner, dict()) 38 | assert rendered.value == expected 39 | 40 | 41 | @pytest.mark.asyncio 42 | @pytest.mark.parametrize( 43 | "render_as,expected", 44 | [ 45 | ("markdown", "# A\nSome text.\n\n\n# B\nOther text."), 46 | ], 47 | ) 48 | async def test_basic_render_text(render_as, expected): 49 | runner = MockedRunner() 50 | rendered = await MetaPrompt( 51 | [Section(title="A", content="Some text.\n"), Section(title="B", content="Other text.")], render_as=render_as 52 | )(runner, dict()) 53 | assert rendered.value == expected 54 | 55 | 56 | @pytest.mark.asyncio 57 | @pytest.mark.parametrize( 58 | "render_as,expected", 59 | [ 60 | ( 61 | "xml", 62 | ( 63 | ( 64 | "
\n" 65 | "Section\n" 66 | "Paragraph 1\n" 67 | "\n" 68 | "\n" 69 | "Paragraph 2\n" 70 | "\n" 71 | "
" 72 | ) 73 | ), 74 | ), 75 | ("markdown", "# Section\nParagraph 1\n\n\nParagraph 2"), 76 | ("markdown-alt", "Section\n=======\nParagraph 1\n\n\nParagraph 2"), 77 | ], 78 | ) 79 | async def test_paragraph(render_as, expected): 80 | runner = MockedRunner() 81 | rendered = await MetaPrompt( 82 | [Section(title="Section", content=[Paragraph("Paragraph 1"), Paragraph("Paragraph 2")])], 83 | render_as=render_as, 84 | )(runner, dict()) 85 | assert rendered.value == expected 86 | 87 | 88 | @pytest.mark.asyncio 89 | async def test_raw_render(): 90 | runner = MockedRunner() 91 | rendered = await MetaPrompt( 92 | [Section("My title\n", "Section 1"), Paragraph("Paragraph 1\n"), Paragraph("Paragraph 2")], render_as="raw" 93 | )(runner, dict()) 94 | assert rendered.value == "My title\nSection 1Paragraph 1\nParagraph 2" 95 | 96 | 97 | @pytest.mark.asyncio 98 | @pytest.mark.parametrize( 99 | "render_as,expected", 100 | [ 101 | ( 102 | "xml", 103 | ( 104 | "Paragraph 1\n" 105 | "\n" 106 | "\n" 107 | "Input: {'name': 'Alice', 'age': 25}\n" 108 | "\n" 109 | "Input: {'name': 'Bob', 'age': 30}\n" 110 | "\n" 111 | "" 112 | ), 113 | ), 114 | ( 115 | "markdown", 116 | ( 117 | "Paragraph 1\n" 118 | "\n" 119 | "\n" 120 | "Input: {'name': 'Alice', 'age': 25}\n" 121 | "\n" 122 | "Input: {'name': 'Bob', 'age': 30}" 123 | ), 124 | ), 125 | ("raw", ("Paragraph 1Input: {'name': 'Alice', 'age': 25}\n" "\n" "Input: {'name': 'Bob', 'age': 30}")), 126 | ], 127 | ) 128 | async def test_data_render(render_as, expected): 129 | data = DataTable([{"name": "Alice", "age": 25}, {"name": "Bob", "age": 30}], [1, 0]) 130 | context = dict( 131 | data=frozendict( 132 | inputs=data.inputs.values, 133 | constants=data.constants, 134 | ) 135 | ) 136 | data_formatter = PlainFormatter(all_labels=data.outputs.unique(), orient="item") 137 | runner = MockedRunner() 138 | rendered = await MetaPrompt( 139 | [Paragraph("Paragraph 1"), Paragraph(InputData())], render_as=render_as, data_formatter=data_formatter 140 | )(runner, context) 141 | assert rendered.value == expected 142 | -------------------------------------------------------------------------------- /sammo/integration_test.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from unittest.mock import patch 3 | 4 | import dill 5 | import pytest 6 | 7 | from sammo.base import VerbatimText, Template, EvaluationScore 8 | from sammo.components import Output, Union, ForEach, GenerateText 9 | from sammo.data import DataTable 10 | from sammo.extractors import ExtractRegex, LambdaExtractor, SplitLines, StripWhitespace 11 | from sammo.runners import MockedRunner 12 | from sammo.search import EnumerativeSearch 13 | from sammo.search_op import one_of 14 | 15 | 16 | def test_manual_loop(): 17 | numbers = [VerbatimText(f"{i}") for i in range(5)] 18 | result = Output(Union(*numbers)).run(MockedRunner()) 19 | assert result.outputs.values[0] == ["0", "1", "2", "3", "4"] 20 | 21 | 22 | def test_dynamic_loop(): 23 | numbers = ExtractRegex( 24 | VerbatimText("123"), 25 | r"(.*?)<.?item>", 26 | ) 27 | fruit_blurbs = ForEach( 28 | "number", 29 | numbers, 30 | Template("{{number}}!"), 31 | ) 32 | result = Output(fruit_blurbs).run(MockedRunner()) 33 | assert result.outputs.values[0] == ["1!", "2!", "3!"] 34 | 35 | 36 | def test_custom_extractor(): 37 | numbers = [VerbatimText(f"{i}") for i in range(5)] 38 | numbers = LambdaExtractor(Union(*numbers), "lambda x: int(x) + 1") 39 | result = Output(numbers).run(MockedRunner()) 40 | assert result.outputs.values[0] == [1, 2, 3, 4, 5] 41 | 42 | 43 | def test_minibatching(): 44 | data = list(range(5)) 45 | result = Output( 46 | SplitLines(StripWhitespace(Template("{{#each inputs}}{{this}}\n\n{{/each}}"))), 47 | minibatch_size=2, 48 | ).run(MockedRunner(), data, progress_callback=False) 49 | assert result.outputs.values == [str(d) for d in data] 50 | 51 | 52 | def test_search(): 53 | def prompt_space(): 54 | prompt = GenerateText( 55 | Template(f"{{input}}"), 56 | randomness=one_of([0.3, 0.7, 1.0], reference_id="randomness"), 57 | ) 58 | return Output(prompt) 59 | 60 | def metric(y_true, y_pred): 61 | return EvaluationScore(0) 62 | 63 | train_data = DataTable([{"input": "1"}, {"input": "2"}]) 64 | runner = MockedRunner() 65 | searcher = EnumerativeSearch(runner, prompt_space, metric) 66 | searcher.fit_transform(train_data) 67 | with patch("builtins.open", lambda x, y: BytesIO()) as mock_file: 68 | searcher.save("file.pkl") 69 | -------------------------------------------------------------------------------- /sammo/mutators_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from unittest.mock import MagicMock 4 | 5 | import pytest 6 | 7 | from sammo.dataformatters import PlainFormatter 8 | from sammo.instructions import MetaPrompt, Section, InputData 9 | from sammo.mutators import * 10 | from sammo.runners import MockedRunner 11 | 12 | 13 | def basic_template(): 14 | return Output( 15 | MetaPrompt( 16 | Section(title="Title", reference_id="test", content="The big ball rolled over the street."), 17 | data_formatter=PlainFormatter(), 18 | render_as="markdown", 19 | ) 20 | ) 21 | 22 | 23 | def short_template(): 24 | return Output( 25 | MetaPrompt( 26 | [Section(title="Title", reference_id="test", content="No content."), InputData()], 27 | data_formatter=PlainFormatter(), 28 | ).with_extractor() 29 | ) 30 | 31 | 32 | def duplicate_template(): 33 | return Output( 34 | MetaPrompt( 35 | [ 36 | Section(title="Title", reference_classes=["test"], content="The big ball rolled over the street."), 37 | Section(title="Title", reference_classes=["test"], content="The big ball rolled over the street."), 38 | ], 39 | data_formatter=PlainFormatter(), 40 | render_as="markdown", 41 | ) 42 | ) 43 | 44 | 45 | @pytest.mark.slow 46 | def test_parsing(): 47 | assert ["The big ball ", "rolled over the street", "."] == SyntaxTreeMutator.get_phrases( 48 | "The big ball rolled over the street." 49 | ) 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_paraphrase(): 54 | runner = MockedRunner(["1", "2"]) 55 | mutator = Paraphrase(css_selector="#test") 56 | result = await mutator.mutate(basic_template(), MagicMock(), runner, n_mutations=2, random_state=42) 57 | assert len(result) == 2 58 | assert result[0].candidate.find_first("#test content").node == "1" 59 | assert result[1].candidate.find_first("#test content").node == "2" 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_paraphrase_with_duplicates(): 64 | runner = MockedRunner(["1", "2"]) 65 | mutator = Paraphrase(css_selector=".test") 66 | result = await mutator.mutate(duplicate_template(), MagicMock(), runner, n_mutations=2, random_state=42) 67 | assert len(result) == 2 68 | assert [m.node for m in result[0].candidate.find_all(".test content")] == ["1", "1"] 69 | assert [m.node for m in result[1].candidate.find_all(".test content")] == ["2", "2"] 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_to_bulletpoints(): 74 | runner = MockedRunner(["1", "2"]) 75 | mutator = SegmentToBulletPoints(css_selector="#test") 76 | result = await mutator.mutate(basic_template(), MagicMock(), runner, n_mutations=2, random_state=42) 77 | assert len(result) == 1 78 | assert result[0].candidate.find_first("#test content").node == "1" 79 | assert runner.prompt_log[0] == ( 80 | "Rewrite the text below as a bullet list with at most 10 words per bullet " 81 | "point. \n" 82 | "\n" 83 | "The big ball rolled over the street." 84 | ) 85 | 86 | 87 | @pytest.mark.asyncio 88 | async def test_induce(): 89 | runner = MockedRunner(["1", "2"]) 90 | datatable = DataTable.from_records([{"input": "_a_", "output": "_b_"}] * 5) 91 | mutator = InduceInstructions("#test", datatable) 92 | result = await mutator.mutate(basic_template(), MagicMock(), runner, n_mutations=2, random_state=42) 93 | assert len(result) == 2 94 | assert result[0].candidate.find_first("#test content").node == "1" 95 | assert result[1].candidate.find_first("#test content").node == "2" 96 | assert "_a_" in runner.prompt_log[0] 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_replace_param(): 101 | runner = MockedRunner(["1", "2"]) 102 | mutator = ReplaceParameter("render_as", ["markdown", "xml"]) 103 | result = await mutator.mutate(basic_template(), MagicMock(), runner, n_mutations=2, random_state=42) 104 | assert len(result) == 1 105 | 106 | 107 | @pytest.mark.asyncio 108 | async def test_rewrite_missing_placeholder(): 109 | with pytest.raises(ValueError): 110 | _ = Rewrite("#test", "content") 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_apo(n_data=5, num_gradients=2): 115 | runner = MockedRunner( 116 | { 117 | "semantic.*1": "Rewritten Improved 1", 118 | "semantic.*2": "Rewritten Improved 2", 119 | "reason 1.*improved": "Improved 1", 120 | "reason 2.*improved": "Improved 2", 121 | "reasons": "reason 1reason 2", 122 | } 123 | | {f"a_{i}": f"y_{i}" for i in range(n_data)} 124 | ) 125 | datatable = DataTable.from_records([{"input": f"a_{i}", "output": f"b_{i}"} for i in range(n_data)]) 126 | mutator = APO("#test content", None, num_gradients=2, steps_per_gradient=1, num_rewrites=1) 127 | mutator.objective = MagicMock() 128 | mutator.objective.return_value = MagicMock(mistakes=list(range(n_data))) 129 | result = await mutator.mutate(short_template(), datatable, runner, n_mutations=3, random_state=42) 130 | 131 | assert result[0].candidate.find_first("#test content").node == "Improved 2" 132 | assert result[2].candidate.find_first("#test content").node == "Rewritten Improved 1" 133 | assert len(result) == 3 134 | 135 | 136 | @pytest.mark.asyncio 137 | async def test_mutate(): 138 | cache = {"The big ball rolled over the street.": ["The big ball ", "rolled over the street", "."]} 139 | mutator = SyntaxTreeMutator(starting_prompt=basic_template(), cache=cache, css_selector="#test") 140 | runner = MockedRunner("LLM response") 141 | result = await mutator.mutate(basic_template(), MagicMock(), runner, random_state=42) 142 | assert result[0].action == "del" 143 | result = await mutator.mutate(basic_template(), MagicMock(), runner, random_state=43) 144 | assert result[0].action == "par" 145 | assert result[0].candidate.find_first("content").node.strip() == "LLM response rolled over the street ." 146 | result = await mutator.mutate(basic_template(), MagicMock(), runner, random_state=46) 147 | assert result[0].action == "swap" 148 | assert result[0].candidate.find_first("content").node.strip() == "rolled over the street The big ball ." 149 | -------------------------------------------------------------------------------- /sammo/runners_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from aiohttp import ClientConnectorError 4 | from aiohttp.client_reqrep import ConnectionKey 5 | from quattro import TaskGroup 6 | from unittest.mock import AsyncMock, MagicMock 7 | 8 | import pytest 9 | 10 | from sammo.runners import BaseRunner, OpenAIChat, OpenAIEmbedding, RetriableError, JsonSchema, AzureEmbedding 11 | from sammo.base import Costs 12 | from sammo.store import InMemoryDict 13 | 14 | 15 | """Testing the Costs class""" 16 | 17 | 18 | def test_costs_addition(): 19 | c1 = Costs(input_costs=1, output_costs=2) 20 | c2 = Costs(input_costs=3, output_costs=4) 21 | result = c1 + c2 22 | assert result.input == 4 23 | assert result.output == 6 24 | 25 | 26 | def test_costs_subtraction(): 27 | c1 = Costs(input_costs=5, output_costs=6) 28 | c2 = Costs(input_costs=3, output_costs=4) 29 | result = c1 - c2 30 | assert result.input == 2 31 | assert result.output == 2 32 | 33 | 34 | def test_costs_to_dict(): 35 | c = Costs(input_costs=1, output_costs=2) 36 | result = c.to_dict() 37 | assert result == {"input": 1, "output": 2} 38 | 39 | 40 | def test_costs_total(): 41 | c = Costs(input_costs=1, output_costs=2) 42 | assert c.total == 3 43 | 44 | 45 | @pytest.fixture 46 | def basic(): 47 | mock = MagicMock() 48 | coro = MagicMock() 49 | coro.post.return_value.__aenter__.return_value.status = 200 50 | coro.post.return_value.__aenter__.return_value.json = AsyncMock( 51 | return_value={ 52 | "usage": {"total_tokens": 1, "prompt_tokens": 2, "completion_tokens": 3}, 53 | "choices": [{"message": {"content": "test"}}], 54 | } 55 | ) 56 | mock.return_value.__aenter__.return_value = coro 57 | return mock 58 | 59 | 60 | @pytest.fixture 61 | def connector_error_in_post(): 62 | session_mock = MagicMock() 63 | post_mock = MagicMock() 64 | post_mock.post.side_effect = ClientConnectorError( 65 | ConnectionKey("example.com", 123, False, False, None, None, None), OSError("mock error") 66 | ) 67 | session_mock.return_value.__aenter__.return_value = post_mock 68 | return session_mock 69 | 70 | 71 | @pytest.fixture 72 | def basic_embedding(): 73 | mock = MagicMock() 74 | coro = MagicMock() 75 | coro.post.return_value.__aenter__.return_value.status = 200 76 | coro.post.return_value.__aenter__.return_value.json = AsyncMock( 77 | return_value={ 78 | "usage": {"total_tokens": 1, "prompt_tokens": 2, "completion_tokens": 3}, 79 | "data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}], 80 | } 81 | ) 82 | mock.return_value.__aenter__.return_value = coro 83 | return mock 84 | 85 | 86 | @pytest.mark.asyncio 87 | async def test_generate_text(basic): 88 | runner = OpenAIChat(model_id="gpt-4", api_config={"api_key": "test"}, cache=None) 89 | runner._get_session = basic 90 | result = await runner.generate_text(prompt="test prompt") 91 | assert result.value == "test" 92 | 93 | 94 | @pytest.mark.asyncio 95 | async def test_parallel_identical_calls(basic): 96 | runner = OpenAIChat(model_id="gpt-4", api_config={"api_key": "test"}, rate_limit=10, cache=InMemoryDict()) 97 | runner._get_session = basic 98 | async with TaskGroup() as g: 99 | for _ in range(10): 100 | g.create_task(runner.generate_text(prompt="test prompt", seed=0)) 101 | # we expect the backend to be called only once, other values from cache 102 | assert basic.call_count == 1 103 | 104 | 105 | @pytest.mark.asyncio 106 | async def test_system_message(basic): 107 | runner = OpenAIChat(model_id="gpt-4", api_config={"api_key": "test"}, cache=None) 108 | runner._get_session = basic 109 | await runner.generate_text(prompt="test prompt", system_prompt="test system") 110 | assert basic.mock_calls[2].kwargs["json"]["messages"][0] == {"role": "system", "content": "test system"} 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_cache(basic): 115 | cache = InMemoryDict() 116 | runner = OpenAIChat(model_id="gpt-4", api_config={"api_key": "test"}, cache=cache) 117 | runner._get_session = basic 118 | await runner.generate_text(prompt="test prompt") 119 | assert len(cache) == 1 120 | 121 | 122 | @pytest.mark.asyncio 123 | async def test_generate_embedding(basic_embedding): 124 | runner = OpenAIEmbedding(model_id="some_id", api_config={"api_key": "test"}, cache=None) 125 | runner._get_session = basic_embedding 126 | result = await runner.generate_embedding(["text", "text2"]) 127 | assert result.value == [[0.1, 0.2], [0.3, 0.4]] 128 | 129 | 130 | @pytest.mark.asyncio 131 | async def test_cached_embeddings(basic_embedding): 132 | cache = InMemoryDict() 133 | cache[("some_id", "text2")] = [0.3, 0.4] 134 | runner = OpenAIEmbedding(model_id="some_id", api_config={"api_key": "test"}, cache=cache) 135 | runner._get_session = basic_embedding 136 | result = await runner.generate_embedding(["text", "text2"]) 137 | assert len(basic_embedding.mock_calls[2].kwargs["json"]["input"]) == 1 138 | assert result.value == [[0.1, 0.2], [0.3, 0.4]] 139 | print(cache._dict) 140 | 141 | # test that backend will not be called again if cached 142 | result = await runner.generate_embedding(["text", "text2"]) 143 | assert basic_embedding.call_count == 1 144 | 145 | assert result.value == [[0.1, 0.2], [0.3, 0.4]] 146 | 147 | 148 | @pytest.mark.asyncio 149 | async def test_retry_connector_errors(connector_error_in_post): 150 | runner = OpenAIChat(model_id="gpt-4", api_config={"api_key": "test"}, rate_limit=10, cache=InMemoryDict()) 151 | runner._get_session = connector_error_in_post 152 | with pytest.raises(RetriableError) as excinfo: 153 | await runner.generate_text(prompt="test prompt") 154 | assert "Client/server connection error" in str(excinfo) 155 | assert "example.com" in str(excinfo) 156 | 157 | 158 | def test_schema_simple_dict(): 159 | data = {"name": "John", "age": 30, "employed": True} 160 | expected = { 161 | "type": "object", 162 | "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "employed": {"type": "boolean"}}, 163 | "required": ["name", "age", "employed"], 164 | "additionalProperties": False, 165 | } 166 | assert JsonSchema.guess_schema(data).schema == expected 167 | 168 | 169 | def test_instantiate_azure(): 170 | test = AzureEmbedding( 171 | model_id="dummy", 172 | api_config={"api_key": "test", "endpoint": "test", "deployment_id": "sth", "api_version": "2023-05-15"}, 173 | ) 174 | assert hasattr(test, "_embeddings_cache") 175 | 176 | 177 | def test_schema_nested_dict(): 178 | data = {"person": {"name": "Alice", "age": 25}} 179 | expected = { 180 | "type": "object", 181 | "properties": { 182 | "person": { 183 | "type": "object", 184 | "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, 185 | "required": ["name", "age"], 186 | "additionalProperties": False, 187 | } 188 | }, 189 | "required": ["person"], 190 | "additionalProperties": False, 191 | } 192 | assert JsonSchema.guess_schema(data).schema == expected 193 | 194 | 195 | def test_infer_schema_empty_dict(): 196 | data = {} 197 | expected = {"type": "object", "properties": {}, "required": [], "additionalProperties": False} 198 | assert JsonSchema.guess_schema(data).schema == expected 199 | 200 | 201 | def test_infer_schema_list_of_integers(): 202 | data = [1, 2, 3] 203 | expected = {"type": "array", "items": {"type": "integer"}} 204 | assert JsonSchema._guess_schema(data, top_level=False) == expected 205 | 206 | 207 | def test_infer_schema_empty_list(): 208 | data = [] 209 | with pytest.raises(IndexError): 210 | JsonSchema._guess_schema(data, top_level=False) 211 | 212 | 213 | def test_infer_schema_set_of_strings(): 214 | data = {"apple", "banana", "cherry"} 215 | expected = {"type": "string", "enum": list(data)} 216 | assert JsonSchema._guess_schema(data, top_level=False) == expected 217 | 218 | 219 | def test_infer_schema_mixed_type_set(): 220 | data = {1, "two", 3.0} 221 | expected = {"type": ["integer", "number", "string"]} 222 | assert JsonSchema._guess_schema(data, top_level=False) == expected 223 | 224 | 225 | def test_infer_schema_dict_with_description(): 226 | data = {("name", "The user's name"): "Alice", ("age", "The user's age"): 30} 227 | expected = { 228 | "type": "object", 229 | "properties": { 230 | "name": {"type": "string", "description": "The user's name"}, 231 | "age": {"type": "integer", "description": "The user's age"}, 232 | }, 233 | "required": ["name", "age"], 234 | "additionalProperties": False, 235 | } 236 | assert JsonSchema.guess_schema(data).schema == expected 237 | 238 | 239 | def test_infer_schema_top_level_type_error(): 240 | data = "not a dict" 241 | with pytest.raises(TypeError): 242 | JsonSchema.guess_schema(data) 243 | -------------------------------------------------------------------------------- /sammo/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import asyncio 4 | import quattro 5 | import collections 6 | import json 7 | import pathlib 8 | import webbrowser 9 | from graphlib import TopologicalSorter 10 | import logging 11 | 12 | from sammo.utils import HtmlRenderer, GRAPH_TEMPLATE 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ComputeNode: 18 | __slots__ = ["job", "compute_context", "priority", "needs_scheduling"] 19 | 20 | def __init__(self, job, local_cache, priority): 21 | self.job = job 22 | self.compute_context = local_cache 23 | self.priority = priority 24 | 25 | 26 | class Scheduler: 27 | def __init__(self, runner, jobs, base_priority=0): 28 | # Construct graph 29 | self._graph = dict() 30 | self._runner = runner 31 | 32 | jobs = [jobs] if not isinstance(jobs, collections.abc.Iterable) else jobs 33 | queue = [ComputeNode(x, x._context, i) for i, x in enumerate(jobs)] 34 | 35 | self._graph = dict() 36 | while queue: 37 | x = queue.pop(0) 38 | children = [ComputeNode(c, x.compute_context, x.priority + i) for i, c in enumerate(x.job.dependencies)] 39 | queue = children + queue 40 | self._graph[x] = set(children) 41 | 42 | self.tasks = TopologicalSorter(self._graph) 43 | self.tasks.prepare() 44 | self.finalized_tasks_queue = asyncio.Queue() 45 | 46 | @staticmethod 47 | def _generate_id(node, iddict): 48 | if node not in iddict: 49 | iddict[node] = f"{node.job.__class__.__name__}_{node.priority}_{len(iddict)}" 50 | 51 | def plot(self, open_in_browser=False): 52 | elements = self._to_html() 53 | 54 | # write out as utf-8 file 55 | file = pathlib.Path("logs/callgraph.html") 56 | with open(file, "w", encoding="utf-8") as f: 57 | f.write() 58 | if open_in_browser: 59 | webbrowser.open(file.absolute().as_uri(), new=2, autoraise=False) 60 | 61 | def _to_html(self): 62 | # Generate ids 63 | node_ids = dict() 64 | for node, children in self._graph: 65 | self.generate_id(node, node_ids) 66 | for child in children: 67 | self.generate_id(child, node_ids) 68 | # Convert into Cytoscape.js format 69 | nodes = [{"data": {"id": v}} for v in node_ids.values()] 70 | edges = list() 71 | for e1, v in self._graph.items(): 72 | for e2 in v: 73 | e1_id = node_ids[e1] 74 | e2_id = node_ids[e2] 75 | edges.append({"data": {"id": f"{e1_id}_{e2_id}", "source": e1_id, "target": e2_id}}) 76 | elements = {"nodes": nodes, "edges": edges} 77 | return GRAPH_TEMPLATE.replace("ELEMENTS", json.dumps(elements, ensure_ascii=False)) 78 | 79 | def display(self, backend="auto"): 80 | return HtmlRenderer(self._to_html()).render(backend) 81 | 82 | async def run_node(self, node): 83 | await node.job(self._runner, node.compute_context, None) 84 | await self.finalized_tasks_queue.put(node) 85 | 86 | async def arun(self): 87 | async with quattro.TaskGroup() as tg: 88 | while self.tasks.is_active(): 89 | for compute_node in self.tasks.get_ready(): 90 | if compute_node.job.NEEDS_SCHEDULING: 91 | tg.create_task(self.run_node(compute_node)) 92 | else: 93 | await self.finalized_tasks_queue.put(compute_node) 94 | 95 | compute_node = await self.finalized_tasks_queue.get() 96 | self.tasks.done(compute_node) 97 | 98 | def run(self): 99 | asyncio.run(self.arun()) 100 | -------------------------------------------------------------------------------- /sammo/search_op.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """ 4 | This module contains a variety of search operators that can be used to define a discrete search space for `GridSearch` 5 | or define a set of initial candidates for other search algorithms. 6 | """ 7 | from __future__ import annotations 8 | import pyglove as pg 9 | from sammo.base import Component 10 | from beartype.typing import Iterable, Any, Callable 11 | from pyglove.core.hyper import OneOf, ManyOf 12 | 13 | __all__ = ["one_of", "many_of", "permutate", "optional", "get_points_from_search_space", "get_first_point"] 14 | 15 | 16 | def get_points_from_search_space( 17 | search_space: Callable | Component | list | dict, 18 | n_points: int, 19 | sample: bool = False, 20 | seed: int = 42, 21 | return_names: bool = False, 22 | ) -> list[Component]: 23 | """Materialize a number of points from a search space. 24 | 25 | :param search_space: Search space, either represented as function or a single Output class. 26 | :param n_points: Number of points to materialize. 27 | :param sample: Whether to sample from the search space or enumerate and return first `n_points`. 28 | :param seed: Random seed for sampling. 29 | :param return_names: Whether to return the names of the points. 30 | """ 31 | names = list() 32 | if isinstance(search_space, list): 33 | search_space = pg.list(search_space) 34 | elif isinstance(search_space, dict): 35 | search_space = pg.dict(search_space) 36 | if isinstance(search_space, Callable) and not isinstance(search_space, pg.Object): 37 | candidates = list() 38 | for context in pg.iter( 39 | pg.hyper.trace(search_space), 40 | num_examples=n_points, 41 | algorithm=pg.geno.Random(seed) if sample else None, 42 | ): 43 | names.append(context.__closure__[0].cell_contents.to_dict("name_or_id", "literal")) 44 | with context(): 45 | candidates.append(search_space()) 46 | elif search_space.is_deterministic: 47 | candidates = [search_space.clone(deep=True)] * n_points 48 | elif sample: 49 | candidates = list(pg.random_sample(search_space, n_points, seed=seed)) 50 | else: 51 | candidates = list(pg.iter(search_space, n_points)) 52 | if return_names: 53 | return candidates, names 54 | else: 55 | return candidates 56 | 57 | 58 | def get_first_point(search_space: Callable | Component | list | dict) -> Component: 59 | """Return the first value of the enumerated search space. 60 | 61 | :param search_space: Search space, either represented as function or a single Output class.""" 62 | return get_points_from_search_space(search_space, 1, sample=False)[0] 63 | 64 | 65 | class OneOfPatched(OneOf): 66 | def __getitem__(self, item): 67 | return "" 68 | 69 | 70 | class ManyOfPatched(ManyOf): 71 | def __getitem__(self, item): 72 | return "" 73 | 74 | 75 | def one_of(candidates: Iterable, reference_id: str | None = None) -> Any: 76 | """Search operator for selecting one of the given candidates. 77 | 78 | :param candidates: The list of candidates to choose from. 79 | :param reference_id: Identifier for later reference. 80 | """ 81 | return OneOfPatched([(lambda n=x: n)(x) if not callable(x) else x for x in candidates], name=reference_id) 82 | 83 | 84 | def many_of(num_choices: int, candidates: Iterable, reference_id: str | None = None) -> Any: 85 | """Search operator for n choose k. 86 | 87 | :param num_choices: The number of candidates to choose. 88 | :param candidates: The list of candidates to choose from. 89 | :param reference_id: Identifier for later reference. 90 | """ 91 | return ManyOfPatched( 92 | num_choices=num_choices, 93 | choices_sorted=True, 94 | candidates=[(lambda n=x: n)(x) if not callable(x) else x for x in candidates], 95 | name=reference_id, 96 | ) 97 | 98 | 99 | def permutate(candidates: Iterable, reference_id: str | None = None) -> Any: 100 | """Search operator for permutating a list of components. 101 | 102 | :param candidates: The list of components to permute. 103 | :param reference_id: Identifier for later reference. 104 | """ 105 | return ManyOfPatched( 106 | num_choices=len(list(candidates)), 107 | choices_distinct=True, 108 | choices_sorted=False, 109 | candidates=[(lambda n=x: n)(x) if not callable(x) else x for x in candidates], 110 | name=reference_id, 111 | ) 112 | 113 | 114 | def optional(candidate, reference_id=None) -> Any: 115 | """Search operator for making a component optional. 116 | 117 | :param val: The value to include or exclude. 118 | :param reference_id: Identifier for later reference. 119 | """ 120 | return one_of([[], [candidate]], reference_id=reference_id) 121 | -------------------------------------------------------------------------------- /sammo/search_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from sammo.search_op import one_of, many_of, permutate, optional, get_points_from_search_space 4 | import pyglove as pg 5 | 6 | 7 | @pg.symbolize(eq=True) 8 | class Demo: 9 | def __init__(self, params): 10 | self.params = params 11 | 12 | 13 | def enumerate_candidates(search_space): 14 | traced_search_space = pg.hyper.trace(search_space) 15 | all_candidates = list() 16 | for search_context in pg.iter(traced_search_space): 17 | with search_context(): 18 | all_candidates.append(search_space()) 19 | return all_candidates 20 | 21 | 22 | def test_one_of(): 23 | space = lambda: one_of([1, 2, 3]) 24 | assert enumerate_candidates(space) == [1, 2, 3] 25 | 26 | 27 | def test_many_of(): 28 | space = lambda: many_of(2, [1, 2, 3]) 29 | assert enumerate_candidates(space) == [[1, 2], [1, 3], [2, 3]] 30 | 31 | 32 | def test_permutate(): 33 | space = lambda: permutate([1, 2, 3]) 34 | assert sorted(enumerate_candidates(space)) == sorted( 35 | [ 36 | [1, 2, 3], 37 | [1, 3, 2], 38 | [2, 1, 3], 39 | [3, 1, 2], 40 | [2, 3, 1], 41 | [3, 2, 1], 42 | ] 43 | ) 44 | 45 | 46 | def test_optional(): 47 | space = lambda: optional(1) 48 | assert enumerate_candidates(space) == [[], [1]] 49 | 50 | 51 | def test_get_points_from_search_space(): 52 | me = Demo(one_of(["a", "b"])) 53 | points = get_points_from_search_space(me, 2, sample=False) 54 | assert points == [Demo("a"), Demo("b")] 55 | 56 | 57 | def test_get_points_from_search_space_with_names(): 58 | me = lambda: Demo(one_of(["a", "b"], reference_id="params")) 59 | points, names = get_points_from_search_space(me, 2, sample=False, return_names=True) 60 | assert points == [Demo("a"), Demo("b")] 61 | assert names == [{"params": "'a'"}, {"params": "'b'"}] 62 | -------------------------------------------------------------------------------- /sammo/store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Implements two different types of dictionaries that can store either data in memory or on disk. Allows keys to be 4 | arbitrary JSON-serializable objects that get rendered to byte strings for indexing. 5 | Mainly used to cache LLM API calls, but can be used for other purposes as well. 6 | """ 7 | from __future__ import annotations 8 | from collections.abc import MutableMapping 9 | from contextlib import ExitStack 10 | from io import BytesIO 11 | import logging 12 | import os 13 | import threading 14 | import warnings 15 | from pathlib import Path 16 | 17 | import diskcache 18 | from beartype.typing import Callable, Union 19 | import filelock 20 | import orjson 21 | from pyglove import JSONConvertible 22 | 23 | from sammo.utils import CodeTimer 24 | from sammo.utils import serialize_json as serialize_json_obj 25 | 26 | __all__ = ["PersistentDict", "InMemoryDict"] 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def serialize_json(obj): 32 | if isinstance(obj, bytes): 33 | return obj 34 | else: 35 | return serialize_json_obj(obj) 36 | 37 | 38 | class PersistentDict(MutableMapping, JSONConvertible): 39 | """ 40 | Implements a dictionary that is persisted to disk. Entries are appended to the end of the file, with later entries 41 | overwriting earlier ones. The file is read into memory on initialization to allow for fast lookups. 42 | Write and delete operations are thread-safe. 43 | 44 | :param filename: 45 | path for the stored data. Loads the dictionary from the given file, or creates a new one if it doesn't exist. 46 | """ 47 | 48 | def __init__(self, filename: os.PathLike | str): 49 | self._filename = Path(filename) 50 | if self._filename.exists(): 51 | self._dict = self._load() 52 | else: 53 | self._filename.parent.mkdir(parents=True, exist_ok=True) 54 | self._dict = dict() 55 | self._fp = None 56 | self._lock = threading.Lock() 57 | self._os_lock = filelock.FileLock(self._filename.with_suffix(".lock"), timeout=1) 58 | 59 | def _load(self): 60 | timer = CodeTimer() 61 | keys, vals = list(), list() 62 | for line in open(self._filename, "rb"): 63 | if line[0] == b"#" or line == b"\n" or b"\t" not in line: 64 | continue 65 | splits = line.split(b"\t", 2) 66 | if len(splits) != 2: 67 | continue 68 | key, val = splits 69 | try: 70 | val = orjson.loads(val) 71 | keys.append(key) 72 | vals.append(val) 73 | except orjson.JSONDecodeError: 74 | logger.warning(f"Failed to load line {line}") 75 | logger.info(f"Loaded {len(keys)} entries from {self._filename} in {timer.interval:.2f} s") 76 | return dict(zip(keys, vals)) 77 | 78 | def _append_to_file(self, key, value): 79 | if self._fp is None: 80 | self._filename.parent.mkdir(parents=True, exist_ok=True) 81 | if not self._filename.exists(): 82 | self._fp = open(self._filename, "wb") 83 | else: 84 | self._fp = open(self._filename, "r+b") 85 | self._fp.seek(0, os.SEEK_END) 86 | 87 | # Mark the new line as a comment until fully written out 88 | offset = self._fp.tell() 89 | 90 | if offset > 0: 91 | self._fp.write(b"\n") 92 | 93 | offset = self._fp.tell() 94 | self._fp.write(b"#") 95 | self._fp.write(key[1:]) 96 | self._fp.write(b"\t") 97 | self._fp.write(orjson.dumps(value)) 98 | self._fp.flush() 99 | 100 | # Remove the comment marker 101 | self._fp.seek(offset) 102 | self._fp.write(key[0:1]) 103 | self._fp.seek(0, os.SEEK_END) 104 | self._fp.flush() 105 | if not type(self._fp) == BytesIO: 106 | os.fsync(self._fp.fileno()) 107 | 108 | def vacuum(self) -> None: 109 | """Removes all deleted entries from the file.""" 110 | tmp_fname = self._filename.with_suffix(".tmp") 111 | with self._lock: 112 | with open(tmp_fname, "wb") as f: 113 | for key, value in self._dict.items(): 114 | if value is not None: 115 | f.write(key) 116 | f.write(b"\t") 117 | f.write(orjson.dumps(value, option=orjson.OPT_APPEND_NEWLINE)) 118 | if self._fp: 119 | self._fp.close() 120 | self._fp = None 121 | os.replace(tmp_fname, self._filename) 122 | 123 | def __contains__(self, key): 124 | bkey = self._find(key) 125 | return bkey in self._dict and self._dict[bkey] is not None 126 | 127 | def __getitem__(self, key): 128 | return self._dict[self._find(key)] 129 | 130 | def __getstate__(self): 131 | return self._dict 132 | 133 | def __setstate__(self, state): 134 | self._dict = state 135 | 136 | def _find(self, key): 137 | return serialize_json(key) 138 | 139 | def __setitem__(self, key, value): 140 | bkey = serialize_json(key) 141 | with self._lock: 142 | with self._os_lock: 143 | self._append_to_file(bkey, value) 144 | self._dict[bkey] = value 145 | 146 | def __delitem__(self, key): 147 | # Convention: deleted items set to None 148 | bkey = self._find(key) 149 | if bkey in self._dict: 150 | with self._lock: 151 | with self._os_lock: 152 | self._append_to_file(bkey, None) 153 | del self._dict[bkey] 154 | 155 | def __iter__(self): 156 | return iter(self._dict) 157 | 158 | def __len__(self): 159 | return len(self._dict) 160 | 161 | def to_json(self, **kwargs): 162 | return {"_type": "PersistentDict", "filename": str(self._filename)} 163 | 164 | @classmethod 165 | def from_json(cls, json_value, **kwargs): 166 | return cls(json_value["filename"]) 167 | 168 | 169 | class InMemoryDict(PersistentDict): 170 | """ 171 | Implements a dictionary that lives only in memory. Entries are not persisted to disk unless `persist` is called. 172 | 173 | """ 174 | 175 | def __init__(self): 176 | self._dict = dict() 177 | self._lock = threading.Lock() 178 | self._os_lock = ExitStack() 179 | 180 | def _append_to_file(self, key, value): 181 | pass 182 | 183 | def persist(self, filename: os.PathLike | str): 184 | """Persists the dictionary to disk. 185 | 186 | :param filename: path for the stored data. 187 | """ 188 | with open(filename, "wb") as f: 189 | for key, value in self._dict.items(): 190 | if value is not None: 191 | f.write(key) 192 | f.write(b"\t") 193 | f.write(orjson.dumps(value, option=orjson.OPT_APPEND_NEWLINE)) 194 | 195 | 196 | class SqlLiteDict(PersistentDict): 197 | """ 198 | Implements a dictionary that is persisted to disk a SQLite DB. It is a bit slower for shorter entries than using 199 | PersistentDict which is all buffered in memory, but preferable when storing larger objects, such as vectors. 200 | 201 | :param directory: path for the stored data. If none, uses a temporary directory. 202 | """ 203 | 204 | def __init__(self, directory: Union[os.PathLike, str, None] = None): 205 | if directory and Path(directory).exists() and not Path(directory).is_dir(): 206 | raise ValueError(f"{directory} is not a directory.") 207 | self._filename = directory 208 | self._dict = diskcache.Cache(directory, eviction_policy="none") 209 | 210 | def vacuum(self) -> None: 211 | pass 212 | 213 | def _find(self, key): 214 | return serialize_json(key) 215 | 216 | def __contains__(self, key): 217 | return self._find(key) in self._dict 218 | 219 | def __setitem__(self, key, value): 220 | self._dict[self._find(key)] = value 221 | 222 | def __delitem__(self, key): 223 | del self._dict[self._find(key)] 224 | 225 | def to_json(self, **kwargs): 226 | return {"_type": "SqlLiteDict", "filename": str(self._filename)} 227 | -------------------------------------------------------------------------------- /sammo/store_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from unittest.mock import patch, mock_open, Mock 4 | import pytest 5 | from io import BytesIO 6 | from sammo.store import PersistentDict, InMemoryDict, serialize_json, SqlLiteDict 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "expected,data", 11 | [ 12 | ({b'"test"': "Hello"}, b'"test"\t"Hello"'), 13 | ({b'"test"': "Hello", b'"int"': 3}, b'"test"\t"Hello"\n"int"\t3'), 14 | ({b'"test"': "Hello"}, b'"test"\t"Hello"'), 15 | ({b'"test"': "Hello", b'"int"': None}, b'"test"\t"Hello"\n"int"\t3\n"int"\tnull'), 16 | ], 17 | ) 18 | def test_read(data, expected): 19 | # patch file lock 20 | with patch("sammo.store.filelock.FileLock"): 21 | with patch("builtins.open", mock_open(read_data=data)) as mock_file: 22 | store = PersistentDict("dummy.txt") 23 | read = store._load() 24 | assert read == expected 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "data", 29 | [ 30 | [("test", "Hello")], 31 | [("test", "Hello"), ("int", 3)], 32 | ], 33 | ) 34 | def test_delete(data): 35 | with patch("sammo.store.filelock.FileLock"): 36 | with patch("builtins.open", lambda x, y: BytesIO()) as mock_file: 37 | store = PersistentDict("test_file.data") 38 | for k, v in data: 39 | store[k] = v 40 | assert data[0][0] in store 41 | del store[data[0][0]] 42 | assert data[0][0] not in store 43 | 44 | 45 | def test_vaccum(): 46 | with patch("sammo.store.filelock.FileLock"): 47 | with patch("os.replace"): 48 | with patch("builtins.open", lambda x, y: BytesIO()) as mock_file: 49 | store = PersistentDict("test_file.data") 50 | 51 | data = [("test", "Hello"), ("int", 3), ("int", None)] 52 | for k, v in data: 53 | store[k] = v 54 | store.vacuum() 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "data,expected", 59 | [ 60 | ([("test", "Hello")], b'"test"\t"Hello"'), 61 | ([("test", "Hello"), ("int", 3)], b'"test"\t"Hello"\n"int"\t3'), 62 | ([("test", "Hello"), ("int", None)], b'"test"\t"Hello"\n"int"\tnull'), 63 | ([("test", "Hello"), ("int", 3), ("int", None)], b'"test"\t"Hello"\n"int"\t3\n"int"\tnull'), 64 | ( 65 | [("test", "Hello"), ("int", 3), ("int", None), ("test", True)], 66 | b'"test"\t"Hello"\n"int"\t3\n"int"\tnull\n"test"\ttrue', 67 | ), 68 | ], 69 | ) 70 | def test_write(data, expected): 71 | with patch("sammo.store.filelock.FileLock"): 72 | with patch("builtins.open", lambda x, y: BytesIO()) as mock_file: 73 | store = PersistentDict("test_file.data") 74 | for k, v in data: 75 | store[k] = v 76 | store._fp.seek(0) 77 | content = store._fp.read() 78 | assert content == expected 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "data,expected", 83 | [([("test", "Hello")], b'"test"\t"Hello"'), ([("test", "Hello"), ("int", 3)], b'"test"\t"Hello"\n"int"\t3')], 84 | ) 85 | def test_persist(data, expected): 86 | with patch("builtins.open") as mock_file: 87 | store = InMemoryDict() 88 | for k, v in data: 89 | store[k] = v 90 | store.persist("") 91 | print(mock_file.mock_calls) 92 | assert len(mock_file.mock_calls) == len(data) * 3 + 3 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "data", 97 | [3, {"test": "Hello"}, "some string", b"bytes string"], 98 | ) 99 | def test_fix_point(data): 100 | serialize_json(data) == serialize_json(serialize_json(data)) 101 | 102 | 103 | def test_sqlite(): 104 | store = SqlLiteDict(None) 105 | store["test"] = "Hello" 106 | assert store["test"] == "Hello" 107 | store["test"] = "World" 108 | assert store["test"] == "World" 109 | del store["test"] 110 | assert "test" not in store 111 | -------------------------------------------------------------------------------- /sammo/throttler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """ 4 | Provides a context manager for throttling async jobs. The throttling is defined by a list of AtMost instance and 5 | can be used to limit the number of concurrent jobs, the number of jobs per time period, the total cost of jobs per 6 | time period, or the number of failed jobs per time period. The context manager will block until there is capacity 7 | to run the job. The jobs are run in order of priority, breaking ties with creation time. 8 | """ 9 | 10 | from __future__ import annotations 11 | import asyncio 12 | import bisect 13 | from collections import deque 14 | from dataclasses import dataclass, field 15 | import enum 16 | import logging 17 | import threading 18 | import time 19 | 20 | from beartype import beartype 21 | from beartype.typing import Literal, Union 22 | 23 | 24 | __all__ = ["Throttler", "AtMost"] 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class JobStatus(enum.Enum): 30 | NEW = 0 31 | RUNNING = 1 32 | REJECTED = 2 33 | FAILED = 3 34 | SUCCESSFUL = 4 35 | 36 | 37 | @dataclass(order=True) 38 | class Job: 39 | """Class for keeping track of an async job.""" 40 | 41 | priority: int 42 | id: int 43 | created: float = field(default_factory=time.perf_counter, compare=False) 44 | start: float = field(default=0, compare=False) 45 | end: float = field(default=0, compare=False) 46 | cost: int = field(default=0, compare=False) 47 | status: JobStatus = field(default=JobStatus.NEW, compare=False) 48 | 49 | def __str__(self) -> str: 50 | return f"{self.priority}|{self.id}" 51 | 52 | def get_value(self, property: str) -> float | int | bool: 53 | if property == "calls": 54 | return True 55 | elif property == "cost": 56 | return self.cost 57 | else: 58 | return self.status == JobStatus[property.upper()] 59 | 60 | def time_since_start(self) -> float: 61 | return time.perf_counter() - self.start 62 | 63 | def time_since_end(self) -> float: 64 | if self.status in [JobStatus.NEW, JobStatus.RUNNING]: 65 | return -1 66 | return time.perf_counter() - self.end 67 | 68 | 69 | @beartype 70 | @dataclass 71 | class AtMost: 72 | """Class for defining a throttling limit.""" 73 | 74 | value: Union[float, int] 75 | type: Literal["calls", "running", "failed", "rejected"] 76 | period: Union[float, int] = 1 77 | pause_for: Union[float, int] = 0 78 | 79 | 80 | @beartype 81 | class Throttler: 82 | """Class that provides flexible throttling for async jobs. 83 | 84 | :param limits: A list of :class:`sammo.throttler.AtMost` instances that define the throttling limits. 85 | :param sleep_interval: The time (in s) between checks for capacity. 86 | :param impute_pending_costs: Whether to estimate the cost of pending jobs with the running average. 87 | :param n_cost_samples: The number of samples to use when calculating the running average. 88 | :param rejection_window: The time (in s) within which a job is considered rejected instead of failed. 89 | """ 90 | 91 | DEBUG_INTERVAL_SECONDS = 3 92 | 93 | def __init__( 94 | self, 95 | limits: list[AtMost], 96 | sleep_interval: float = 0.01, 97 | impute_pending_costs: bool = True, 98 | n_cost_samples: int = 10, 99 | rejection_window: Union[int, float] = 0.5, 100 | ): 101 | self._limits = limits 102 | self._max_history_window = max([x.period for x in limits] + [60]) 103 | self._sleep_interval = sleep_interval 104 | self._lock = threading.Lock() 105 | self._task_logs = deque() 106 | self._wait_list = deque() 107 | self._id_counter = 0 108 | self._cost_samples = list() 109 | self._running_avg = 0 110 | self._n_cost_samples = n_cost_samples 111 | self._impute_pending_costs = impute_pending_costs 112 | self._rejection_limit_start = None 113 | self._last_log = 0 114 | self._daemon_active = None 115 | self._rejection_window = rejection_window 116 | rejected = [x for x in limits if x.type == "rejected"] 117 | 118 | if len(rejected) > 1: 119 | raise ValueError("Only one rejected limit can be specified.") 120 | elif len(rejected) == 1: 121 | self._rejection_limit = rejected[0] 122 | else: 123 | self._rejection_limit = None 124 | 125 | def _collect_garbage(self) -> None: 126 | while self._task_logs: 127 | job = self._task_logs[0] 128 | if job.time_since_end() > self._max_history_window: 129 | with self._lock: 130 | self._task_logs.popleft() 131 | else: 132 | break 133 | 134 | @staticmethod 135 | async def sleep(delay: float): 136 | """A more precise sleep function on Windows""" 137 | await asyncio.get_running_loop().run_in_executor(None, time.sleep, delay) 138 | 139 | def update_job_stats(self, job: Job, cost: Union[float, int], failed: bool = False) -> None: 140 | """Update the stats for a job. Needs to be called when a job is finished. 141 | 142 | :param job: Job instance to update. 143 | :param cost: The cost of the job. 144 | :param failed: Whether the job failed, default is False. 145 | """ 146 | job.cost = cost 147 | job.end = time.perf_counter() 148 | with self._lock: 149 | n_running = sum([x.get_value("running") for x in self._task_logs]) 150 | if self._daemon_active is not None and n_running <= 1: 151 | # last job finished, cancel the daemon 152 | self._daemon_active.cancel() 153 | self._daemon_active = None 154 | if failed: 155 | if job.time_since_start() < self._rejection_window: 156 | job.status = JobStatus.REJECTED 157 | if self._rejection_limit is not None: 158 | self._rejection_limit_start = job.end 159 | else: 160 | job.status = JobStatus.FAILED 161 | else: 162 | job.status = JobStatus.SUCCESSFUL 163 | if self._rejection_limit_start and job.start > self._rejection_limit_start: 164 | self._rejection_limit_start = None 165 | 166 | self._cost_samples = self._cost_samples[-self._n_cost_samples :] + [cost] 167 | self._running_avg = sum(self._cost_samples) / (1.0 if self._cost_samples == 0 else len(self._cost_samples)) 168 | 169 | async def wait_in_line(self, priority: int = 0) -> Job: 170 | """Wait async until there is capacity to run a job. The jobs are run in order of priority, 171 | breaking ties with creation time. 172 | 173 | :param priority: The priority of the job. Lower numbers are higher priority. 174 | """ 175 | try: 176 | with self._lock: 177 | if self._daemon_active is None: 178 | self._daemon_active = asyncio.get_event_loop().create_task(self._log_stats()) 179 | my_id = self._id_counter 180 | self._id_counter += 1 181 | this_job = Job(priority=priority, id=my_id) 182 | bisect.insort(self._wait_list, this_job) 183 | 184 | while True: 185 | if self._has_capacity() and self._wait_list[0].id == my_id: 186 | break 187 | 188 | await self.sleep(self._sleep_interval) 189 | self._wait_list.remove(this_job) 190 | this_job.cost = self._running_avg if self._impute_pending_costs else 0 191 | this_job.start = time.perf_counter() 192 | this_job.status = JobStatus.RUNNING 193 | self._task_logs.append(this_job) 194 | 195 | self._collect_garbage() 196 | return this_job 197 | 198 | except asyncio.CancelledError: 199 | logger.debug(f"Canceling {my_id}") 200 | self._wait_list.remove(this_job) 201 | raise 202 | 203 | async def _log_stats(self) -> None: 204 | while True: 205 | with self._lock: 206 | completed = [x for x in self._task_logs if x.time_since_end() <= 60] 207 | running = [x for x in self._task_logs if x.status == JobStatus.RUNNING] 208 | successful = [x for x in completed if x.status == JobStatus.SUCCESSFUL] 209 | failed = [x for x in completed if x.status == JobStatus.FAILED] 210 | costs = sum([x.cost for x in completed]) 211 | rejected = [x for x in completed if x.status == JobStatus.REJECTED] 212 | 213 | logger.info( 214 | f"{len(running)} running, " 215 | f"last minute: {len(successful)} successful ({costs} total costs), " 216 | f"{len(rejected)} rejected, {len(failed)} failed later" 217 | ) 218 | await asyncio.sleep(self.DEBUG_INTERVAL_SECONDS) 219 | 220 | def _active_rejection_limit(self) -> list[AtMost]: 221 | if self._rejection_limit_start: 222 | return [AtMost(1, "calls", self._rejection_limit.pause_for)] 223 | else: 224 | return list() 225 | 226 | def _has_capacity(self) -> bool: 227 | individual_limits_okay = list() 228 | with self._lock: 229 | for limit in self._limits + self._active_rejection_limit(): 230 | if limit.type == "calls": 231 | relevant_jobs = [x for x in self._task_logs if x.time_since_start() < limit.period] 232 | individual_limits_okay.append(len(relevant_jobs) < limit.value) 233 | elif limit.type == "running": 234 | individual_limits_okay.append(sum([x.get_value("running") for x in self._task_logs]) < limit.value) 235 | elif limit.type in ["cost", "failed", "rejected"]: 236 | relevant_jobs = [x for x in self._task_logs if x.time_since_end() < limit.period] 237 | job_count = sum([x.get_value(limit.type) for x in relevant_jobs]) 238 | individual_limits_okay.append(job_count < limit.value) 239 | 240 | return all(individual_limits_okay) 241 | -------------------------------------------------------------------------------- /sammo/throttler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import asyncio 4 | import quattro 5 | import threading 6 | import time 7 | 8 | import pytest 9 | from pytest import approx 10 | 11 | from sammo.throttler import Throttler, AtMost 12 | 13 | 14 | async def simple_job(job_id, throttler, fail=False, delay=0): 15 | scheduled = time.perf_counter() 16 | job = await throttler.wait_in_line() 17 | run = time.perf_counter() 18 | if delay > 0: 19 | await asyncio.sleep(delay) 20 | end = time.perf_counter() 21 | throttler.update_job_stats(job, cost=0, failed=fail) 22 | return { 23 | "scheduled": scheduled, 24 | "start": run, 25 | "duration": end - scheduled, 26 | "net_duration": end - run, 27 | "job_id": job_id, 28 | } 29 | 30 | 31 | @pytest.mark.asyncio 32 | @pytest.mark.parametrize("n_jobs,completion_time", [(10, 0.1), (11, 0.2), (20, 0.3)]) 33 | async def test_basic_call_limit(n_jobs, completion_time): 34 | throttler = Throttler([AtMost(10, "calls", 0.1)]) 35 | 36 | async with quattro.TaskGroup() as g: 37 | jobs = [g.create_task(simple_job(i, throttler)) for i in range(n_jobs)] 38 | jobs = [j.result() for j in jobs] 39 | durations = [j["duration"] for j in jobs] 40 | # provide a relaxed upper bound for max duration to account for differences 41 | # in executors across test environments 42 | assert max(durations) <= (completion_time * 2) 43 | 44 | 45 | @pytest.mark.asyncio 46 | @pytest.mark.parametrize("n_jobs,completion_time", [(10, 0.06), (12, 0.11)]) 47 | async def test_basic_running_limit(n_jobs, completion_time, job_duration=0.05): 48 | throttler = Throttler([AtMost(10, "running")], sleep_interval=0.001) 49 | 50 | async with quattro.TaskGroup() as g: 51 | jobs = [g.create_task(simple_job(i, throttler, delay=job_duration)) for i in range(n_jobs)] 52 | jobs = [j.result() for j in jobs] 53 | 54 | durations = [j["duration"] for j in jobs] 55 | # provide a relaxed upper bound for max duration to account for differences 56 | # in executors across test environments 57 | assert max(durations) <= (completion_time * 2) 58 | 59 | 60 | @pytest.mark.asyncio 61 | @pytest.mark.parametrize("jobs_with_flags,completion_time", [([True] * 2 + [False] * 5, 0.21)]) 62 | async def test_basic_failed_limit(jobs_with_flags, completion_time): 63 | throttler = Throttler([AtMost(1, "failed", 0.1)], rejection_window=-1, sleep_interval=0.001) 64 | 65 | async with quattro.TaskGroup() as g: 66 | jobs = [g.create_task(simple_job(i, throttler, fail=j)) for i, j in enumerate(jobs_with_flags)] 67 | jobs = [j.result() for j in jobs] 68 | 69 | durations = [j["duration"] for j in jobs] 70 | # provide a relaxed upper bound for max duration to account for differences 71 | # in executors across test environments 72 | assert max(durations) <= (completion_time * 2) 73 | 74 | 75 | @pytest.mark.asyncio 76 | @pytest.mark.parametrize( 77 | "jobs_with_flags,completion_time", 78 | [([True, False] * 1, 0.11), ([True] * 2 + [False] * 1, 0.21), ([True] * 1 + [False] * 5, 0.11)], 79 | ) 80 | async def test_basic_rejected_limit(jobs_with_flags, completion_time): 81 | throttler = Throttler([AtMost(1, "rejected", 0.1, 0.1)], rejection_window=1, sleep_interval=0.001) 82 | 83 | async with quattro.TaskGroup() as g: 84 | jobs = [g.create_task(simple_job(i, throttler, fail=j)) for i, j in enumerate(jobs_with_flags)] 85 | jobs = [j.result() for j in jobs] 86 | durations = [j["duration"] for j in jobs] 87 | # provide a relaxed upper bound for max duration to account for differences 88 | # in executors across test environments 89 | assert max(durations) <= (completion_time * 2) 90 | -------------------------------------------------------------------------------- /sammo/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Small number of utility functions that are used across SAMMO.""" 4 | import asyncio 5 | import collections 6 | import tempfile 7 | import time 8 | import pathlib 9 | import sys 10 | import webbrowser 11 | from concurrent.futures import ThreadPoolExecutor 12 | from html import escape 13 | 14 | from orjson import orjson 15 | 16 | __all__ = [ 17 | "CodeTimer", 18 | "MAIN_PATH", 19 | "MAIN_NAME", 20 | "DEFAULT_SAVE_PATH", 21 | "sync", 22 | "serialize_json", 23 | ] 24 | 25 | GRAPH_TEMPLATE = """ 26 | 27 | 28 | 29 | 30 | 74 | 75 | Callgraph 76 | 77 | 78 | 79 | 80 | 81 |
82 |
Click on node for details.
83 |
84 |
85 | 142 | 143 | 144 | 145 | """ 146 | 147 | 148 | class CodeTimer: 149 | """Time code with this context manager.""" 150 | 151 | def __init__(self): 152 | self.created = time.perf_counter() 153 | self._interval = None 154 | 155 | @property 156 | def interval(self) -> float: 157 | """Timed interval in s.""" 158 | if self._interval is None: 159 | return time.perf_counter() - self.created 160 | return self._interval 161 | 162 | def __enter__(self): 163 | self.start = time.perf_counter() 164 | return self 165 | 166 | def __exit__(self, *args): 167 | self.end = time.perf_counter() 168 | self._interval = self.end - self.start 169 | 170 | 171 | def is_thread_running_async_loop() -> bool: 172 | try: 173 | asyncio.get_running_loop() 174 | return True 175 | except RuntimeError: 176 | return False 177 | 178 | 179 | def sync(f: collections.abc.Coroutine): 180 | """Execute and return result of an async function. Take special care of already running async loops.""" 181 | if is_thread_running_async_loop(): 182 | # run inside a new thread 183 | with ThreadPoolExecutor(1) as pool: 184 | result = pool.submit(lambda: asyncio.run(f)) 185 | return result.result() 186 | else: 187 | return asyncio.run(f) 188 | 189 | 190 | def is_interactive() -> bool: 191 | """Check if the code is running in an interactive shell.""" 192 | return hasattr(sys, "ps1") 193 | 194 | 195 | def is_jupyter() -> bool: 196 | """Check if code is running in jupyter lab or notebook.""" 197 | try: 198 | if get_ipython().__class__.__name__ == "ZMQInteractiveShell": 199 | return True 200 | else: 201 | return False 202 | except NameError: 203 | return False 204 | 205 | 206 | class HtmlRenderer: 207 | """Render HTML in an IFrame for Jupyter or as temporary file.""" 208 | 209 | def __init__(self, raw_html, width="100%", height="300px"): 210 | self.raw_html = raw_html 211 | self.width = width 212 | self.height = height 213 | 214 | def _repr_html_(self, **kwargs): 215 | iframe = f"""\ 216 | """ 219 | return iframe 220 | 221 | def render(self, backend="auto"): 222 | if backend == "auto": 223 | backend = "jupyter" if is_jupyter() else "file" 224 | if backend == "jupyter": 225 | return self 226 | else: 227 | with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8", suffix=".html") as f: 228 | f.write(self.raw_html) 229 | webbrowser.open("file://" + f.name, new=2, autoraise=True) 230 | 231 | 232 | def get_main_script_path() -> pathlib.Path: 233 | """Path of the main script if not interactive, otherwise working dir.""" 234 | if is_interactive(): 235 | return pathlib.Path.cwd().resolve() 236 | else: 237 | return pathlib.Path(sys.argv[0]).resolve().parent 238 | 239 | 240 | def get_main_script_name(if_interactive="tmp") -> str: 241 | """Name of the main script file if not interactive, otherwise 'tmp'.""" 242 | if is_interactive(): 243 | return if_interactive 244 | else: 245 | return pathlib.Path(sys.argv[0]).name 246 | 247 | 248 | def get_default_save_path() -> str: 249 | """Default save path is folder with the same name as main script.""" 250 | if is_interactive(): 251 | return get_main_script_path() / get_main_script_name() 252 | else: 253 | return pathlib.Path(sys.argv[0]).with_suffix("").resolve() 254 | 255 | 256 | def serialize_json(key) -> bytes: 257 | """Serialize json with orjson to invariant byte string.""" 258 | return orjson.dumps(key, option=orjson.OPT_SORT_KEYS) 259 | 260 | 261 | MAIN_PATH = get_main_script_path() 262 | """Path of the main script if not interactive, otherwise working dir.""" 263 | 264 | MAIN_NAME = get_main_script_name() 265 | """Name of the main script file if not interactive, otherwise 'tmp'.""" 266 | 267 | DEFAULT_SAVE_PATH = get_default_save_path() 268 | """Default save path is folder with the same name as main script.""" 269 | --------------------------------------------------------------------------------