├── .hooks ├── .gitkeep ├── linters │ └── yamllint.yaml ├── prettier.sh ├── post_merge.sh ├── generate_pr_description.py ├── check_pinned_hash_dependencies.py └── generate_docs.py ├── rigging ├── py.typed ├── version.py ├── tokenizer │ ├── __init__.pyi │ ├── __init__.py │ └── transformers_.py ├── tools │ ├── __init__.py │ └── robopages.py ├── generator │ ├── __init__.pyi │ ├── __init__.py │ ├── vllm_.py │ └── transformers_.py ├── transform │ ├── base.py │ └── __init__.py ├── caching.py ├── logging.py ├── __init__.py ├── interact.py ├── parsing.py └── error.py ├── tests ├── __init__.py ├── generators.py ├── test_completion_pipeline.py ├── test_generator_ids.py ├── test_model.py ├── test_prompt.py └── test_watchers.py ├── CODEOWNERS ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── ISSUE_TEMPLATE │ ├── config.yaml │ ├── bug_report.yaml │ └── feature_request.yaml ├── workflows │ ├── semantic-prs.yaml │ ├── meta-labeler.yaml │ ├── meta-sync-labels.yaml │ ├── ci.yml │ ├── docs-update.yaml │ ├── publish.yml │ ├── rigging_pr_description.yml │ ├── semgrep.yaml │ ├── template-sync.yaml │ └── renovate.yaml ├── renovate.json5 ├── labeler.yaml ├── CONTRIBUTING.md ├── labels.yaml └── CODE_OF_CONDUCT.md ├── docs ├── assets │ └── tracing_logfire.png ├── topics │ ├── logging.mdx │ ├── completions.mdx │ ├── tracing.mdx │ ├── workflow.mdx │ ├── serialization.mdx │ └── iterating-and-batching.mdx ├── install.mdx ├── docs.json └── api │ ├── logging.mdx │ ├── interact.mdx │ ├── parsing.mdx │ └── error.mdx ├── .vscode └── settings.json ├── examples ├── robopages.py ├── chat.py └── tokenize.ipynb ├── LICENSE ├── .pre-commit-config.yaml ├── .gitignore ├── pyproject.toml ├── .secrets.baseline └── README.md /.hooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rigging/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @dreadnode/team 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Notes 2 | 3 | - 4 | 5 | --- 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | blank_issues_enabled: false 3 | -------------------------------------------------------------------------------- /rigging/version.py: -------------------------------------------------------------------------------- 1 | import importlib_metadata 2 | 3 | VERSION = importlib_metadata.version("rigging") 4 | -------------------------------------------------------------------------------- /docs/assets/tracing_logfire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreadnode/rigging/HEAD/docs/assets/tracing_logfire.png -------------------------------------------------------------------------------- /.hooks/linters/yamllint.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | extends: default 3 | 4 | rules: 5 | line-length: 6 | max: 400 7 | level: warning 8 | truthy: false 9 | comments: 10 | min-spaces-from-content: 1 11 | braces: disable 12 | indentation: disable 13 | -------------------------------------------------------------------------------- /rigging/tokenizer/__init__.pyi: -------------------------------------------------------------------------------- 1 | from rigging.tokenizer.base import ( 2 | TokenizedChat, 3 | Tokenizer, 4 | TokenSlice, 5 | get_tokenizer, 6 | register_tokenizer, 7 | ) 8 | from rigging.tokenizer.transformers_ import TransformersTokenizer 9 | 10 | __all__ = [ 11 | "TokenSlice", 12 | "TokenizedChat", 13 | "Tokenizer", 14 | "TransformersTokenizer", 15 | "get_tokenizer", 16 | "register_tokenizer", 17 | ] 18 | -------------------------------------------------------------------------------- /.github/workflows/semantic-prs.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Semantic Lints PR" 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | types: 8 | - opened 9 | - edited 10 | - synchronize 11 | - reopened 12 | 13 | permissions: 14 | pull-requests: read 15 | 16 | jobs: 17 | main: 18 | name: Validate PR title 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 22 | env: 23 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.formatOnSave": true, 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "explicit", 6 | "source.organizeImports": "explicit" 7 | }, 8 | "editor.defaultFormatter": "charliermarsh.ruff" 9 | }, 10 | "python.testing.pytestArgs": [ 11 | "tests" 12 | ], 13 | "python.testing.unittestEnabled": false, 14 | "python.testing.pytestEnabled": true, 15 | "mypy.runUsingActiveInterpreter": true, 16 | "debugpy.debugJustMyCode": false, 17 | "jupyter.debugJustMyCode": false 18 | } 19 | -------------------------------------------------------------------------------- /rigging/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines handles tool interaction with rigging generation. 3 | """ 4 | 5 | from rigging.tools.base import ( 6 | FunctionCall, 7 | FunctionDefinition, 8 | Tool, 9 | ToolCall, 10 | ToolChoice, 11 | ToolDefinition, 12 | ToolMode, 13 | tool, 14 | tool_method, 15 | ) 16 | from rigging.tools.mcp import as_mcp, mcp 17 | from rigging.tools.robopages import robopages 18 | 19 | __all__ = [ 20 | "FunctionCall", 21 | "FunctionDefinition", 22 | "Tool", 23 | "ToolCall", 24 | "ToolChoice", 25 | "ToolDefinition", 26 | "ToolMode", 27 | "as_mcp", 28 | "mcp", 29 | "robopages", 30 | "tool", 31 | "tool_method", 32 | ] 33 | -------------------------------------------------------------------------------- /examples/robopages.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import litellm 5 | import logfire 6 | 7 | import rigging as rg 8 | 9 | # Constants 10 | SYSTEM_PROMPT = "Enumerate all open TCP ports on 127.0.0.1 and provide a vulnerability report" 11 | 12 | # LOGFIRE_PROJECT = "rigging-demos" 13 | logfire.configure() 14 | os.environ.setdefault("LOGFIRE_TOKEN", "") # (1)! 15 | litellm.callbacks = ["logfire"] 16 | 17 | rg.logging.configure_logging("debug") 18 | 19 | 20 | async def main(): 21 | tools = rg.robopages("http://localhost:8000") 22 | chat = await rg.get_generator("gpt-4").chat(SYSTEM_PROMPT).using(*tools).run() 23 | print(chat.conversation) 24 | 25 | 26 | if __name__ == "__main__": 27 | asyncio.run(main()) 28 | -------------------------------------------------------------------------------- /.github/renovate.json5: -------------------------------------------------------------------------------- 1 | { 2 | $schema: "https://docs.renovatebot.com/renovate-schema.json", 3 | extends: [ 4 | "config:base", 5 | ":disableRateLimiting", 6 | ":semanticCommits", 7 | ":enablePreCommit", 8 | ":automergeDigest", 9 | ":automergeBranch", 10 | ], 11 | dependencyDashboardTitle: "Renovate Dashboard 🤖", 12 | suppressNotifications: ["prIgnoreNotification"], 13 | rebaseWhen: "conflicted", 14 | commitBodyTable: true, 15 | "pre-commit": { 16 | enabled: true, 17 | }, 18 | enabledManagers: [ 19 | "github-actions", 20 | "dockerfile", 21 | "docker-compose", 22 | "pre-commit" 23 | ], 24 | ignorePaths: [], 25 | packageRules: [ 26 | { 27 | description: "Auto merge non-major updates", 28 | matchUpdateTypes: ["minor", "patch"], 29 | automerge: true, 30 | automergeType: "pr", 31 | }, 32 | ], 33 | } 34 | -------------------------------------------------------------------------------- /.github/workflows/meta-labeler.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Labeler" 3 | on: 4 | pull_request_target: 5 | branches: ["main"] 6 | types: ["opened", "synchronize"] 7 | 8 | permissions: 9 | actions: read 10 | contents: read 11 | issues: write 12 | pull-requests: write 13 | 14 | jobs: 15 | labeler: 16 | name: Labeler 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Generate Token 20 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 21 | id: app-token 22 | with: 23 | app-id: "${{ secrets.BOT_APP_ID }}" 24 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" 25 | 26 | - name: Labeler 27 | uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 28 | with: 29 | configuration-path: .github/labeler.yaml 30 | repo-token: "${{ steps.app-token.outputs.token }}" 31 | -------------------------------------------------------------------------------- /docs/topics/logging.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Logging" 3 | description: "How to view and configure logging" 4 | public: true 5 | --- 6 | 7 | Rigging uses [loguru](https://loguru.readthedocs.io/) for it's logging. By default it disables it's logger allowing users to choose when/how to gather messages. 8 | 9 | If you want to let rigging messages flow into loguru, you should enable it: 10 | 11 | ```python 12 | from loguru import logger 13 | 14 | logger.enable('rigging') 15 | ``` 16 | 17 | If you want sane default handlers with dual console & file logging, you can use the `rigging.logging.configure_logging` function to configure loguru. 18 | 19 | ```python 20 | from rigging.logging import configure_logging 21 | 22 | configure_logging( 23 | 'info', # stderr level 24 | 'out.log', # log file (optional) 25 | 'trace' # log file level 26 | ) 27 | ``` 28 | *(This will remove existing handlers, so you might prefer to configure them yourself)* 29 | -------------------------------------------------------------------------------- /rigging/generator/__init__.pyi: -------------------------------------------------------------------------------- 1 | from rigging.generator.base import ( 2 | GeneratedMessage, 3 | GeneratedText, 4 | GenerateParams, 5 | Generator, 6 | StopReason, 7 | Usage, 8 | chat, 9 | complete, 10 | get_generator, 11 | get_identifier, 12 | register_generator, 13 | ) 14 | from rigging.generator.http import HTTPGenerator 15 | from rigging.generator.litellm_ import LiteLLMGenerator 16 | from rigging.generator.transformers_ import TransformersGenerator 17 | from rigging.generator.vllm_ import VLLMGenerator 18 | 19 | __all__ = [ 20 | "GenerateParams", 21 | "GeneratedMessage", 22 | "GeneratedText", 23 | "Generator", 24 | "HTTPGenerator", 25 | "LiteLLMGenerator", 26 | "StopReason", 27 | "TransformersGenerator", 28 | "Usage", 29 | "VLLMGenerator", 30 | "chat", 31 | "complete", 32 | "get_generator", 33 | "get_generator", 34 | "get_identifier", 35 | "register_generator", 36 | ] 37 | -------------------------------------------------------------------------------- /docs/install.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Installation" 3 | description: "Install and set up Rigging" 4 | public: true 5 | --- 6 | 7 | We publish every version to Pypi: 8 | 9 | 10 | ```bash pip 11 | pip install -U rigging 12 | ``` 13 | 14 | ```bash uv 15 | uv pip install -U rigging 16 | ``` 17 | 18 | ```bash poetry 19 | poetry add rigging 20 | ``` 21 | 22 | 23 | If you want all the extras (vLLM, transformers, examples), specify the `all` extra: 24 | 25 | 26 | ```bash pip 27 | pip install -U rigging[all] 28 | ``` 29 | 30 | ```bash uv 31 | uv pip install -U rigging[all] 32 | ``` 33 | 34 | ```bash poetry 35 | poetry add rigging[all] 36 | ``` 37 | 38 | 39 | If you want to build from source: 40 | 41 | ```bash 42 | cd rigging/ 43 | poetry install 44 | ``` 45 | 46 | ## Migration Guides 47 | 48 | - **[Migrating from v2 -> v3](/topics/migrations#migrating-from-v2x-to-v3x)** 49 | - **[Migrating from v1 -> v2](/topics/migrations#migrating-from-v1x-to-v2x)** 50 | -------------------------------------------------------------------------------- /examples/chat.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import click 4 | 5 | import rigging as rg 6 | 7 | 8 | async def main(generator_id: str, system_prompt: str) -> None: 9 | base_pipeline = rg.get_generator(generator_id).chat({"role": "system", "content": system_prompt}) 10 | await rg.interact(base_pipeline) 11 | 12 | 13 | @click.command() 14 | @click.option( 15 | "-g", 16 | "--generator-id", 17 | type=str, 18 | required=True, 19 | help="Rigging generator identifier (gpt-4, mistral/mistral-medium, etc.)", 20 | ) 21 | @click.option( 22 | "-s", 23 | "--system-prompt", 24 | type=str, 25 | default="You are a helpful assistant.", 26 | help="System prompt to use for the generator", 27 | ) 28 | def cli( 29 | generator_id: str, 30 | system_prompt: str, 31 | ) -> None: 32 | """ 33 | Rigging example of a basic terminal chat interaction. 34 | """ 35 | 36 | asyncio.run(main(generator_id, system_prompt)) 37 | 38 | 39 | if __name__ == "__main__": 40 | cli() 41 | -------------------------------------------------------------------------------- /rigging/transform/base.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from rigging.generator import GenerateParams 4 | from rigging.message import ( 5 | Message, 6 | ) 7 | 8 | if t.TYPE_CHECKING: 9 | from rigging.chat import Chat 10 | 11 | 12 | @t.runtime_checkable 13 | class PostTransform(t.Protocol): 14 | def __call__( 15 | self, 16 | chat: "Chat", 17 | /, 18 | ) -> "t.Awaitable[Chat]": 19 | """ 20 | Passed messages and params to transform. 21 | """ 22 | ... 23 | 24 | 25 | @t.runtime_checkable 26 | class Transform(t.Protocol): 27 | def __call__( 28 | self, 29 | messages: list[Message], 30 | params: GenerateParams, 31 | /, 32 | ) -> t.Awaitable[tuple[list[Message], GenerateParams, PostTransform | None]]: 33 | """ 34 | Passed messages and params to transform. 35 | 36 | May return an optional post-transform callback to be executed to unwind the transformation. 37 | """ 38 | ... 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 dreadnode 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rigging/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tokenizers encode chats and associated message data into tokens for training and inference. 3 | """ 4 | 5 | import typing as t 6 | 7 | from rigging.tokenizer.base import ( 8 | TokenizedChat, 9 | Tokenizer, 10 | TokenSlice, 11 | get_tokenizer, 12 | register_tokenizer, 13 | ) 14 | 15 | 16 | def get_transformers_lazy() -> type[Tokenizer]: 17 | try: 18 | from rigging.tokenizer.transformers_ import TransformersTokenizer 19 | except ImportError as e: 20 | raise ImportError( 21 | "TransformersTokenizer is not available. Please install `transformers` or use `rigging[llm]`.", 22 | ) from e 23 | 24 | return TransformersTokenizer 25 | 26 | 27 | register_tokenizer("transformers", get_transformers_lazy) 28 | 29 | __all__ = [ 30 | "TokenSlice", 31 | "TokenizedChat", 32 | "Tokenizer", 33 | "get_tokenizer", 34 | "register_tokenizer", 35 | ] 36 | 37 | 38 | def __getattr__(name: str) -> t.Any: 39 | if name == "TransformersTokenizer": 40 | return get_transformers_lazy() 41 | raise AttributeError(f"module {__name__} has no attribute {name}") 42 | -------------------------------------------------------------------------------- /.github/workflows/meta-sync-labels.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Meta Sync labels" 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: ["main"] 7 | paths: [".github/labels.yaml"] 8 | 9 | permissions: 10 | actions: read 11 | contents: read 12 | issues: write 13 | pull-requests: write 14 | 15 | jobs: 16 | labels: 17 | name: Sync Labels 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Generate Token 21 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 22 | id: app-token 23 | with: 24 | app-id: "${{ secrets.BOT_APP_ID }}" 25 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" 26 | 27 | - name: Set up git repository 28 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 29 | with: 30 | token: "${{ steps.app-token.outputs.token }}" 31 | 32 | - name: Sync Labels 33 | uses: EndBug/label-sync@52074158190acb45f3077f9099fea818aa43f97a # v2.3.3 34 | with: 35 | config-file: .github/labels.yaml 36 | token: "${{ steps.app-token.outputs.token }}" 37 | delete-other-labels: true 38 | -------------------------------------------------------------------------------- /.hooks/prettier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | # Check if npm is installed 5 | if ! command -v npm &> /dev/null; then 6 | echo 'Error: npm is not installed.' >&2 7 | exit 1 8 | fi 9 | 10 | # Check if Prettier is installed, install it if missing 11 | if ! command -v prettier &> /dev/null; then 12 | echo 'Error: Prettier is not installed.' >&2 13 | echo 'Installing Prettier...' 14 | npm install -g prettier 15 | fi 16 | 17 | # Verify Prettier is installed 18 | if ! command -v prettier &> /dev/null; then 19 | echo 'Error: Prettier installation failed.' >&2 20 | exit 1 21 | fi 22 | 23 | # Run Prettier on staged .json, .yaml, and .yml files 24 | echo "Running Prettier on staged files..." 25 | 26 | # List all staged files, filter for the desired extensions, and run Prettier 27 | git diff --cached --name-only --diff-filter=d | 28 | grep -E '\.(json|ya?ml)$' | 29 | xargs -I {} prettier --write {} 30 | 31 | # Add the files back to staging area as Prettier may have modified them 32 | git diff --name-only --diff-filter=d | 33 | grep -E '\.(json|ya?ml)$' | 34 | xargs git add 35 | 36 | echo "Prettier formatting completed." 37 | exit 0 38 | -------------------------------------------------------------------------------- /rigging/caching.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from loguru import logger 4 | 5 | if t.TYPE_CHECKING: 6 | from rigging.message import Message 7 | 8 | CacheMode = t.Literal["latest"] 9 | """ 10 | How to handle cache_control entries on messages. 11 | 12 | - latest: Assign cache_control to the latest 2 non-assistant messages in the pipeline before inference. 13 | """ 14 | 15 | 16 | def apply_cache_mode_to_messages( 17 | mode: CacheMode | None, 18 | messages: "list[list[Message]]", 19 | ) -> "list[list[Message]]": 20 | if mode is None: 21 | return messages 22 | 23 | if mode != "latest": 24 | logger.warning( 25 | f"Unknown caching mode '{mode}', defaulting to 'latest'", 26 | ) 27 | mode = "latest" 28 | 29 | # first remove existing cache settings 30 | updated: list[list[Message]] = [] 31 | for _messages in messages: 32 | updated = [ 33 | *updated, 34 | [m.clone().cache(cache_control=False) for m in _messages], 35 | ] 36 | 37 | # then apply the latest cache settings 38 | for _messages in updated: 39 | for message in [m for m in _messages if m.role != "assistant"][-2:]: 40 | message.cache(cache_control=True) 41 | 42 | return updated 43 | -------------------------------------------------------------------------------- /.hooks/post_merge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get pre-merge hash from the target branch 4 | old_hash=$(git show ORIG_HEAD:poetry.lock | md5sum 2> /dev/null || echo "") 5 | 6 | # Get current hash 7 | new_hash=$(md5sum poetry.lock 2> /dev/null || echo "") 8 | 9 | # Compare and run poetry install if changed 10 | if [ "$old_hash" != "$new_hash" ]; then 11 | echo "📦 Root dependencies changed. Running poetry install..." 12 | poetry install || { 13 | echo "❌ Failed to update dependencies" 14 | exit 1 15 | } 16 | echo "✅ Root dependencies updated!" 17 | else 18 | echo "📦 No root dependency changes" 19 | fi 20 | 21 | # Get pre-merge hash from the target branch 22 | old_hash=$(git show ORIG_HEAD:components/api/poetry.lock | md5sum 2> /dev/null || echo "") 23 | 24 | # Get current hash 25 | new_hash=$(md5sum components/api/poetry.lock 2> /dev/null || echo "") 26 | 27 | # Compare and run poetry install if changed 28 | if [ "$old_hash" != "$new_hash" ]; then 29 | echo "📦 API dependencies changed. Running poetry install..." 30 | cd components/api || exit 31 | if ! poetry install --with dev; then 32 | echo "❌ Failed to update dependencies" 33 | exit 1 34 | fi 35 | echo "✅ API dependencies updated!" 36 | else 37 | echo "📦 No API dependency changes" 38 | fi 39 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Lint, Typecheck, and Test 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | branches: [main] 9 | 10 | jobs: 11 | ci: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12"] 16 | 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Checkout code 21 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 22 | 23 | - name: Install Poetry 24 | uses: abatilo/actions-poetry@b8f6fe29ba2eb78e0d45ccbf41cd14154c4e25b2 25 | 26 | - name: Configure Poetry 27 | run: | 28 | poetry config virtualenvs.create true --local 29 | poetry config virtualenvs.in-project true --local 30 | 31 | - name: Setup Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | cache: "poetry" 36 | 37 | - name: Install package 38 | run: poetry install --all-extras 39 | 40 | - name: Lint 41 | run: poetry run ruff check --output-format=github rigging 42 | 43 | - name: Typecheck 44 | if: always() 45 | run: poetry run mypy rigging 46 | 47 | - name: Test 48 | if: always() 49 | run: poetry run pytest 50 | -------------------------------------------------------------------------------- /.github/workflows/docs-update.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Trigger Docs Update 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | paths: 8 | - "docs/**" 9 | - ".hooks/generate_docs.py" 10 | - ".github/workflows/docs-update.yaml" 11 | workflow_dispatch: 12 | 13 | jobs: 14 | notify-docs: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 18 | id: app-token 19 | with: 20 | app-id: ${{ vars.UPDATE_DOCS_APP_ID }} 21 | private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }} 22 | owner: "${{ github.repository_owner }}" 23 | repositories: | 24 | docs 25 | 26 | - name: Trigger docs repository workflow 27 | uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0 28 | with: 29 | token: ${{ steps.app-token.outputs.token }} 30 | repository: dreadnode/docs 31 | event-type: docs-update 32 | client-payload: | 33 | { 34 | "repository": "${{ github.repository }}", 35 | "ref": "${{ github.ref }}", 36 | "sha": "${{ github.sha }}", 37 | "source_dir": "docs", 38 | "target_dir": "open-source/rigging", 39 | "nav_target": "Open Source/Rigging" 40 | } 41 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Build and Publish 3 | 4 | on: 5 | push: 6 | tags: ["v*"] 7 | 8 | jobs: 9 | build-and-publish: 10 | name: Build and Publish 11 | environment: protected 12 | permissions: 13 | contents: read 14 | id-token: write 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 19 | 20 | - name: Setup Python 21 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 22 | with: 23 | python-version: "3.14" 24 | 25 | - name: Install Poetry 26 | uses: abatilo/actions-poetry@b8f6fe29ba2eb78e0d45ccbf41cd14154c4e25b2 27 | 28 | - name: Configure Poetry 29 | run: | 30 | poetry config virtualenvs.create true --local 31 | poetry config virtualenvs.in-project true --local 32 | 33 | - name: Install package 34 | run: poetry install 35 | 36 | - name: Validate version 37 | run: | 38 | TAG_VERSION=${GITHUB_REF#refs/tags/v} 39 | POETRY_VERSION=$(poetry version -s) 40 | 41 | if [ "$TAG_VERSION" != "$POETRY_VERSION" ]; then 42 | echo "Tag ($TAG_VERSION) doesn't match pyproject.toml ($POETRY_VERSION)" 43 | exit 1 44 | fi 45 | 46 | - name: Build package 47 | run: poetry build 48 | 49 | - name: Publish to PyPI 50 | uses: pypa/gh-action-pypi-publish@3317ede93a4981d0fc490510c6fcf8bf0e92ed05 51 | -------------------------------------------------------------------------------- /rigging/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from rigging.transform.base import PostTransform, Transform 2 | from rigging.transform.json_tools import ( 3 | JsonToolMode, 4 | make_tools_to_json_transform, 5 | tools_to_json_in_xml_transform, 6 | tools_to_json_transform, 7 | tools_to_json_with_tag_transform, 8 | ) 9 | from rigging.transform.pythonic_tools import ( 10 | make_tools_to_pythonic_transform, 11 | tools_to_pythonic_transform, 12 | ) 13 | from rigging.transform.xml_tools import make_tools_to_xml_transform 14 | 15 | 16 | def get_transform(identifier: str) -> Transform: 17 | """ 18 | Get a well-known transform by its identifier. 19 | 20 | Args: 21 | identifier: The identifier of the transform to retrieve. 22 | 23 | Returns: 24 | The corresponding transform callable. 25 | """ 26 | match identifier: 27 | case "json": 28 | return tools_to_json_transform 29 | case "json-in-xml": 30 | return tools_to_json_in_xml_transform 31 | case "json-with-tag": 32 | return tools_to_json_with_tag_transform 33 | case "pythonic": 34 | return tools_to_pythonic_transform 35 | case _: 36 | raise ValueError(f"Unknown transform identifier: {identifier}") 37 | 38 | 39 | __all__ = [ 40 | "JsonToolMode", 41 | "PostTransform", 42 | "Transform", 43 | "make_tools_to_json_transform", 44 | "make_tools_to_pythonic_transform", 45 | "make_tools_to_xml_transform", 46 | "tools_to_json_in_xml_transform", 47 | "tools_to_json_transform", 48 | "tools_to_json_with_tag_transform", 49 | "tools_to_pythonic_transform", 50 | ] 51 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "🚨 Bug Report" 3 | description: File a bug report 4 | title: "🚨 [BUG] - " 5 | labels: ["bug", "triage"] 6 | assignees: 7 | - octocat 8 | body: 9 | - type: markdown 10 | attributes: 11 | value: | 12 | Thanks for taking the time to fill out this bug report! 13 | 14 | - type: textarea 15 | id: what-happened 16 | attributes: 17 | label: What happened? 18 | description: Also tell us, what did you expect to happen? 19 | placeholder: | 20 | Steps to reproduce the behavior: 21 | 1. 22 | 2. 23 | 3. 24 | 25 | Expected behavior: 26 | ... 27 | 28 | Actual behavior: 29 | ... 30 | validations: 31 | required: true 32 | 33 | - type: textarea 34 | id: possible-fix 35 | attributes: 36 | label: Any suggestions for fixing this bug? 37 | description: If you have an idea to fix this bug, we'd love to hear it! 38 | validations: 39 | required: false 40 | 41 | - type: textarea 42 | id: logs 43 | attributes: 44 | label: Relevant log output 45 | description: Please copy and paste any relevant log output. 46 | render: shell 47 | 48 | - type: textarea 49 | id: environment 50 | attributes: 51 | label: Details about your environment 52 | description: Please provide the following information about your environment. 53 | placeholder: | 54 | ## Your Environment 55 | - Go Version: 56 | - Operating System: 57 | - Browser (if applicable): 58 | - Relevant env vars 59 | 60 | Tell us what you see! 61 | validations: 62 | required: false 63 | -------------------------------------------------------------------------------- /.github/workflows/rigging_pr_description.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Update PR Description with Rigging 3 | 4 | on: 5 | pull_request: 6 | types: [opened, synchronize] 7 | 8 | jobs: 9 | update-description: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | pull-requests: write 13 | contents: read 14 | 15 | steps: 16 | - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 17 | with: 18 | fetch-depth: 0 # full history for proper diffing 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 22 | with: 23 | python-version: "3.14" 24 | 25 | - name: Install uv 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install uv 29 | 30 | - name: Generate PR Description 31 | id: description 32 | env: 33 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 34 | run: | 35 | DESCRIPTION="$(uv run --no-project .hooks/generate_pr_description.py --base-ref "origin/${{ github.base_ref }}" --exclude "./*.lock")" 36 | { 37 | echo "description<<EOF" 38 | echo "${DESCRIPTION}" 39 | echo "EOF" 40 | } >> "$GITHUB_OUTPUT" 41 | 42 | - name: Update PR Description 43 | uses: nefrob/pr-description@4dcc9f3ad5ec06b2a197c5f8f93db5e69d2fdca7 # v1.2.0 44 | with: 45 | token: ${{ secrets.GITHUB_TOKEN }} 46 | content: | 47 | 48 | --- 49 | 50 | ## Generated Summary 51 | 52 | ${{ steps.description.outputs.description }} 53 | 54 | This summary was generated with ❤️ by [rigging](https://docs.dreadnode.io/rigging/) 55 | -------------------------------------------------------------------------------- /.github/workflows/semgrep.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Semgrep Analysis 3 | on: 4 | merge_group: 5 | pull_request: 6 | branches: 7 | - main 8 | types: 9 | - opened 10 | - synchronize 11 | - reopened 12 | push: 13 | branches: 14 | - main 15 | schedule: 16 | - cron: "0 0 * * *" # Run daily at midnight UTC 17 | 18 | concurrency: 19 | group: pre-commit-${{ github.run_id }} 20 | cancel-in-progress: true 21 | 22 | permissions: 23 | actions: read 24 | checks: write 25 | contents: read 26 | pull-requests: write # Allows merge queue updates 27 | security-events: write # Required for GitHub Security tab 28 | 29 | jobs: 30 | semgrep: 31 | name: Semgrep Analysis 32 | runs-on: ubuntu-latest 33 | container: 34 | image: returntocorp/semgrep 35 | 36 | # Skip any PR created by dependabot to avoid permission issues: 37 | if: (github.actor != 'dependabot[bot]') 38 | 39 | steps: 40 | - name: Set up git repository 41 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 42 | with: 43 | token: ${{ secrets.GITHUB_TOKEN }} 44 | 45 | - name: Configure Git Safe Directory 46 | run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" 47 | 48 | - name: Semgrep Analysis 49 | env: 50 | SEMGREP_RULES: >- 51 | p/python 52 | p/security-audit 53 | p/secrets 54 | p/owasp-top-ten 55 | p/supply-chain 56 | SEMGREP_TIMEOUT: 300 # 5-minute timeout per rule 57 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 58 | run: | 59 | semgrep ci \ 60 | --config="${SEMGREP_RULES}" \ 61 | --timeout="${SEMGREP_TIMEOUT}" \ 62 | --sarif --output=semgrep-results.sarif 63 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "💡 Feature Request" 3 | description: Create a new ticket for a new feature request 4 | title: "💡 [REQUEST] - <title>" 5 | labels: ["question"] 6 | body: 7 | - type: textarea 8 | id: implementation_pr 9 | attributes: 10 | label: "Implementation PR" 11 | description: Associated pull request 12 | placeholder: "# Pull Request ID" 13 | validations: 14 | required: false 15 | - type: textarea 16 | id: reference_issues 17 | attributes: 18 | label: "Reference Issues" 19 | description: Related issues 20 | placeholder: "# Issue ID(s)" 21 | validations: 22 | required: false 23 | - type: textarea 24 | id: summary 25 | attributes: 26 | label: "Summary" 27 | description: Provide a brief explanation of the feature 28 | placeholder: Describe your feature request 29 | validations: 30 | required: true 31 | - type: textarea 32 | id: basic_example 33 | attributes: 34 | label: "Basic Example" 35 | description: Provide some basic examples of your feature 36 | placeholder: A few specific details about your feature request 37 | validations: 38 | required: true 39 | - type: textarea 40 | id: drawbacks 41 | attributes: 42 | label: "Drawbacks" 43 | description: What are the drawbacks/impacts of your feature request? 44 | placeholder: Identify the drawbacks and impacts while remaining neutral on your feature request 45 | validations: 46 | required: true 47 | - type: textarea 48 | id: unresolved_question 49 | attributes: 50 | label: "Unresolved questions" 51 | description: What questions remain unresolved? 52 | placeholder: Identify any unresolved issues 53 | validations: 54 | required: false 55 | -------------------------------------------------------------------------------- /.github/labeler.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Area Labels 3 | area/docs: 4 | - changed-files: 5 | - any-glob-to-any-file: "docs/**/*" 6 | 7 | area/examples: 8 | - changed-files: 9 | - any-glob-to-any-file: "examples/**/*" 10 | 11 | area/github: 12 | - changed-files: 13 | - any-glob-to-any-file: ".github/**/*" 14 | 15 | area/pre-commit: 16 | - changed-files: 17 | - any-glob-to-any-file: ".pre-commit-config.yaml" 18 | - any-glob-to-any-file: ".hooks/**/*" 19 | 20 | area/python: 21 | - changed-files: 22 | - any-glob-to-any-file: "pyproject.toml" 23 | - any-glob-to-any-file: "requirements.txt" 24 | - any-glob-to-any-file: "*.py" 25 | 26 | area/security: 27 | - changed-files: 28 | - any-glob-to-any-file: "SECURITY.md" 29 | - any-glob-to-any-file: "secrets.baseline" 30 | 31 | area/taskfiles: 32 | - changed-files: 33 | - any-glob-to-any-file: "Taskfile.yaml" 34 | 35 | area/tests: 36 | - changed-files: 37 | - any-glob-to-any-file: "tests/**/*" 38 | 39 | area/workspace: 40 | - changed-files: 41 | - any-glob-to-any-file: "python.code-workspace" 42 | 43 | # Development Labels 44 | area/dev: 45 | - changed-files: 46 | - any-glob-to-any-file: "dev/**/*" 47 | 48 | # Semantic Type Labels 49 | type/digest: 50 | - head-branch: ["^renovate/"] 51 | - head-branch: ["^deps/"] 52 | 53 | type/patch: 54 | - any: ["title:/^(?:Fix|Patch|Update)/"] 55 | 56 | type/minor: 57 | - any: ["title:/^(?:Add|Feature|Improve)/"] 58 | 59 | type/major: 60 | - any: ["title:/^(?:BREAKING)/"] 61 | 62 | type/break: 63 | - any: ["body:/BREAKING CHANGE:/"] 64 | 65 | # Documentation Labels 66 | type/docs: 67 | - changed-files: 68 | - any-glob-to-any-file: "docs/**/*" 69 | - any-glob-to-any-file: "*.md" 70 | 71 | # Core Files Labels 72 | type/core: 73 | - changed-files: 74 | - any-glob-to-any-file: "CODEOWNERS" 75 | - any-glob-to-any-file: "LICENSE" 76 | - any-glob-to-any-file: "README.md" 77 | -------------------------------------------------------------------------------- /rigging/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | We use loguru for logging. This module provides a function to configure logging handlers. 3 | 4 | To just enable rigging logs to flow, call `logger.enable("rigging")` after importing the module. 5 | """ 6 | 7 | import pathlib 8 | import sys 9 | import typing as t 10 | 11 | from loguru import logger 12 | 13 | g_configured: bool = False 14 | 15 | LogLevelList = ["trace", "debug", "info", "success", "warning", "error", "critical"] 16 | LogLevelLiteral = t.Literal["trace", "debug", "info", "success", "warning", "error", "critical"] 17 | """Valid logging levels.""" 18 | 19 | 20 | def configure_logging( 21 | log_level: LogLevelLiteral = "info", 22 | log_file: pathlib.Path | None = None, 23 | log_file_level: LogLevelLiteral = "debug", 24 | ) -> None: 25 | """ 26 | Configures common loguru handlers. 27 | 28 | Args: 29 | log_level: The desired log level. 30 | log_file: The path to the log file. If None, logging 31 | will only be done to the console. 32 | log_file_level: The log level for the log file. 33 | """ 34 | global g_configured # noqa: PLW0603 35 | 36 | if g_configured: 37 | return 38 | 39 | logger.enable("rigging") 40 | 41 | logger.level("TRACE", color="<magenta>", icon="[T]") 42 | logger.level("DEBUG", color="<blue>", icon="[_]") 43 | logger.level("INFO", color="<cyan>", icon="[=]") 44 | logger.level("SUCCESS", color="<green>", icon="[+]") 45 | logger.level("WARNING", color="<yellow>", icon="[-]") 46 | logger.level("ERROR", color="<red>", icon="[!]") 47 | logger.level("CRITICAL", color="<RED>", icon="[x]") 48 | 49 | custom_format = "<green>{time:HH:mm:ss.SSS}</green> | <level>{level.icon}</level> {message}" 50 | 51 | logger.remove() 52 | logger.add(sys.stderr, format=custom_format, level=log_level.upper()) 53 | 54 | if log_file is not None: 55 | logger.add(log_file, format=custom_format, level=log_file_level.upper()) 56 | logger.info(f"Logging to {log_file}") 57 | 58 | g_configured = True 59 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this project 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Request Guidelines 7 | 8 | We actively welcome your pull requests. 9 | 10 | 1. Fork the repo and create your branch from `main`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ### PR Description Format 18 | 19 | We use a standardized format for pull request descriptions to ensure 20 | consistency and clarity: 21 | 22 | 1. **Title**: Use a clear, concise title that summarizes the changes 23 | 2. **Key Changes**: List the most important updates 24 | 3. **Added**: Document new features or files 25 | 4. **Changed**: Highlight modifications to existing code 26 | 5. **Removed**: Note any deletions or removals 27 | 28 | Example: 29 | 30 | ```markdown 31 | ### Add device configuration automation 32 | 33 | **Key Changes:** 34 | 35 | - Implement dynamic device configuration 36 | - Add automated setup scripts 37 | - Update documentation 38 | 39 | **Added:** 40 | 41 | - New device setup module 42 | - Configuration templates 43 | - Setup guide 44 | 45 | **Changed:** 46 | 47 | - Refactored device initialization 48 | - Updated configuration format 49 | - Modified setup process 50 | 51 | **Removed:** 52 | 53 | - Legacy device configs 54 | - Deprecated setup scripts 55 | ``` 56 | 57 | ## Contributor License Agreement ("CLA") 58 | 59 | In order to accept your pull request, we need you to submit a CLA. You only need 60 | to do this once to work on any of Facebook's open source projects. 61 | 62 | Complete your CLA here: <https://code.facebook.com/cla> 63 | 64 | ## Issues 65 | 66 | We use GitHub issues to track public bugs. Please ensure your description is 67 | clear and has sufficient instructions to be able to reproduce the issue. 68 | 69 | ## License 70 | 71 | By contributing to this project, you agree that your contributions will be licensed 72 | under the LICENSE file in the root directory of this source tree. 73 | -------------------------------------------------------------------------------- /rigging/generator/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generators produce completions for a given set of messages or text. 3 | """ 4 | 5 | import typing as t 6 | 7 | from rigging.generator.base import ( 8 | GeneratedMessage, 9 | GeneratedText, 10 | GenerateParams, 11 | Generator, 12 | StopReason, 13 | Usage, 14 | chat, 15 | complete, 16 | get_generator, 17 | get_identifier, 18 | register_generator, 19 | ) 20 | from rigging.generator.http import HTTPGenerator, HttpHook, HttpHookAction, HTTPSpec 21 | from rigging.generator.litellm_ import LiteLLMGenerator 22 | 23 | register_generator("litellm", LiteLLMGenerator) 24 | register_generator("http", HTTPGenerator) 25 | register_generator( 26 | "base", 27 | Generator, 28 | ) # TODO: Helper while we sort out generators being required so many places. 29 | 30 | 31 | def get_vllm_lazy() -> type[Generator]: 32 | try: 33 | from rigging.generator.vllm_ import VLLMGenerator 34 | except ImportError as e: 35 | raise ImportError( 36 | "VLLMGenerator is not available. Please install `vllm` or use `rigging[llm]`.", 37 | ) from e 38 | 39 | return VLLMGenerator 40 | 41 | 42 | register_generator("vllm", get_vllm_lazy) 43 | 44 | 45 | def get_transformers_lazy() -> type[Generator]: 46 | try: 47 | from rigging.generator.transformers_ import TransformersGenerator 48 | except ImportError as e: 49 | raise ImportError( 50 | "TransformersGenerator is not available. Please install `transformers` or use `rigging[llm]`.", 51 | ) from e 52 | 53 | return TransformersGenerator 54 | 55 | 56 | register_generator("transformers", get_transformers_lazy) 57 | 58 | __all__ = [ 59 | "GenerateParams", 60 | "GeneratedMessage", 61 | "GeneratedText", 62 | "Generator", 63 | "HTTPGenerator", 64 | "HTTPGenerator", 65 | "HTTPSpec", 66 | "HttpHook", 67 | "HttpHookAction", 68 | "LiteLLMGenerator", 69 | "StopReason", 70 | "Usage", 71 | "chat", 72 | "complete", 73 | "get_generator", 74 | "get_generator", 75 | "get_identifier", 76 | "register_generator", 77 | ] 78 | 79 | 80 | def __getattr__(name: str) -> t.Any: 81 | if name == "VLLMGenerator": 82 | return get_vllm_lazy() 83 | if name == "TransformersGenerator": 84 | return get_transformers_lazy() 85 | raise AttributeError(f"module {__name__} has no attribute {name}") 86 | -------------------------------------------------------------------------------- /docs/docs.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://mintlify.com/docs.json", 3 | "theme": "mint", 4 | "name": "Docs", 5 | "colors": { 6 | "primary": "#ea580c", 7 | "light": "#F47150", 8 | "dark": "#333333" 9 | }, 10 | "background": { 11 | "color": { 12 | "light": "#e3e3e8", 13 | "dark": "#09090b" 14 | } 15 | }, 16 | "navigation": { 17 | "groups": [ 18 | { 19 | "group": "Getting Started", 20 | "pages": ["intro", "install"] 21 | }, 22 | { 23 | "group": "Usage", 24 | "pages": [ 25 | "topics/workflow", 26 | "topics/pipelines", 27 | "topics/prompt-functions", 28 | "topics/chats-and-messages", 29 | "topics/generators", 30 | "topics/data-models", 31 | "topics/tools", 32 | "topics/transforms", 33 | "topics/message-slicing", 34 | "topics/tokenization", 35 | "topics/iterating-and-batching", 36 | "topics/tracing", 37 | "topics/completions", 38 | "topics/migrations", 39 | "topics/serialization", 40 | "topics/logging" 41 | ] 42 | }, 43 | { 44 | "group": "API", 45 | "pages": [ 46 | "api/chat", 47 | "api/completion", 48 | "api/generator", 49 | "api/model", 50 | "api/message", 51 | "api/prompt", 52 | "api/tools", 53 | "api/transform", 54 | "api/tokenize", 55 | "api/data", 56 | "api/watchers", 57 | "api/parsing", 58 | "api/interact", 59 | "api/logging", 60 | "api/error", 61 | "api/util" 62 | ] 63 | } 64 | ] 65 | }, 66 | "navbar": { 67 | "links": [ 68 | { 69 | "label": "Home", 70 | "href": "https://docs.dreadnode.io" 71 | }, 72 | { 73 | "label": "Support", 74 | "href": "mailto:support@dreadnode.io" 75 | }, 76 | { 77 | "label": "Blog", 78 | "href": "https://dreadnode.io/blog" 79 | } 80 | ], 81 | "primary": { 82 | "type": "button", 83 | "label": "Platform", 84 | "href": "https://platform.dreadnode.io" 85 | } 86 | }, 87 | "footer": { 88 | "socials": { 89 | "x": "https://x.com/dreadnode", 90 | "github": "https://github.com/dreadnode", 91 | "linkedin": "https://linkedin.com/company/dreadnode" 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /.github/workflows/template-sync.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Template Sync 3 | on: 4 | # checkov:skip=CKV_GHA_7: "Workflow dispatch inputs are required for manual debugging and configuration" 5 | workflow_dispatch: 6 | inputs: 7 | dryRun: 8 | description: Dry Run 9 | default: "false" 10 | required: false 11 | logLevel: 12 | description: Log Level 13 | default: "debug" 14 | required: false 15 | 16 | schedule: 17 | # Run on the 1st of every month at 00:00 UTC 18 | - cron: "0 0 1 * *" 19 | 20 | push: 21 | branches: ["main"] 22 | paths: 23 | - ".github/**" 24 | - ".hooks/**" 25 | - ".pre-commit-config.yaml" 26 | - ".mdlrc" 27 | - ".editorconfig" 28 | - "Taskfile.yaml" 29 | - ".task/**" 30 | 31 | permissions: 32 | contents: write 33 | pull-requests: write 34 | 35 | concurrency: 36 | group: ${{ github.workflow }}-${{ github.run_number || github.ref }} 37 | cancel-in-progress: true 38 | 39 | jobs: 40 | template-sync: 41 | name: Template Sync 42 | runs-on: ubuntu-latest 43 | steps: 44 | - name: Generate Token 45 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 46 | id: app-token 47 | with: 48 | app-id: "${{ secrets.BOT_APP_ID }}" 49 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" 50 | owner: "${{ github.repository_owner }}" 51 | 52 | - name: Checkout 53 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 54 | with: 55 | token: "${{ steps.app-token.outputs.token }}" 56 | 57 | - name: Template Sync 58 | uses: AndreasAugustin/actions-template-sync@v2 59 | with: 60 | source_gh_token: ${{ steps.app-token.outputs.token }} 61 | git_user_name: github-actions[bot] 62 | git_user_email: github-actions[bot]@users.noreply.github.com 63 | pr_title: "chore: sync infrastructure files with template" 64 | pr_labels: sync,template 65 | pr_body: | 66 | 🤖 A new version of the python template files is available. 67 | 68 | This PR was automatically created to sync the following: 69 | - GitHub Actions workflows 70 | - Pre-commit hooks and configs 71 | - Task definitions 72 | - Editor configs and linter rules 73 | 74 | Please review the changes carefully before merging. 75 | source_repo_path: dreadnode/python-template 76 | steps: "prechecks,pull,commit,push,pr" 77 | upstream_branch: main 78 | -------------------------------------------------------------------------------- /.hooks/generate_pr_description.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.10" 3 | # dependencies = [ 4 | # "rigging", 5 | # "typer", 6 | # ] 7 | # /// 8 | 9 | import asyncio 10 | import subprocess 11 | import typing as t 12 | 13 | import typer 14 | 15 | import rigging as rg 16 | 17 | TRUNCATION_WARNING = "\n---\n**Note**: Due to the large size of this diff, some content has been truncated." 18 | 19 | 20 | @rg.prompt 21 | def generate_pr_description(diff: str) -> t.Annotated[str, rg.Ctx("markdown")]: # type: ignore[empty-body] 22 | """ 23 | Analyze the provided git diff and create a PR description in markdown format. 24 | 25 | <guidance> 26 | - Keep the summary concise and informative. 27 | - Use bullet points to structure important statements. 28 | - Focus on key modifications and potential impact - if any. 29 | - Do not add in general advice or best-practice information. 30 | - Write like a developer who authored the changes. 31 | - Prefer flat bullet lists over nested. 32 | - Do not include any title structure. 33 | - If there are no changes, just provide "No relevant changes." 34 | - Order your bullet points by importance. 35 | </guidance> 36 | """ 37 | 38 | 39 | def get_diff(base_ref: str, source_ref: str, *, exclude: list[str] | None = None) -> str: 40 | """ 41 | Get the git diff between two branches. 42 | """ 43 | 44 | merge_base = subprocess.run( 45 | ["git", "merge-base", source_ref, base_ref], 46 | capture_output=True, 47 | text=True, 48 | check=True, 49 | ).stdout.strip() 50 | 51 | diff_command = ["git", "diff", "--no-color", merge_base, source_ref] 52 | if exclude: 53 | diff_command.extend(["--", ".", *[f":(exclude){path}" for path in exclude]]) 54 | 55 | diff_text = subprocess.run( 56 | diff_command, 57 | capture_output=True, 58 | text=True, 59 | check=True, 60 | ).stdout 61 | 62 | return diff_text 63 | 64 | 65 | def main( 66 | base_ref: str = "origin/main", 67 | source_ref: str = "HEAD", 68 | generator_id: str = "openai/o3-mini", 69 | max_diff_lines: int = 10_000, 70 | exclude: list[str] | None = None, 71 | ) -> None: 72 | """ 73 | Use rigging to generate a PR description from a git diff. 74 | """ 75 | 76 | diff = get_diff(base_ref, source_ref, exclude=exclude) 77 | diff_lines = diff.split("\n") 78 | if len(diff_lines) > max_diff_lines: 79 | diff = "\n".join(diff_lines[:max_diff_lines]) + TRUNCATION_WARNING 80 | 81 | description = asyncio.run(generate_pr_description.bind(generator_id)(diff)) 82 | 83 | print(description) 84 | 85 | 86 | if __name__ == "__main__": 87 | typer.run(main) 88 | -------------------------------------------------------------------------------- /docs/api/logging.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: rigging.logging 3 | --- 4 | 5 | {/* 6 | ::: rigging.logging 7 | */} 8 | 9 | We use loguru for logging. This module provides a function to configure logging handlers. 10 | 11 | To just enable rigging logs to flow, call `logger.enable("rigging")` after importing the module. 12 | 13 | LogLevelLiteral 14 | --------------- 15 | 16 | ```python 17 | LogLevelLiteral = Literal[ 18 | "trace", 19 | "debug", 20 | "info", 21 | "success", 22 | "warning", 23 | "error", 24 | "critical", 25 | ] 26 | ``` 27 | 28 | Valid logging levels. 29 | 30 | configure\_logging 31 | ------------------ 32 | 33 | ```python 34 | configure_logging( 35 | log_level: LogLevelLiteral = "info", 36 | log_file: Path | None = None, 37 | log_file_level: LogLevelLiteral = "debug", 38 | ) -> None 39 | ``` 40 | 41 | Configures common loguru handlers. 42 | 43 | **Parameters:** 44 | 45 | * **`log_level`** 46 | (`LogLevelLiteral`, default: 47 | `'info'` 48 | ) 49 | –The desired log level. 50 | * **`log_file`** 51 | (`Path | None`, default: 52 | `None` 53 | ) 54 | –The path to the log file. If None, logging 55 | will only be done to the console. 56 | * **`log_file_level`** 57 | (`LogLevelLiteral`, default: 58 | `'debug'` 59 | ) 60 | –The log level for the log file. 61 | 62 | <Accordion title="Source code in rigging/logging.py" icon="code"> 63 | ```python 64 | def configure_logging( 65 | log_level: LogLevelLiteral = "info", 66 | log_file: pathlib.Path | None = None, 67 | log_file_level: LogLevelLiteral = "debug", 68 | ) -> None: 69 | """ 70 | Configures common loguru handlers. 71 | 72 | Args: 73 | log_level: The desired log level. 74 | log_file: The path to the log file. If None, logging 75 | will only be done to the console. 76 | log_file_level: The log level for the log file. 77 | """ 78 | global g_configured # noqa: PLW0603 79 | 80 | if g_configured: 81 | return 82 | 83 | logger.enable("rigging") 84 | 85 | logger.level("TRACE", color="<magenta>", icon="[T]") 86 | logger.level("DEBUG", color="<blue>", icon="[_]") 87 | logger.level("INFO", color="<cyan>", icon="[=]") 88 | logger.level("SUCCESS", color="<green>", icon="[+]") 89 | logger.level("WARNING", color="<yellow>", icon="[-]") 90 | logger.level("ERROR", color="<red>", icon="[!]") 91 | logger.level("CRITICAL", color="<RED>", icon="[x]") 92 | 93 | custom_format = "<green>{time:HH:mm:ss.SSS}</green> | <level>{level.icon}</level> {message}" 94 | 95 | logger.remove() 96 | logger.add(sys.stderr, format=custom_format, level=log_level.upper()) 97 | 98 | if log_file is not None: 99 | logger.add(log_file, format=custom_format, level=log_file_level.upper()) 100 | logger.info(f"Logging to {log_file}") 101 | 102 | g_configured = True 103 | ``` 104 | 105 | 106 | </Accordion> -------------------------------------------------------------------------------- /tests/generators.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from rigging import Message 4 | from rigging.generator import GenerateParams, Generator 5 | from rigging.generator.base import GeneratedMessage, GeneratedText 6 | 7 | # ruff: noqa: S101, ARG002 8 | 9 | 10 | class FixedGenerator(Generator): 11 | text: str 12 | 13 | async def generate_messages( 14 | self, 15 | messages: t.Sequence[t.Sequence[Message]], 16 | params: t.Sequence[GenerateParams], 17 | ) -> t.Sequence[GeneratedMessage]: 18 | return [GeneratedMessage.from_text(self.text, stop_reason="stop") for _ in messages] 19 | 20 | async def generate_texts( 21 | self, 22 | texts: t.Sequence[str], 23 | params: t.Sequence[GenerateParams], 24 | ) -> t.Sequence[GeneratedText]: 25 | return [GeneratedText.from_text(self.text, stop_reason="stop") for _ in texts] 26 | 27 | 28 | class EchoGenerator(Generator): 29 | async def generate_messages( 30 | self, 31 | messages: t.Sequence[t.Sequence[Message]], 32 | params: t.Sequence[GenerateParams], 33 | ) -> t.Sequence[GeneratedMessage]: 34 | return [GeneratedMessage.from_text(m[-1].content, stop_reason="stop") for m in messages] 35 | 36 | async def generate_texts( 37 | self, 38 | texts: t.Sequence[str], 39 | params: t.Sequence[GenerateParams], 40 | ) -> t.Sequence[GeneratedText]: 41 | return [GeneratedText.from_text(t, stop_reason="stop") for t in texts] 42 | 43 | 44 | class CallbackGenerator(Generator): 45 | message_callback: t.Callable[["CallbackGenerator", t.Sequence[Message]], str] | None = None 46 | text_callback: t.Callable[["CallbackGenerator", str], str] | None = None 47 | 48 | async def generate_messages( 49 | self, 50 | messages: t.Sequence[t.Sequence[Message]], 51 | params: t.Sequence[GenerateParams], 52 | ) -> t.Sequence[GeneratedMessage]: 53 | assert self.message_callback is not None 54 | return [ 55 | GeneratedMessage.from_text(self.message_callback(self, m), stop_reason="stop") 56 | for m in messages 57 | ] 58 | 59 | async def generate_texts( 60 | self, 61 | texts: t.Sequence[str], 62 | params: t.Sequence[GenerateParams], 63 | ) -> t.Sequence[GeneratedText]: 64 | assert len(texts) == 1 65 | assert self.text_callback is not None 66 | return [ 67 | GeneratedText.from_text(self.text_callback(self, text), stop_reason="stop") 68 | for text in texts 69 | ] 70 | 71 | 72 | class FailingGenerator(Generator): 73 | _exception: Exception = RuntimeError("Intentional failure") 74 | 75 | async def generate_messages( 76 | self, 77 | messages: t.Sequence[t.Sequence[Message]], 78 | params: t.Sequence[GenerateParams], 79 | ) -> t.Sequence[GeneratedMessage]: 80 | raise self._exception 81 | 82 | async def generate_texts( 83 | self, 84 | texts: t.Sequence[str], 85 | params: t.Sequence[GenerateParams], 86 | ) -> t.Sequence[GeneratedText]: 87 | raise self._exception 88 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v6.0.0 5 | hooks: 6 | - id: check-added-large-files 7 | args: [--maxkb=10240] 8 | - id: check-case-conflict 9 | - id: check-merge-conflict 10 | - id: check-executables-have-shebangs 11 | - id: check-json 12 | - id: check-shebang-scripts-are-executable 13 | - id: check-symlinks 14 | - id: check-yaml 15 | - id: detect-private-key 16 | - id: end-of-file-fixer 17 | exclude: ^docs/ 18 | - id: trailing-whitespace 19 | 20 | - repo: https://github.com/rhysd/actionlint 21 | rev: v1.7.9 22 | hooks: 23 | - id: actionlint 24 | 25 | - repo: https://github.com/adrienverge/yamllint.git 26 | rev: v1.37.1 27 | hooks: 28 | - id: yamllint 29 | entry: yamllint --strict -c .hooks/linters/yamllint.yaml 30 | 31 | - repo: https://github.com/codespell-project/codespell 32 | rev: v2.4.1 33 | hooks: 34 | - id: codespell 35 | entry: codespell -q 3 -f --skip=".git,.github" -L te,Sie,braket,astroid 36 | 37 | # Python code security 38 | - repo: https://github.com/PyCQA/bandit 39 | rev: 1.9.2 40 | hooks: 41 | - id: bandit 42 | name: Code security checks 43 | args: ["-c", "pyproject.toml"] 44 | additional_dependencies: ["bandit[toml]"] 45 | exclude: ^tests/ 46 | 47 | - repo: https://github.com/Yelp/detect-secrets 48 | rev: v1.5.0 49 | hooks: 50 | - id: detect-secrets 51 | args: ["--baseline", ".secrets.baseline", "--exclude-files", "examples/*"] 52 | exclude: .secrets.baseline 53 | 54 | # Clean jupyter notebook outputs 55 | - repo: https://github.com/kynan/nbstripout 56 | rev: 0.8.2 57 | hooks: 58 | - id: nbstripout 59 | args: [--keep-id] 60 | 61 | # - repo: https://github.com/astral-sh/ruff-pre-commit 62 | # rev: v0.11.7 63 | # hooks: 64 | # - id: ruff 65 | # args: [--fix] 66 | # - id: ruff-format 67 | 68 | # - repo: https://github.com/pre-commit/mirrors-mypy 69 | # rev: v1.15.0 70 | # hooks: 71 | # - id: mypy 72 | # additional_dependencies: 73 | # - "pydantic" 74 | # - "types-PyYAML" 75 | # - "types-requests" 76 | # - "types-setuptools" 77 | 78 | - repo: local 79 | hooks: 80 | # Ensure our GH actions are pinned to a specific hash 81 | - id: check-github-actions 82 | name: Check GitHub Actions for Pinned Dependencies 83 | entry: .hooks/check_pinned_hash_dependencies.py 84 | language: python 85 | files: \.github/.*\.yml$ 86 | 87 | - id: prettier 88 | name: Run prettier 89 | entry: .hooks/prettier.sh 90 | language: script 91 | types: [json, yaml] 92 | 93 | # Generate documentation 94 | - id: generate-docs 95 | name: Generate docs 96 | entry: poetry run python .hooks/generate_docs.py 97 | language: system 98 | pass_filenames: false 99 | always_run: true 100 | -------------------------------------------------------------------------------- /rigging/__init__.py: -------------------------------------------------------------------------------- 1 | from rigging import ( 2 | caching, 3 | data, 4 | error, 5 | generator, 6 | logging, 7 | model, 8 | parsing, 9 | tokenizer, 10 | tools, 11 | transform, 12 | watchers, 13 | ) 14 | from rigging.chat import ( 15 | Chat, 16 | ChatPipeline, 17 | MapChatCallback, 18 | PipelineStep, 19 | PipelineStepContextManager, 20 | PipelineStepGenerator, 21 | ThenChatCallback, 22 | ) 23 | from rigging.completion import ( 24 | Completion, 25 | CompletionPipeline, 26 | MapCompletionCallback, 27 | ThenCompletionCallback, 28 | ) 29 | from rigging.error import Stop 30 | from rigging.generator import ( 31 | GeneratedMessage, 32 | GeneratedText, 33 | GenerateParams, 34 | Generator, 35 | chat, 36 | complete, 37 | get_generator, 38 | register_generator, 39 | ) 40 | from rigging.interact import interact 41 | from rigging.message import ( 42 | ContentAudioInput, 43 | ContentImageUrl, 44 | ContentText, 45 | Message, 46 | MessageDict, 47 | Messages, 48 | MessageSlice, 49 | ) 50 | from rigging.model import Model, attr, element, wrapped 51 | from rigging.prompt import Ctx, Prompt, prompt 52 | from rigging.tokenizer import TokenizedChat, Tokenizer, get_tokenizer, register_tokenizer 53 | from rigging.tools import Tool, as_mcp, mcp, robopages, tool, tool_method 54 | from rigging.transform import PostTransform, Transform 55 | from rigging.util import await_ 56 | from rigging.version import VERSION 57 | 58 | __version__ = VERSION 59 | 60 | __all__ = [ 61 | "Chat", 62 | "ChatFormatter", 63 | "ChatPipeline", 64 | "Completion", 65 | "CompletionPipeline", 66 | "ContentAudioInput", 67 | "ContentImageUrl", 68 | "ContentText", 69 | "Ctx", 70 | "Decoder", 71 | "Encoder", 72 | "GenerateParams", 73 | "GeneratedMessage", 74 | "GeneratedText", 75 | "Generator", 76 | "MapChatCallback", 77 | "MapCompletionCallback", 78 | "Message", 79 | "MessageDict", 80 | "MessageSlice", 81 | "Messages", 82 | "Model", 83 | "PipelineStep", 84 | "PipelineStepContextManager", 85 | "PipelineStepGenerator", 86 | "PostTransform", 87 | "Prompt", 88 | "Stop", 89 | "ThenChatCallback", 90 | "ThenCompletionCallback", 91 | "TokenSlice", 92 | "TokenizedChat", 93 | "Tokenizer", 94 | "Tool", 95 | "Transform", 96 | "as_mcp", 97 | "attr", 98 | "await_", 99 | "caching", 100 | "chat", 101 | "complete", 102 | "data", 103 | "element", 104 | "error", 105 | "find_in_tokens", 106 | "generator", 107 | "get_generator", 108 | "get_tokenizer", 109 | "interact", 110 | "logging", 111 | "mcp", 112 | "model", 113 | "parsing", 114 | "prompt", 115 | "register_generator", 116 | "register_tokenizer", 117 | "robopages", 118 | "tokenizer", 119 | "tool", 120 | "tool_method", 121 | "tools", 122 | "transform", 123 | "watchers", 124 | "wrapped", 125 | ] 126 | 127 | from loguru import logger 128 | 129 | logger.disable("rigging") 130 | -------------------------------------------------------------------------------- /.github/workflows/renovate.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Renovate 3 | on: 4 | # checkov:skip=CKV_GHA_7: "Workflow dispatch inputs are required for manual debugging and configuration" 5 | workflow_dispatch: 6 | inputs: 7 | dryRun: 8 | description: Dry Run 9 | default: "false" 10 | required: false 11 | logLevel: 12 | description: Log Level 13 | default: "debug" 14 | required: false 15 | version: 16 | description: Renovate version 17 | default: latest 18 | required: false 19 | schedule: 20 | # Run every evening at 20:00 UTC (8:00 PM UTC) 21 | - cron: "0 20 * * *" 22 | push: 23 | branches: ["main"] 24 | paths: 25 | - .github/renovate.json5 26 | - .github/renovate/**.json5 27 | 28 | permissions: 29 | contents: read 30 | pull-requests: write 31 | issues: write 32 | 33 | concurrency: 34 | group: ${{ github.workflow }}-${{ github.run_number || github.ref }} 35 | cancel-in-progress: true 36 | 37 | # Retrieve BOT_USER_ID via `curl -s "https://api.github.com/users/${BOT_USERNAME}%5Bbot%5D" | jq .id` 38 | env: 39 | WORKFLOW_DRY_RUN: false 40 | WORKFLOW_LOG_LEVEL: debug 41 | WORKFLOW_VERSION: latest # 37.59.8 42 | RENOVATE_PLATFORM: github 43 | RENOVATE_PLATFORM_COMMIT: true 44 | RENOVATE_ONBOARDING_CONFIG_FILE_NAME: .github/renovate.json5 45 | RENOVATE_AUTODISCOVER: true 46 | RENOVATE_AUTODISCOVER_FILTER: "${{ github.repository }}" 47 | RENOVATE_GIT_AUTHOR: "${{ secrets.BOT_USERNAME }} <${{ secrets.BOT_USER_ID }}+${{ secrets.BOT_USERNAME }}[bot]@users.noreply.github.com>" 48 | 49 | jobs: 50 | renovate: 51 | name: Renovate 52 | runs-on: ubuntu-latest 53 | steps: 54 | - name: Generate Token 55 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 56 | id: app-token 57 | with: 58 | app-id: "${{ secrets.BOT_APP_ID }}" 59 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" 60 | 61 | - name: Checkout 62 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 63 | with: 64 | token: "${{ steps.app-token.outputs.token }}" 65 | 66 | - name: Override default config from dispatch variables 67 | run: | 68 | echo "RENOVATE_DRY_RUN=${{ github.event.inputs.dryRun || env.WORKFLOW_DRY_RUN }}" >> "${GITHUB_ENV}" 69 | echo "LOG_LEVEL=${{ github.event.inputs.logLevel || env.WORKFLOW_LOG_LEVEL }}" >> "${GITHUB_ENV}" 70 | 71 | - name: Delete old dashboard 72 | run: | 73 | ISSUE_NUMBER=$(gh issue list -S 'Renovate Dashboard 🤖' --json number -q '.[0].number') 74 | if [ "$ISSUE_NUMBER" != "null" ] && [ -n "$ISSUE_NUMBER" ]; then 75 | gh issue close "$ISSUE_NUMBER" 76 | else 77 | echo "No issue found to close." 78 | fi 79 | env: 80 | GITHUB_TOKEN: "${{ steps.app-token.outputs.token }}" 81 | 82 | - name: Renovate 83 | uses: renovatebot/github-action@822441559e94f98b67b82d97ab89fe3003b0a247 # v44.2.0 84 | with: 85 | configurationFile: "${{ env.RENOVATE_ONBOARDING_CONFIG_FILE_NAME }}" 86 | token: "${{ steps.app-token.outputs.token }}" 87 | renovate-version: "${{ github.event.inputs.version || env.WORKFLOW_VERSION }}" 88 | -------------------------------------------------------------------------------- /rigging/tools/robopages.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for integrating tools from a Robopages server. 3 | """ 4 | 5 | import re 6 | import typing as t 7 | 8 | import httpx 9 | import requests 10 | from loguru import logger 11 | from pydantic import TypeAdapter 12 | 13 | from rigging.tools.base import Tool, ToolDefinition 14 | 15 | DEFAULT_HTTP_TIMEOUT = 10 16 | 17 | 18 | def make_execute_on_server(url: str, tool_name: str) -> t.Callable[..., t.Any]: 19 | async def execute_on_server(**kwargs: t.Any) -> t.Any: 20 | async with httpx.AsyncClient() as client: 21 | response = await client.post( 22 | f"{url}/process", 23 | json=[ 24 | { 25 | "type": "function", 26 | "function": { 27 | "name": tool_name, 28 | "arguments": kwargs, 29 | }, 30 | }, 31 | ], 32 | ) 33 | if response.status_code not in [200, 400]: 34 | response.raise_for_status() 35 | 36 | if response.status_code == 400: # noqa: PLR2004 37 | result = response.content.decode() 38 | else: 39 | result = response.json()[0]["content"] 40 | 41 | return result 42 | 43 | return execute_on_server 44 | 45 | 46 | def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.Any]]: 47 | """ 48 | Create a list of tools from a Robopages server. 49 | 50 | Args: 51 | url: The URL of the Robopages server. 52 | name_filter: A regular expression to filter the tools by name. 53 | 54 | Returns: 55 | A list of integrated tools which leverage the Robopages server. 56 | 57 | Example: 58 | ``` 59 | import rigging as rg 60 | 61 | tools = rg.tool.robopages("http://localhost:8080") 62 | 63 | chat = ( 64 | await rg.get_generator('gpt-4o') 65 | .chat('Please use tools') 66 | .using(*tools) 67 | .run() 68 | ) 69 | 70 | print(chat.conversation) 71 | ``` 72 | """ 73 | 74 | filter_regex = re.compile(name_filter) if name_filter else None 75 | 76 | response = requests.get(url, params={"flavor": "openai"}, timeout=DEFAULT_HTTP_TIMEOUT) 77 | response.raise_for_status() 78 | tools_data = response.json() 79 | 80 | adapter = TypeAdapter(list[ToolDefinition]) 81 | tool_definitions = adapter.validate_python(tools_data) 82 | 83 | logger.info(f"Fetched {len(tool_definitions)} functions from Robopages ({url})") 84 | 85 | tools: list[Tool[..., t.Any]] = [] 86 | for definition in tool_definitions: 87 | function = definition.function 88 | 89 | if filter_regex and not filter_regex.search(function.name): 90 | logger.debug(f"Skipping function {function.name}") 91 | continue 92 | 93 | tools.append( 94 | Tool( 95 | name=function.name, 96 | description=function.description or "", 97 | parameters_schema=function.parameters or {}, 98 | fn=make_execute_on_server(url, function.name), 99 | ), 100 | ) 101 | 102 | return tools 103 | -------------------------------------------------------------------------------- /tests/test_completion_pipeline.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rigging.completion import Completion, CompletionPipeline 4 | from rigging.generator import GenerateParams, get_generator 5 | 6 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001 7 | 8 | 9 | def test_completion_generator_id() -> None: 10 | generator = get_generator("gpt-3.5") 11 | completion = Completion("foo", "bar", generator) 12 | assert completion.generator_id == "gpt-3.5" 13 | 14 | completion.generator = None 15 | assert completion.generator_id is None 16 | 17 | 18 | def test_completion_properties() -> None: 19 | generator = get_generator("gpt-3.5") 20 | completion = Completion("foo", "bar", generator) 21 | assert completion.text == "foo" 22 | assert completion.generated == "bar" 23 | assert completion.generator == generator 24 | assert len(completion) == len("foo") + len("bar") 25 | assert completion.all == "foobar" 26 | 27 | 28 | def test_completion_restart() -> None: 29 | generator = get_generator("gpt-3.5") 30 | completion = Completion("foo", "bar", generator) 31 | assert len(completion.restart()) == 3 32 | assert len(completion.restart(include_all=True)) == 6 33 | 34 | assert len(completion.fork("baz")) == 6 35 | assert len(completion.continue_("baz")) == 9 36 | 37 | completion.generator = None 38 | with pytest.raises(ValueError): 39 | completion.restart() 40 | 41 | 42 | def test_completion_clone() -> None: 43 | generator = get_generator("gpt-3.5") 44 | original = Completion("foo", "bar", generator).meta(key="value") 45 | clone = original.clone() 46 | assert clone.text == original.text 47 | assert clone.generated == original.generated 48 | assert clone.metadata == original.metadata 49 | 50 | clone_2 = original.clone(only_messages=True) 51 | assert clone.metadata != clone_2.metadata 52 | 53 | 54 | def test_completion_pipeline_with() -> None: 55 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "foo") 56 | with_pipeline = pipeline.with_(GenerateParams(max_tokens=123)) 57 | assert with_pipeline == pipeline 58 | assert with_pipeline.params is not None 59 | assert with_pipeline.params.max_tokens == 123 60 | 61 | with_pipeline_2 = with_pipeline.with_(top_p=0.5) 62 | assert with_pipeline_2 != with_pipeline 63 | assert with_pipeline_2.params is not None 64 | assert with_pipeline_2.params.max_tokens == 123 65 | assert with_pipeline_2.params.top_p == 0.5 66 | 67 | 68 | def test_completion_pipeline_fork() -> None: 69 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "foo") 70 | forked_1 = pipeline.fork("bar") 71 | forked_2 = pipeline.fork("baz") 72 | 73 | assert pipeline != forked_1 != forked_2 74 | assert pipeline.text == "foo" 75 | assert forked_1.text == "foobar" 76 | assert forked_2.text == "foobaz" 77 | 78 | 79 | def test_completion_pipeline_meta() -> None: 80 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "foo") 81 | with_meta = pipeline.meta(key="value") 82 | assert with_meta == pipeline 83 | assert with_meta.metadata == {"key": "value"} 84 | 85 | 86 | def test_completion_pipeline_apply() -> None: 87 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "Hello $name") 88 | applied = pipeline.apply(name="World", noexist="123") 89 | assert pipeline != applied 90 | assert pipeline.text == "Hello $name" 91 | assert applied.text == "Hello World" 92 | -------------------------------------------------------------------------------- /docs/topics/completions.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Completions" 3 | description: "Using Rigging for text completions" 4 | public: true 5 | --- 6 | 7 | The majority of Rigging was built around "instruct" or "chat" LLM interfaces where a base model has been tuned to work with a structured layer on top of raw text completion. We typically find that base models are more unpredictable with their outputs, tend to be more sensitive to small changes in their context windows, and require frequent use of stop tokens to prevent unnecessary generation. 8 | 9 | However, there are some places where completing raw text and working with base models might be desirable: 10 | 11 | - Fewer restrictions on the types of content they will generate 12 | - Speeding up generation and lowering token usage by discouraging verbose responses 13 | - Leveraging prompts from popular libraries like [LangChain](https://python.langchain.com/) which assume 14 | a completions-style interface 15 | 16 | ## Interface Parity 17 | 18 | While we try to maintain parity between the "Chat" and "Completions" interfaces in Rigging, you'll 19 | find some deviations here and there. Completions should be a simple transition if you are familiar 20 | with the other code in rigging. Here are the highlights: 21 | 22 | - `chat` -> `complete` 23 | - `Chat` -> `Completion` 24 | - `ChatPipeline` -> `CompletionPipeline` 25 | - `generate_messages` -> `generate_texts` 26 | 27 | On all of these interfaces, you'll note that sequences of `Message` objects have been 28 | replaced with basic `str` objects for both inputs and outputs. 29 | 30 | ## Translator Example 31 | 32 | Let's build a simply translator object that we can store as a `CompletionPipeline` 33 | and use it quickly translate a phrase to 3 different languages. 34 | 35 | ```python {13, 15} 36 | PROMPT = """\ 37 | As an expert translator, you accept english text and translate it to $language. 38 | 39 | # Format 40 | 41 | Input: [english text] 42 | Output: [translated text] 43 | --- 44 | 45 | Input: $input 46 | Output: """ 47 | 48 | translator = ( 49 | rg.get_generator('gpt-3.5-turbo') 50 | .complete(PROMPT) 51 | .with_(stop=["---", "Input:", "\n\n"]) 52 | ) 53 | 54 | text = "Could you please tell me where the nearest train station is?" 55 | 56 | for language in ["spanish", "french", "german"]: 57 | completion = await translator.apply( 58 | language=language, 59 | input=text 60 | ).run() 61 | print(f"[{language}]: {completion.generated}") 62 | 63 | # [spanish]: ¿Podría decirme por favor dónde está la estación de tren más cercana? 64 | # [french]: Pouvez-vous me dire où se trouve la gare la plus proche, s'il vous plaît ? 65 | # [german]: Könnten Sie mir bitte sagen, wo sich der nächste Bahnhof befindet? 66 | ``` 67 | 68 | 1. OpenAPI supports the same model IDs for both completions and chats, but other providers might require you to specify a specific model ID used for text completions. 69 | 2. We use `.with_()` to set stop tokens and prevent the generation from simply continuing until our max tokens are reached. This is a very common and often required pattern when doing completions over chats. Here, we aren't totally sure what the model might generate after our translation, so we use a few different token sequences to be safe. 70 | 71 | <Tip> 72 | **Using .apply()** 73 | 74 | Text completion is a great place to use the `.apply` method as we can easily slot in our inputs without using `.add` and following it with our output section of the prompt. 75 | </Tip> 76 | -------------------------------------------------------------------------------- /rigging/tokenizer/transformers_.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from pydantic import Field 4 | from transformers import AutoTokenizer # type: ignore [import-not-found, unused-ignore] 5 | 6 | from rigging.tokenizer.base import Tokenizer 7 | 8 | if t.TYPE_CHECKING: 9 | from transformers.tokenization_utils import ( # type: ignore [import-not-found, unused-ignore] 10 | PreTrainedTokenizer, 11 | ) 12 | 13 | from rigging.chat import Chat 14 | 15 | 16 | class TransformersTokenizer(Tokenizer): 17 | """ 18 | A tokenizer implementation using Hugging Face Transformers. 19 | 20 | This class provides tokenization capabilities for chat conversations 21 | using transformers models and their associated tokenizers. 22 | """ 23 | 24 | apply_chat_template_kwargs: dict[str, t.Any] = Field(default_factory=dict) 25 | """Additional keyword arguments for applying the chat template.""" 26 | 27 | encode_kwargs: dict[str, t.Any] = Field(default_factory=dict) 28 | """Additional keyword arguments for encoding text.""" 29 | 30 | decode_kwargs: dict[str, t.Any] = Field(default_factory=dict) 31 | """Additional keyword arguments for decoding tokens.""" 32 | 33 | _tokenizer: "PreTrainedTokenizer | None" = None 34 | 35 | @property 36 | def tokenizer(self) -> "PreTrainedTokenizer": 37 | """The underlying `PreTrainedTokenizer` instance.""" 38 | if self._tokenizer is None: 39 | self._tokenizer = AutoTokenizer.from_pretrained(self.model) # type: ignore[no-untyped-call] # nosec 40 | return self._tokenizer 41 | 42 | @classmethod 43 | def from_obj(cls, tokenizer: "PreTrainedTokenizer") -> "TransformersTokenizer": 44 | """ 45 | Create a new instance of TransformersTokenizer from an already loaded tokenizer. 46 | 47 | Args: 48 | tokenizer: The tokenizer associated with the model. 49 | 50 | Returns: 51 | The TransformersTokenizer instance. 52 | """ 53 | return cls(model=str(tokenizer), _tokenizer=tokenizer) 54 | 55 | def encode(self, text: str) -> list[int]: 56 | """ 57 | Encodes the given text into a list of tokens. 58 | 59 | Args: 60 | text: The text to encode. 61 | 62 | Returns: 63 | A list of tokens representing the encoded text. 64 | """ 65 | return self.tokenizer.encode(text, **self.encode_kwargs) # type: ignore [no-any-return] 66 | 67 | def decode(self, tokens: list[int]) -> str: 68 | decode_kwargs = { 69 | "clean_up_tokenization_spaces": False, 70 | **self.decode_kwargs, 71 | } 72 | return self.tokenizer.decode(tokens, **decode_kwargs) # type: ignore [no-any-return, unused-ignore] 73 | 74 | def format_chat(self, chat: "Chat") -> str: 75 | messages = [m.to_openai(compatibility_flags={"content_as_str"}) for m in chat.all] 76 | tools = ( 77 | [tool.model_dump() for tool in chat.params.tools] 78 | if chat.params and chat.params.tools 79 | else None 80 | ) 81 | 82 | apply_chat_template_kwargs = { 83 | "tokenize": False, 84 | **self.apply_chat_template_kwargs, 85 | } 86 | 87 | return str( 88 | self.tokenizer.apply_chat_template( 89 | messages, 90 | tools=tools, # type: ignore [arg-type, unused-ignore] 91 | **apply_chat_template_kwargs, 92 | ), 93 | ) 94 | -------------------------------------------------------------------------------- /.github/labels.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Area Labels 3 | - name: area/docs 4 | color: "72CCF3" # Light Blue 5 | description: >- 6 | Changes to documentation and guides 7 | 8 | - name: area/examples 9 | color: "BC9BE3" # Lavender 10 | description: >- 11 | Changes to example code and demonstrations 12 | 13 | - name: area/github 14 | color: "F4D1B7" # Peach 15 | description: >- 16 | Changes made to GitHub Actions 17 | 18 | - name: area/pre-commit 19 | color: "84B6EB" # Steel Blue 20 | description: >- 21 | Changes made to pre-commit hooks 22 | 23 | - name: area/python 24 | color: "7BD7E0" # Turquoise 25 | description: >- 26 | Changes to Python package configuration and dependencies 27 | 28 | - name: area/security 29 | color: "FF6600" # Orange 30 | description: >- 31 | Changes to security policies and configurations 32 | 33 | - name: area/taskfiles 34 | color: "66CCFF" # Sky Blue 35 | description: >- 36 | Changes made to Taskfiles 37 | 38 | - name: area/tests 39 | color: "99CC00" # Lime Green 40 | description: >- 41 | Changes to test files and testing infrastructure 42 | 43 | - name: area/workspace 44 | color: "FF99CC" # Pink 45 | description: >- 46 | Changes to VSCode workspace configuration 47 | 48 | - name: area/assets 49 | color: "FFA07A" # Light Salmon 50 | description: >- 51 | Changes to asset files 52 | 53 | - name: area/templates 54 | color: "DA70D6" # Orchid 55 | description: >- 56 | Changes to templates 57 | 58 | - name: area/scripts 59 | color: "40E0D0" # Turquoise 60 | description: >- 61 | Changes to script files 62 | 63 | - name: area/src 64 | color: "4682B4" # Steel Blue 65 | description: >- 66 | Changes to source code 67 | 68 | - name: area/ci 69 | color: "FF4500" # Orange Red 70 | description: >- 71 | Changes related to CI/CD configurations 72 | 73 | - name: area/shell 74 | color: "556B2F" # Dark Olive Green 75 | description: >- 76 | Changes to shell scripts 77 | 78 | - name: area/dev 79 | color: "CC6699" # Dusty Rose 80 | description: >- 81 | Changes to development tools and assets 82 | 83 | # Renovate Labels 84 | - name: renovate/container 85 | color: "9933CC" # Purple 86 | description: >- 87 | Docker container updates via Renovate 88 | 89 | - name: renovate/github-action 90 | color: "FF3366" # Hot Pink 91 | description: >- 92 | GitHub Action updates via Renovate 93 | 94 | - name: renovate/github-release 95 | color: "3399FF" # Bright Blue 96 | description: >- 97 | GitHub Release updates via Renovate 98 | 99 | # Semantic Type Labels 100 | - name: type/digest 101 | color: "FF66CC" # Bright Pink 102 | description: >- 103 | Dependency digest updates 104 | 105 | - name: type/patch 106 | color: "FFC300" # Golden Yellow 107 | description: >- 108 | Patch changes (fixes, updates) 109 | 110 | - name: type/minor 111 | color: "FFD700" # Gold 112 | description: >- 113 | Minor changes (features, improvements) 114 | 115 | - name: type/major 116 | color: "F6412D" # Red Orange 117 | description: >- 118 | Major changes 119 | 120 | - name: type/break 121 | color: "FF0000" # Bright Red 122 | description: >- 123 | Breaking changes 124 | 125 | # Documentation Labels 126 | - name: type/docs 127 | color: "0075CA" # Documentation Blue 128 | description: >- 129 | Documentation updates and improvements 130 | 131 | - name: type/core 132 | color: "A2EEEF" # Light Blue 133 | description: >- 134 | Changes to core repository files and configurations 135 | -------------------------------------------------------------------------------- /tests/test_generator_ids.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rigging.error import InvalidGeneratorError 4 | from rigging.generator import ( 5 | GenerateParams, 6 | LiteLLMGenerator, 7 | get_generator, 8 | get_identifier, 9 | register_generator, 10 | ) 11 | 12 | from .generators import EchoGenerator 13 | 14 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001 15 | 16 | 17 | @pytest.mark.parametrize("identifier", ["test_model", "litellm!test_model"]) 18 | def test_get_generator_default_is_litellm(identifier: str) -> None: 19 | generator = get_generator(identifier) 20 | assert isinstance(generator, LiteLLMGenerator) 21 | assert generator.model == "test_model" 22 | 23 | 24 | @pytest.mark.parametrize("identifier", ["invalid!testing", "no_exist!stuff,args=123"]) 25 | def test_get_generator_invalid_provider(identifier: str) -> None: 26 | with pytest.raises(InvalidGeneratorError): 27 | get_generator(identifier) 28 | 29 | 30 | @pytest.mark.parametrize( 31 | ("identifier", "valid_params"), 32 | [ 33 | ("litellm!test_model,max_tokens=123,top_p=10", GenerateParams(max_tokens=123, top_p=10)), 34 | ("litellm!test_model,temperature=0.5", GenerateParams(temperature=0.5)), 35 | ( 36 | "test_model,temperature=1.0,max_tokens=100", 37 | GenerateParams(max_tokens=100, temperature=1.0), 38 | ), 39 | ], 40 | ) 41 | def test_get_generator_with_params(identifier: str, valid_params: GenerateParams) -> None: 42 | generator = get_generator(identifier) 43 | assert isinstance(generator, LiteLLMGenerator) 44 | assert generator.model == "test_model" 45 | assert generator.params == valid_params 46 | 47 | 48 | @pytest.mark.parametrize( 49 | "identifier", 50 | [ 51 | ("test_model,max_tokens=1024,top_p=0.1"), 52 | ("custom,temperature=1.0,max_tokens=100,api_base=https://localhost:8000"), 53 | ("many/model/slashes,stop=a;b;c;"), 54 | ("http!with_cls_args"), 55 | ], 56 | ) 57 | def test_identifier_roundtrip(identifier: str) -> None: 58 | generator = get_generator(identifier) 59 | assert generator.to_identifier() == identifier 60 | 61 | 62 | def test_get_identifier_no_extra() -> None: 63 | generator = get_generator("testing_model,temperature=0.5") 64 | generator.params.extra = {"abc": 123} 65 | identifier = get_identifier(generator) 66 | assert "extra" not in identifier 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "identifier", 71 | ["litellm:invalid,stuff:test,t1/123", "bad:invalid,stuff:test,t1//;;123:"], 72 | ) 73 | def test_get_generator_invalid_structure_format(identifier: str) -> None: 74 | with pytest.raises(InvalidGeneratorError): 75 | get_generator(identifier) 76 | 77 | 78 | @pytest.mark.parametrize( 79 | "identifier", 80 | ["litellm:model,bad_param=123,temperature=1.0", "litellm:model,temperature=True"], 81 | ) 82 | def test_get_generator_invalid_params(identifier: str) -> None: 83 | with pytest.raises(InvalidGeneratorError): 84 | get_generator(identifier) 85 | 86 | 87 | def test_register_generator() -> None: 88 | with pytest.raises(InvalidGeneratorError): 89 | get_generator("echo!test") 90 | 91 | register_generator("echo", EchoGenerator) 92 | generator = get_generator("echo!test") 93 | assert isinstance(generator, EchoGenerator) 94 | 95 | 96 | def test_get_generator_b64() -> None: 97 | generator = get_generator("litellm!test_model,api_key=ZXhhbXBsZXRleHQ=") 98 | assert isinstance(generator, LiteLLMGenerator) 99 | assert generator.model == "test_model" 100 | -------------------------------------------------------------------------------- /docs/topics/tracing.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Tracing" 3 | description: "Trace all rigging internal behaviors with OpenTelemetry" 4 | public: true 5 | --- 6 | 7 | Rigging integrates with the [Logfire](https://logfire.pydantic.dev/docs/) library for exposing tracing information about execution. Specifically we use the logfire-API no-op package, making it optional for users with no overhead if you don't need it. 8 | 9 | Logfire is capable of reporting trace information to any Open Telemetry compatible system, and provides convenient abstractions on top of the standard open-telemetry-SDK. If the `logfire` package is installed and configured, details about pipelines, prompts, and tools will be traced when you use rigging. 10 | 11 | You can configure Logfire to use [alternative backends](https://logfire.pydantic.dev/docs/how-to-guides/alternative-backends/) as needed to integrate with your preferred tracing stack. 12 | 13 | ```python 14 | import rigging as rg 15 | import logfire 16 | 17 | logfire.configure() 18 | 19 | @rg.prompt(generator_id="gpt-4o") 20 | async def summarize(content: str) -> str: 21 | """ 22 | Summarize the content into 1-2 sentences then save it 23 | """ 24 | 25 | summarize.watch(rg.watchers.write_chats_to_jsonl("chats.jsonl")) 26 | 27 | text = """ 28 | Revachol is located on the island of Le Caillou, also called "The Pebble" on the northeast side of the 29 | Insulindian Isola, on the world's largest body of water: the Insulindic. The city itself has a radius 30 | of 80 kilometres and is split by the River Esperance into Revachol East and Revachol West. The north 31 | side of the island is shattered by the delta of the Esperance, and is named La Delta. 32 | """ 33 | 34 | await summarize.run_many(3, text) 35 | 36 | # 23:46:31.484 Prompt summarize() (x3) 37 | # 23:46:31.485 Chat with litellm!gpt-4o (x3) 38 | # 23:46:32.874 Watch with rigging.watchers.write_chats_to_jsonl() 39 | ``` 40 | 41 | <Note> 42 | Rigging will attach call parameters and results for both tools and prompt functions, as well as finalized chat objects at the end of a pipeline. Logfire will serialize these items as JSON values inside attributes, and include a dynamic JSON schema for reference. When using their platform, these items deserialize directly into the web view. 43 | </Note> 44 | 45 | ![Logfire trace](../assets/tracing_logfire.png) 46 | 47 | ## Inference Tracing 48 | 49 | We've opted to exclude tracing at the generator level in Rigging (for now) and focus on instrumenting higher-order abstractions like pipelines and tools, which are specific to the framework. 50 | 51 | There are a suite of powerful instrumentation libraries (including Logfire) which will add tracing to underlying libraries like LiteLLM (recommended), OpenAI, Anthropic, VertexAI, and others. These snap right into the tracing spans from rigging, and provide insight into the raw inference traffic before it's sent to API endpoints and inference libraries. 52 | 53 | - [LiteLLM - Logfire](https://docs.litellm.ai/docs/observability/logfire_integration) 54 | - [LiteLLM - OpenTelemetry](https://docs.litellm.ai/docs/observability/opentelemetry_integration) 55 | - [Logfire - Integrations](https://logfire.pydantic.dev/docs/integrations/) 56 | - [TraceLoop - openllmetry](https://github.com/traceloop/openllmetry) 57 | 58 | Here is an example of adding LiteLLM tracing on top of rigging: 59 | 60 | ```python 61 | import rigging as rg 62 | import logfire 63 | import litellm 64 | 65 | logfire.configure() 66 | 67 | os.environ.setdefault("LOGFIRE_TOKEN", "") # (1)! 68 | litellm.callbacks = ["logfire"] 69 | 70 | # ... 71 | ``` 72 | 73 | *1. As of publication, LiteLLM requires this environment variable, even if it's empty and Logfire is managing tokens for you.* 74 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | - Using welcoming and inclusive language 18 | - Being respectful of differing viewpoints and experiences 19 | - gracefully accepting constructive criticism 20 | - Focusing on what is best for the community 21 | - Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | - The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | - Trolling, insulting/derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at <opensource-conduct@fb.com>. All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 75 | version 1.4, available at <https://www.contributor-covenant.org/version/1/4/code-of-conduct.html> 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | <https://www.contributor-covenant.org/faq> 81 | -------------------------------------------------------------------------------- /rigging/interact.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for interactive chat sessions. 3 | """ 4 | 5 | import asyncio 6 | import itertools 7 | import typing as t 8 | 9 | from colorama import Fore, Style 10 | 11 | from rigging.chat import Chat, ChatPipeline 12 | from rigging.generator.base import Generator, get_generator 13 | 14 | # ruff: noqa: T201 (we use print() here for simplicity) 15 | 16 | 17 | async def _animate( 18 | *, 19 | delay: float = 0.5, 20 | chars: list[str] | None = None, 21 | color: str = Fore.BLUE, 22 | ) -> None: 23 | cycle = itertools.cycle(chars or [" ", ". ", ".. ", "..."]) 24 | while True: 25 | print(f"{color}{next(cycle)}{Style.RESET_ALL}", end="\r", flush=True) 26 | await asyncio.sleep(delay) 27 | 28 | 29 | async def interact( 30 | entrypoint: ChatPipeline | Generator | str, 31 | *, 32 | reset_callback: t.Callable[[Chat | None], None] | None = None, 33 | ) -> Chat | None: 34 | """ 35 | Start an interactive chat session using the given pipeline, generator, or generator id. 36 | 37 | This function allows the user to have a conversation with an assistant by providing input 38 | and receiving responses. The chat session can be controlled using specific commands. 39 | 40 | Args: 41 | entrypoint: A ChatPipeline, Generator, or generator id to use for the chat session. 42 | reset_callback: A callback function to execute when the chat is reset. 43 | 44 | Returns: 45 | The final Chat object, or None if the chat was interrupted before any generation. 46 | """ 47 | 48 | print(f"\n{Fore.YELLOW}") 49 | print("Starting interactive chat.") 50 | print() 51 | print(" - Type 'exit' to quit or use Ctrl+C.") 52 | print(" - Type 'reset' or 'restart' to restart the chat.") 53 | print(" - Type 'again' or 'retry' to re-run the last generation.") 54 | print(Style.RESET_ALL) 55 | 56 | base_pipeline = ( 57 | entrypoint 58 | if isinstance(entrypoint, ChatPipeline) 59 | else entrypoint.chat() 60 | if isinstance(entrypoint, Generator) 61 | else get_generator(entrypoint).chat() 62 | ) 63 | 64 | pipeline = base_pipeline.clone() 65 | chat: Chat | None = None 66 | 67 | while True: 68 | try: 69 | user_input = input(f"\n{Fore.GREEN}User: {Style.RESET_ALL}") 70 | if not user_input: 71 | continue 72 | 73 | if user_input.lower() == "exit": 74 | print(f"\n\n{Fore.YELLOW}Exiting chat.{Style.RESET_ALL}") 75 | break 76 | 77 | if user_input.lower() in ["reset", "restart"]: 78 | print(f"\n{Fore.YELLOW}--- Reset ---{Style.RESET_ALL}") 79 | pipeline = base_pipeline.clone() 80 | if reset_callback: 81 | reset_callback(chat) 82 | continue 83 | 84 | if user_input.lower() in ["again", "retry"]: 85 | print(f"\n{Fore.YELLOW}--- Retry ---{Style.RESET_ALL}") 86 | pipeline.chat.messages = pipeline.chat.messages[:-1] 87 | else: 88 | pipeline.add(user_input) 89 | 90 | print() 91 | 92 | animation_task = asyncio.create_task(_animate()) 93 | chat = await pipeline.run() 94 | animation_task.cancel() 95 | 96 | print(f"\r{Fore.BLUE}Assistant: {Style.RESET_ALL}{chat.last.content}") 97 | 98 | pipeline.add(chat.last) 99 | 100 | except KeyboardInterrupt: 101 | print(f"\n\n{Fore.YELLOW}Chat interrupted. Exiting.{Style.RESET_ALL}") 102 | break 103 | except Exception as e: # noqa: BLE001 104 | print(f"\n\n{Fore.RED}An error occurred: {e!s}{Style.RESET_ALL}") 105 | break 106 | 107 | return chat 108 | -------------------------------------------------------------------------------- /docs/topics/workflow.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Workflow" 3 | description: "How to use Rigging to generate messages." 4 | public: true 5 | --- 6 | 7 | There are two main ways to use Rigging: prompts and pipelines. 8 | 9 | Prompts offer a simple entry point and cover a lot of ground, but pipelines are the core of Rigging’s power—prompts ultimately serve as a way to leverage pipelines. 10 | 11 | ## Using Prompts 12 | 13 | 1. Establish a function decorated with `@prompt` with the inference model you want to use. 14 | 2. Call the function as you normally would and receive structured data back. 15 | 16 | ```python 17 | import rigging as rg 18 | 19 | @rg.prompt(generator_id="claude-3-5-sonnet-latest") 20 | async def get_authors(count: int = 3) -> list[str]: 21 | """Provide famous authors.""" 22 | 23 | print(await get_authors()) 24 | 25 | # ['William Shakespeare', 'J.K. Rowling', 'Jane Austen'] 26 | ``` 27 | 28 | Underneath, Rigging will produce a `Generator` with `get_generator("claude-3-5-sonnet-latest")`, prepare a small template that will establish the required context and output structure, pass it into a new `ChatPipeline`, run the generation process, and parse the output into our structured list with `ChatPipeline.then()`. 29 | 30 | If you want to see the resulting `Chat` object, you can set that as your return value and no output parsing 31 | 32 | ```python 33 | @rg.prompt(generator_id="claude-3-5-sonnet-latest") 34 | async def get_authors(count: int = 3) -> rg.Chat: 35 | ... 36 | ``` 37 | 38 | Now the prompt is only responsible for abstracting the generator, pipeline, and content for you. You can also use a nested object like a `tuple` and include both your structured data and the `Chat` object. 39 | 40 | ```python 41 | @rg.prompt(generator_id="claude-3-5-sonnet-latest") 42 | async def get_authors(count: int = 3) -> tuple[list[str], rg.Chat]: 43 | """Provide famous authors.""" 44 | ``` 45 | 46 | This will return a tuple with the parsed output as the first element and the raw `Chat` object as the second element. 47 | 48 | You can learn more about the `@prompt` decorator in the [Prompt Functions](/topics/prompt-functions) section. 49 | 50 | ## Using Pipelines 51 | 52 | 1. Get a `Generator` object - usually with `get_generator()`. 53 | 2. Call `generator.chat()`to produce a `ChatPipeline` and ready it for generation. 54 | 3. Call `pipeline.run()` to kick off generation and get your final `Chat` object. 55 | 56 | `ChatPipeline` objects hold any messages waiting to be delivered to an LLM in exchange for a new response message. These objects are also where most of the power in rigging comes from. You'll build a generation pipeline with options, parsing, callbacks, etc. After preparation, this pipeline is used to make a final `Chat` which holds all messages prior to generation (`.prev`) and after generation (`.next`). 57 | 58 | You should think of `ChatPipeline` objects like the configurable pre-generation step with calls like `.with_()`, `.apply()`, `.until()`, `.using()`, etc. Once you call one of the many `.run()`functions, the generator is used to produce the next message (or many messages) based on the prior context and any constraints you have in place. Once you have a `Chat` object, the interaction is complete and you can inspect and operate on the messages. 59 | 60 | <Tip> 61 | Rigging supports both Chat objects (messages with roles in a conversation format), as well as raw text completions. While we use Chat objects in most of our examples, you can check out the [Completions](/topics/completions) section to learn more about their feature parity. 62 | </Tip> 63 | 64 | We often use functional styling chaining as most of our utility functions return the object back to you. 65 | 66 | ```python 67 | chat = ( 68 | await 69 | generator.chat(...) 70 | .using(...) # tools 71 | .then(...) # follow up functions 72 | .with_(...) # generation params 73 | .run() 74 | ) 75 | ``` 76 | 77 | Learn more about the `ChatPipeline` object in the [Pipelines](/topics/pipelines) section. 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Testing code 2 | notebooks/ 3 | 4 | # Custom parquet storage 5 | *.parquet 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | 168 | # macos 169 | .DS_Store 170 | .AppleDouble 171 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rigging" 3 | version = "3.3.5" 4 | description = "LLM Interaction Framework" 5 | authors = ["Nick Landers <monoxgas@gmail.com>"] 6 | license = "MIT" 7 | repository = "https://github.com/dreadnode/rigging" 8 | readme = "README.md" 9 | packages = [{ include = "rigging" }] 10 | 11 | # Dependencies 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.10,<3.14" 15 | pydantic = "^2.7.3" 16 | pydantic-xml = "<=2.17.0" 17 | loguru = "^0.7.2" 18 | litellm = "^1.67.2" 19 | xmltodict = "^0.13.0" 20 | colorama = "^0.4.6" 21 | jsonpath-ng = "^1.7.0" 22 | ruamel-yaml = "^0.18.10" 23 | jsonref = "^1.1.0" 24 | mcp = "^1.5.0" 25 | dreadnode = ">=1.12.0" 26 | 27 | vllm = { version = "^0.5.0", optional = true } 28 | transformers = { version = "^4.41.0", optional = true } 29 | accelerate = { version = "^0.30.1", optional = true } 30 | 31 | asyncssh = { version = "^2.14.2", optional = true } 32 | click = { version = "^8.1.7", optional = true } 33 | httpx = { version = "^0.28.0", optional = true } 34 | aiodocker = { version = "^0.22.2", optional = true } 35 | websockets = { version = "^13.0", optional = true } 36 | 37 | elasticsearch = { version = "^8.13.2", optional = true } 38 | pandas = { version = "^2.2.2", optional = true } 39 | 40 | [tool.poetry.extras] 41 | data = ["pandas", "elasticsearch"] 42 | examples = ["asyncssh", "click", "httpx", "aiodocker", "websockets"] 43 | llm = ["vllm", "transformers", "accelerate"] 44 | all = [ 45 | "vllm", 46 | "transformers", 47 | "accelerate", 48 | "asyncssh", 49 | "click", 50 | "httpx", 51 | "aiodocker", 52 | "websockets", 53 | "logfire", 54 | "elasticsearch", 55 | "pandas", 56 | ] 57 | 58 | [tool.poetry.group.dev.dependencies] 59 | ipykernel = "^6.27.1" 60 | mypy = "^1.15.0" 61 | ruff = "^0.10.0" 62 | pytest = "^8.0.0" 63 | pandas-stubs = "^2.2.1.240316" 64 | coverage = "^7.5.1" 65 | ipywidgets = "^8.1.3" 66 | pytest-asyncio = "^1.0.0" 67 | types-colorama = "^0.4.15.20240311" 68 | types-requests = "2.32.4.20250611" 69 | beautifulsoup4 = "^4.13.4" 70 | mkdocstrings = { extras = ["python"], version = "^0.29.1" } 71 | markdown = "^3.8" 72 | markdownify = "^1.1.0" 73 | boto3-stubs = { extras = ["s3"], version = "^1.35.0" } 74 | 75 | # Build 76 | 77 | [build-system] 78 | requires = ["poetry-core"] 79 | build-backend = "poetry.core.masonry.api" 80 | 81 | # Tests / Coverage 82 | 83 | [tool.pytest.ini_options] 84 | asyncio_mode = "auto" 85 | asyncio_default_fixture_loop_scope = "function" 86 | filterwarnings = ["ignore::DeprecationWarning"] 87 | 88 | [tool.coverage.run] 89 | command_line = "-m pytest" 90 | 91 | [tool.coverage.report] 92 | include = ["rigging/*.py"] 93 | show_missing = true 94 | 95 | [tool.coverage.lcov] 96 | output = "lcov.info" 97 | 98 | # Tracing 99 | 100 | [tool.logfire] 101 | ignore_no_config = true 102 | 103 | # Security 104 | 105 | [tool.bandit] 106 | exclude_dirs = ["examples/*", ".github/*", ".hooks/*"] 107 | 108 | # Type Checking 109 | 110 | [tool.mypy] 111 | plugins = "pydantic.mypy" 112 | strict = true 113 | 114 | # Formatting / Linting 115 | 116 | [tool.ruff] 117 | target-version = "py310" 118 | line-length = 100 119 | extend-exclude = [ 120 | "*.ipynb", # jupyter notebooks 121 | "examples/*", # example files 122 | ".github/*", # github files 123 | ".hooks/*", # git hooks 124 | ] 125 | 126 | [tool.ruff.lint] 127 | select = ["ALL"] 128 | ignore = [ 129 | "E501", # line too long (we make best effort) 130 | "TRY003", # long messages in exception classes 131 | "EM", # picky message construction for exceptions 132 | "C90", # mccabe complexity 133 | "A002", # shadowing built-in 134 | "D", # docstrings 135 | "ANN", # annotations (handled by mypy) 136 | "PLR0913", # too many arguments 137 | "ERA001", # commented out code 138 | "FIX002", # contains todo, consider fixing 139 | "TD002", # TODO 140 | "TD003", # TODO 141 | "PLR0911", # too many return statements 142 | "FBT003", # boolean positional in function call 143 | "COM812", # missing trailing comma in function call 144 | ] 145 | 146 | [tool.ruff.format] 147 | skip-magic-trailing-comma = false 148 | -------------------------------------------------------------------------------- /.secrets.baseline: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.5.0", 3 | "plugins_used": [ 4 | { 5 | "name": "ArtifactoryDetector" 6 | }, 7 | { 8 | "name": "AWSKeyDetector" 9 | }, 10 | { 11 | "name": "AzureStorageKeyDetector" 12 | }, 13 | { 14 | "name": "Base64HighEntropyString", 15 | "limit": 4.5 16 | }, 17 | { 18 | "name": "BasicAuthDetector" 19 | }, 20 | { 21 | "name": "CloudantDetector" 22 | }, 23 | { 24 | "name": "DiscordBotTokenDetector" 25 | }, 26 | { 27 | "name": "GitHubTokenDetector" 28 | }, 29 | { 30 | "name": "GitLabTokenDetector" 31 | }, 32 | { 33 | "name": "HexHighEntropyString", 34 | "limit": 3.0 35 | }, 36 | { 37 | "name": "IbmCloudIamDetector" 38 | }, 39 | { 40 | "name": "IbmCosHmacDetector" 41 | }, 42 | { 43 | "name": "IPPublicDetector" 44 | }, 45 | { 46 | "name": "JwtTokenDetector" 47 | }, 48 | { 49 | "name": "KeywordDetector", 50 | "keyword_exclude": "" 51 | }, 52 | { 53 | "name": "MailchimpDetector" 54 | }, 55 | { 56 | "name": "NpmDetector" 57 | }, 58 | { 59 | "name": "OpenAIDetector" 60 | }, 61 | { 62 | "name": "PrivateKeyDetector" 63 | }, 64 | { 65 | "name": "PypiTokenDetector" 66 | }, 67 | { 68 | "name": "SendGridDetector" 69 | }, 70 | { 71 | "name": "SlackDetector" 72 | }, 73 | { 74 | "name": "SoftlayerDetector" 75 | }, 76 | { 77 | "name": "SquareOAuthDetector" 78 | }, 79 | { 80 | "name": "StripeDetector" 81 | }, 82 | { 83 | "name": "TelegramBotTokenDetector" 84 | }, 85 | { 86 | "name": "TwilioKeyDetector" 87 | } 88 | ], 89 | "filters_used": [ 90 | { 91 | "path": "detect_secrets.filters.allowlist.is_line_allowlisted" 92 | }, 93 | { 94 | "path": "detect_secrets.filters.common.is_baseline_file", 95 | "filename": ".secrets.baseline" 96 | }, 97 | { 98 | "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", 99 | "min_level": 2 100 | }, 101 | { 102 | "path": "detect_secrets.filters.heuristic.is_indirect_reference" 103 | }, 104 | { 105 | "path": "detect_secrets.filters.heuristic.is_likely_id_string" 106 | }, 107 | { 108 | "path": "detect_secrets.filters.heuristic.is_lock_file" 109 | }, 110 | { 111 | "path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string" 112 | }, 113 | { 114 | "path": "detect_secrets.filters.heuristic.is_potential_uuid" 115 | }, 116 | { 117 | "path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign" 118 | }, 119 | { 120 | "path": "detect_secrets.filters.heuristic.is_sequential_string" 121 | }, 122 | { 123 | "path": "detect_secrets.filters.heuristic.is_swagger_file" 124 | }, 125 | { 126 | "path": "detect_secrets.filters.heuristic.is_templated_secret" 127 | }, 128 | { 129 | "path": "detect_secrets.filters.regex.should_exclude_file", 130 | "pattern": [ 131 | "examples/*" 132 | ] 133 | } 134 | ], 135 | "results": { 136 | "docs/topics/generators.mdx": [ 137 | { 138 | "type": "Secret Keyword", 139 | "filename": "docs/topics/generators.mdx", 140 | "hashed_secret": "ef5225a03e4f9cc953ab3c4dd41f5c4db7dc2e5b", 141 | "is_verified": false, 142 | "line_number": 360 143 | }, 144 | { 145 | "type": "Secret Keyword", 146 | "filename": "docs/topics/generators.mdx", 147 | "hashed_secret": "eb6256c862c356b375aafa760fa1851e33aa62a9", 148 | "is_verified": false, 149 | "line_number": 384 150 | } 151 | ], 152 | "tests/test_http_spec.py": [ 153 | { 154 | "type": "Secret Keyword", 155 | "filename": "tests/test_http_spec.py", 156 | "hashed_secret": "3acfb2c2b433c0ea7ff107e33df91b18e52f960f", 157 | "is_verified": false, 158 | "line_number": 19 159 | } 160 | ] 161 | }, 162 | "generated_at": "2025-09-03T21:56:13Z" 163 | } 164 | -------------------------------------------------------------------------------- /docs/api/interact.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: rigging.interact 3 | --- 4 | 5 | {/* 6 | ::: rigging.interact 7 | */} 8 | 9 | Utility functions for interactive chat sessions. 10 | 11 | interact 12 | -------- 13 | 14 | ```python 15 | interact( 16 | entrypoint: ChatPipeline | Generator | str, 17 | *, 18 | reset_callback: Callable[[Chat | None], None] 19 | | None = None, 20 | ) -> Chat | None 21 | ``` 22 | 23 | Start an interactive chat session using the given pipeline, generator, or generator id. 24 | 25 | This function allows the user to have a conversation with an assistant by providing input 26 | and receiving responses. The chat session can be controlled using specific commands. 27 | 28 | **Parameters:** 29 | 30 | * **`entrypoint`** 31 | (`ChatPipeline | Generator | str`) 32 | –A ChatPipeline, Generator, or generator id to use for the chat session. 33 | * **`reset_callback`** 34 | (`Callable[[Chat | None], None] | None`, default: 35 | `None` 36 | ) 37 | –A callback function to execute when the chat is reset. 38 | 39 | **Returns:** 40 | 41 | * `Chat | None` 42 | –The final Chat object, or None if the chat was interrupted before any generation. 43 | 44 | <Accordion title="Source code in rigging/interact.py" icon="code"> 45 | ```python 46 | async def interact( 47 | entrypoint: ChatPipeline | Generator | str, 48 | *, 49 | reset_callback: t.Callable[[Chat | None], None] | None = None, 50 | ) -> Chat | None: 51 | """ 52 | Start an interactive chat session using the given pipeline, generator, or generator id. 53 | 54 | This function allows the user to have a conversation with an assistant by providing input 55 | and receiving responses. The chat session can be controlled using specific commands. 56 | 57 | Args: 58 | entrypoint: A ChatPipeline, Generator, or generator id to use for the chat session. 59 | reset_callback: A callback function to execute when the chat is reset. 60 | 61 | Returns: 62 | The final Chat object, or None if the chat was interrupted before any generation. 63 | """ 64 | 65 | print(f"\n{Fore.YELLOW}") 66 | print("Starting interactive chat.") 67 | print() 68 | print(" - Type 'exit' to quit or use Ctrl+C.") 69 | print(" - Type 'reset' or 'restart' to restart the chat.") 70 | print(" - Type 'again' or 'retry' to re-run the last generation.") 71 | print(Style.RESET_ALL) 72 | 73 | base_pipeline = ( 74 | entrypoint 75 | if isinstance(entrypoint, ChatPipeline) 76 | else entrypoint.chat() 77 | if isinstance(entrypoint, Generator) 78 | else get_generator(entrypoint).chat() 79 | ) 80 | 81 | pipeline = base_pipeline.clone() 82 | chat: Chat | None = None 83 | 84 | while True: 85 | try: 86 | user_input = input(f"\n{Fore.GREEN}User: {Style.RESET_ALL}") 87 | if not user_input: 88 | continue 89 | 90 | if user_input.lower() == "exit": 91 | print(f"\n\n{Fore.YELLOW}Exiting chat.{Style.RESET_ALL}") 92 | break 93 | 94 | if user_input.lower() in ["reset", "restart"]: 95 | print(f"\n{Fore.YELLOW}--- Reset ---{Style.RESET_ALL}") 96 | pipeline = base_pipeline.clone() 97 | if reset_callback: 98 | reset_callback(chat) 99 | continue 100 | 101 | if user_input.lower() in ["again", "retry"]: 102 | print(f"\n{Fore.YELLOW}--- Retry ---{Style.RESET_ALL}") 103 | pipeline.chat.messages = pipeline.chat.messages[:-1] 104 | else: 105 | pipeline.add(user_input) 106 | 107 | print() 108 | 109 | animation_task = asyncio.create_task(_animate()) 110 | chat = await pipeline.run() 111 | animation_task.cancel() 112 | 113 | print(f"\r{Fore.BLUE}Assistant: {Style.RESET_ALL}{chat.last.content}") 114 | 115 | pipeline.add(chat.last) 116 | 117 | except KeyboardInterrupt: 118 | print(f"\n\n{Fore.YELLOW}Chat interrupted. Exiting.{Style.RESET_ALL}") 119 | break 120 | except Exception as e: # noqa: BLE001 121 | print(f"\n\n{Fore.RED}An error occurred: {e!s}{Style.RESET_ALL}") 122 | break 123 | 124 | return chat 125 | ``` 126 | 127 | 128 | </Accordion> -------------------------------------------------------------------------------- /rigging/parsing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parsing helpers for extracting rigging models from text 3 | """ 4 | 5 | import typing as t 6 | 7 | from rigging.error import MissingModelError 8 | 9 | if t.TYPE_CHECKING: 10 | from rigging.model import ModelT 11 | 12 | 13 | def parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice]: 14 | """ 15 | Parses a single model from text. 16 | 17 | Args: 18 | text: The content to parse. 19 | model_type: The type of model to parse. 20 | 21 | Returns: 22 | The parsed model. 23 | 24 | Raises: 25 | ValueError: If no models of the given type are found and `fail_on_missing` is set to `True`. 26 | """ 27 | return try_parse_many(text, model_type, fail_on_missing=True)[0] 28 | 29 | 30 | def try_parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice] | None: 31 | """ 32 | Tries to parse a model from text. 33 | 34 | Args: 35 | text: The content to parse. 36 | model_type: The type of model to search for. 37 | 38 | Returns: 39 | The first model that matches the given model type, or None if no match is found. 40 | """ 41 | return next(iter(try_parse_many(text, model_type)), None) 42 | 43 | 44 | def parse_set( 45 | text: str, 46 | model_type: type["ModelT"], 47 | *, 48 | minimum: int | None = None, 49 | ) -> list[tuple["ModelT", slice]]: 50 | """ 51 | Parses a set of models with the specified identical type from text. 52 | 53 | Args: 54 | text: The content to parse. 55 | model_type: The type of models to parse. 56 | minimum: The minimum number of models required. 57 | 58 | Returns: 59 | A list of parsed models. 60 | 61 | Raises: 62 | MissingModelError: If the minimum number of models is not met. 63 | """ 64 | return try_parse_set(text, model_type, minimum=minimum, fail_on_missing=True) 65 | 66 | 67 | def try_parse_set( 68 | text: str, 69 | model_type: type["ModelT"], 70 | *, 71 | minimum: int | None = None, 72 | fail_on_missing: bool = False, 73 | ) -> list[tuple["ModelT", slice]]: 74 | """ 75 | Tries to parse a set of models with the specified identical type from text. 76 | 77 | Args: 78 | text: The content to parse. 79 | model_type: The type of model to parse. 80 | minimum: The minimum number of models expected. 81 | fail_on_missing: Whether to raise an exception if models are missing. 82 | 83 | Returns: 84 | The parsed models. 85 | 86 | Raises: 87 | MissingModelError: If the number of parsed models is less than the minimum required. 88 | """ 89 | models = try_parse_many(text, model_type, fail_on_missing=fail_on_missing) 90 | if minimum is not None and len(models) < minimum: 91 | raise MissingModelError(f"Expected at least {minimum} {model_type.__name__} in message") 92 | return models 93 | 94 | 95 | def parse_many(text: str, *types: type["ModelT"]) -> list[tuple["ModelT", slice]]: 96 | """ 97 | Parses multiple models of the specified non-identical types from text. 98 | 99 | Args: 100 | text: The content to parse. 101 | *types: The types of models to parse. 102 | 103 | Returns: 104 | A list of parsed models. 105 | 106 | Raises: 107 | MissingModelError: If any of the models are missing. 108 | """ 109 | return try_parse_many(text, *types, fail_on_missing=True) 110 | 111 | 112 | def try_parse_many( 113 | text: str, 114 | *types: type["ModelT"], 115 | fail_on_missing: bool = False, 116 | ) -> list[tuple["ModelT", slice]]: 117 | """ 118 | Tries to parses multiple models of the specified non-identical types from text. 119 | 120 | Args: 121 | text: The content to parse. 122 | *types: The types of models to parse. 123 | fail_on_missing: Whether to raise an exception if a model type is missing. 124 | 125 | Returns: 126 | A list of parsed models. 127 | 128 | Raises: 129 | MissingModelError: If a model type is missing and `fail_on_missing` is True. 130 | Exception: If the model is malformed and `fail_on_missing` is True. 131 | """ 132 | model: ModelT 133 | parsed: list[tuple[ModelT, slice]] = [] 134 | 135 | try: 136 | for model_class in types: 137 | for model, slice_ in model_class.from_text(text): 138 | parsed.append((model, slice_)) 139 | except Exception: 140 | if fail_on_missing: 141 | raise 142 | 143 | return sorted(parsed, key=lambda x: x[1].start) 144 | -------------------------------------------------------------------------------- /.hooks/check_pinned_hash_dependencies.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import re 3 | import sys 4 | from pathlib import Path 5 | 6 | 7 | class GitHubActionChecker: 8 | def __init__(self) -> None: 9 | # Pattern for actions with SHA-1 hashes (pinned) 10 | self.pinned_pattern = re.compile(r"uses:\s+([^@\s]+)@([a-f0-9]{40})") 11 | 12 | # Pattern for actions with version tags (unpinned) 13 | self.unpinned_pattern = re.compile( 14 | r"uses:\s+([^@\s]+)@(v\d+(?:\.\d+)*(?:-[a-zA-Z0-9]+(?:\.\d+)*)?)", 15 | ) 16 | 17 | # Pattern for all uses statements 18 | self.all_uses_pattern = re.compile(r"uses:\s+([^@\s]+)@([^\s\n]+)") 19 | 20 | def format_terminal_link(self, file_path: str, line_number: int) -> str: 21 | """Format a terminal link to a file and line number. 22 | 23 | Args: 24 | file_path: Path to the file 25 | line_number: Line number in the file 26 | 27 | Returns: 28 | str: Formatted string with file path and line number 29 | """ 30 | return f"{file_path}:{line_number}" 31 | 32 | def get_line_numbers(self, content: str, pattern: re.Pattern[str]) -> list[tuple[str, int]]: 33 | """Find matches with their line numbers.""" 34 | matches = [] 35 | matches.extend( 36 | (match.group(0), i) 37 | for i, line in enumerate(content.splitlines(), 1) 38 | for match in pattern.finditer(line) 39 | ) 40 | return matches 41 | 42 | def check_file(self, file_path: str) -> bool: 43 | """Check a single file for unpinned dependencies.""" 44 | try: 45 | content = Path(file_path).read_text() 46 | except (FileNotFoundError, PermissionError, IsADirectoryError, OSError) as e: 47 | print(f"\033[91mError reading file {file_path}: {e}\033[0m") 48 | return False 49 | 50 | # Get matches with line numbers 51 | pinned_matches = self.get_line_numbers(content, self.pinned_pattern) 52 | unpinned_matches = self.get_line_numbers(content, self.unpinned_pattern) 53 | all_matches = self.get_line_numbers(content, self.all_uses_pattern) 54 | 55 | print(f"\n\033[1m[=] Checking file: {file_path}\033[0m") 56 | 57 | # Print pinned dependencies 58 | if pinned_matches: 59 | print("\033[92m[+] Pinned:\033[0m") 60 | for match, line_num in pinned_matches: 61 | print(f" |- {match} \033[90m({file_path}:{line_num})\033[0m") 62 | 63 | # Track all found actions for validation 64 | found_actions = set() 65 | for match, _ in pinned_matches + unpinned_matches: 66 | action_name = self.pinned_pattern.match(match) or self.unpinned_pattern.match(match) 67 | if action_name: 68 | found_actions.add(action_name.group(1)) 69 | 70 | has_errors = False 71 | 72 | # Check for unpinned dependencies 73 | if unpinned_matches: 74 | has_errors = True 75 | print("\033[93m[!] Unpinned (using version tags):\033[0m") 76 | for match, line_num in unpinned_matches: 77 | print(f" |- {match} \033[90m({file_path}:{line_num})\033[0m") 78 | 79 | # Check for completely unpinned dependencies (no SHA or version) 80 | unpinned_without_hash = [ 81 | (match, line_num) 82 | for match, line_num in all_matches 83 | if not any(match in pinned[0] for pinned in pinned_matches) 84 | and not any(match in unpinned[0] for unpinned in unpinned_matches) 85 | ] 86 | 87 | if unpinned_without_hash: 88 | has_errors = True 89 | print("\033[91m[!] Completely unpinned (no SHA or version):\033[0m") 90 | for match, line_num in unpinned_without_hash: 91 | print( 92 | f" |- {match} \033[90m({self.format_terminal_link(file_path, line_num)})\033[0m", 93 | ) 94 | 95 | # Print summary 96 | total_actions = len(pinned_matches) + len(unpinned_matches) + len(unpinned_without_hash) 97 | if total_actions == 0: 98 | print("\033[93m[!] No GitHub Actions found in this file\033[0m") 99 | else: 100 | print("\n\033[1mSummary:\033[0m") 101 | print(f"Total actions: {total_actions}") 102 | print(f"Pinned: {len(pinned_matches)}") 103 | print(f"Unpinned with version: {len(unpinned_matches)}") 104 | print(f"Completely unpinned: {len(unpinned_without_hash)}") 105 | 106 | return not has_errors 107 | 108 | 109 | def main() -> None: 110 | checker = GitHubActionChecker() 111 | files_to_check = sys.argv[1:] 112 | 113 | if not files_to_check: 114 | print("\033[91mError: No files provided to check\033[0m") 115 | print("Usage: python script.py <file1> <file2> ...") 116 | sys.exit(1) 117 | 118 | results = {file: checker.check_file(file) for file in files_to_check} 119 | 120 | # Print final summary 121 | print("\n\033[1mFinal Results:\033[0m") 122 | for file, passed in results.items(): 123 | status = "\033[92m✓ Passed\033[0m" if passed else "\033[91m✗ Failed\033[0m" 124 | print(f"{status} {file}") 125 | 126 | if not all(results.values()): 127 | sys.exit(1) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /docs/topics/serialization.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Serialization" 3 | description: "Serialization and Deserialization of Rigging objects" 4 | public: true 5 | --- 6 | 7 | The following objects in Rigging have great serialization support for storage and retrieval: 8 | 9 | - `Chat` 10 | - `Completion` 11 | - `Generator` 12 | - `Model` 13 | 14 | Most of this stems from our use of Pydantic for core models, and we've included some helpful fields for reconstructing Chats and Completions. 15 | 16 | ## JSON Serialization 17 | 18 | Let's build a joke pipeline and serialize the final chat into JSON. 19 | 20 | <CodeGroup> 21 | ```python Serialization 22 | import rigging as rg 23 | 24 | class Joke(rg.Model): 25 | content: str 26 | 27 | chat = ( 28 | await 29 | rg.get_generator("gpt-3.5-turbo") 30 | .chat(f"Provide 3 jokes each between {Joke.xml_tags()} tags.") 31 | .meta(tags=['joke']) 32 | .with_(temperature=1.25) 33 | .run() 34 | ) 35 | 36 | chat.last.parse_set(Joke) 37 | 38 | serialized = chat.model_dump_json(indent=2) 39 | print(serialized) 40 | ``` 41 | 42 | ```json Serialized JSON 43 | { 44 | "uuid": "891c3834-2588-4652-8371-e9746086fd46", 45 | "timestamp": "2024-05-10T11:44:15.501326", 46 | "messages": [ 47 | { 48 | "role": "user", 49 | "parts": [], 50 | "content": "Provide 3 jokes each between <joke></joke> tags." 51 | } 52 | ], 53 | "generated": [ 54 | { 55 | "role": "assistant", 56 | "parts": [ 57 | { 58 | "model": { 59 | "content": " Why was the math book sad? Because it had too many problems. " 60 | }, 61 | "slice_": [ 62 | 0, 63 | 75 64 | ] 65 | }, 66 | { 67 | "model": { 68 | "content": " I told my wife she should embrace her mistakes. She gave me a hug. " 69 | }, 70 | "slice_": [ 71 | 76, 72 | 157 73 | ] 74 | }, 75 | { 76 | "model": { 77 | "content": " Why did the scarecrow win an award? Because he was outstanding in his field. " 78 | }, 79 | "slice_": [ 80 | 158, 81 | 249 82 | ] 83 | } 84 | ], 85 | "content": "<joke> Why was the math book sad? Because it had too many problems. </joke>\n<joke> I told my wife she should embrace her mistakes. She gave me a hug. </joke>\n<joke> Why did the scarecrow win an award? Because he was outstanding in his field. </joke>" 86 | } 87 | ], 88 | "metadata": { 89 | "tags": [ 90 | "joke" 91 | ] 92 | }, 93 | "generator_id": "litellm!gpt-3.5-turbo,temperature=1.25" 94 | } 95 | ``` 96 | </CodeGroup> 97 | 98 | You'll notice that every Chat gets a unique `id` field to help track them in a datastore like Elastic or Pandas. We also assign a `timestamp` to understand when the generation took place. We are also taking advantage of the `.meta()` rigging.chat.ChatPipeline.meta to add a tracking tag for filtering later. 99 | 100 | ## JSON Deserialization 101 | 102 | The JSON has everything required to reconstruct a Chat including a `generator_id` dynamically constructed to preserve the parameters used to create the generated message(s). We can now deserialize a chat from a datastore, and immediately step back into a `ChatPipeline` for exploration. 103 | 104 | ```python 105 | chat = rg.Chat.model_validate_json(serialized) 106 | print(chat.conversation) 107 | # [user]: Provide 3 jokes each between <joke></joke> tags. 108 | 109 | # [assistant]: 110 | # <joke> Why was the math book sad? Because it had too many problems. </joke> 111 | # <joke> I told my wife she should embrace her mistakes. She gave me a hug. </joke> 112 | # <joke> Why did the scarecrow win an award? Because he was outstanding in his field. </joke> 113 | 114 | continued = chat.continue_("Please explain the first joke to me.").run() 115 | print(continued.last) 116 | # [assistant]: In the first joke, the pun is based on the double meaning of the word "problems." 117 | # The math book is described as being sad because it has "too many problems," which could be 118 | # interpreted as having both mathematical problems (equations to solve) and emotional difficulties. 119 | # This play on words adds humor to the joke. 120 | ``` 121 | 122 | ## Pandas DataFrames 123 | 124 | Rigging also has helpers in the `rigging.data` module for performing conversions between Chat objects and other storage formats like Pandas. In `chats_to_df` the messages are flattened and stored with a `chat_id` column for grouping. `df_to_chats` allows you to reconstruct a list of Chat objects back from a DataFrame. 125 | 126 | ```python 127 | import rigging as rg 128 | 129 | chats = ( 130 | await 131 | rg.get_generator("claude-3-haiku-20240307") 132 | .chat("Write me a haiku.") 133 | .run_many(3) 134 | ) 135 | 136 | df = rg.data.chats_to_df(chats) 137 | # or 138 | df = chats.to_df() 139 | 140 | print(df.info()) 141 | 142 | # RangeIndex: 6 entries, 0 to 5 143 | # Data columns (total 9 columns): 144 | # # Column Non-Null Count Dtype 145 | # --- ------ -------------- ----- 146 | # 0 chat_id 6 non-null string 147 | # 1 chat_metadata 6 non-null string 148 | # 2 chat_generator_id 6 non-null string 149 | # 3 chat_timestamp 6 non-null datetime64[ms] 150 | # 4 generated 6 non-null bool 151 | # 5 role 6 non-null category 152 | # 6 parts 6 non-null string 153 | # 7 content 6 non-null string 154 | # 8 message_id 6 non-null string 155 | # dtypes: bool(1), category(1), datetime64[ms](1), string(6) 156 | 157 | df.content.apply(lambda x: len(x)).mean() 158 | 159 | # 60.166666666666664 160 | 161 | back = rg.data.df_to_chats(df) 162 | print(back[0].conversation) 163 | 164 | # [user]: Write me a haiku. 165 | # 166 | # [assistant]: Here's a haiku for you: 167 | # 168 | # Gentle breeze whispers, 169 | # Flowers bloom in vibrant hues, 170 | # Nature's simple bliss. 171 | ``` 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | <p align="center"> 2 | <img 3 | src="https://d1lppblt9t2x15.cloudfront.net/logos/5714928f3cdc09503751580cffbe8d02.png" 4 | alt="Logo" 5 | align="center" 6 | width="144px" 7 | height="144px" 8 | /> 9 | </p> 10 | 11 | <h3 align="center"> 12 | Flexible LLM library for code and agents 13 | </h3> 14 | 15 | <h4 align="center"> 16 | <img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/rigging"> 17 | <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/rigging"> 18 | <img alt="GitHub License" src="https://img.shields.io/github/license/dreadnode/rigging"> 19 | <img alt="GitHub Actions Workflow Status" src="https://img.shields.io/github/actions/workflow/status/dreadnode/rigging/ci.yml"> 20 | </h4> 21 | 22 | </br> 23 | 24 | Rigging is a lightweight LLM framework to make using language models in production code as simple and effective as possible. Here are the highlights: 25 | 26 | - **Structured Pydantic models** can be used interchangeably with unstructured text output. 27 | - LiteLLM as the default generator giving you **instant access to a huge array of models**. 28 | - Define prompts as python functions with **type hints and docstrings**. 29 | - Simple **tool use**, even for models which don't support them at the API. 30 | - Store different models and configs as **simple connection strings** just like databases. 31 | - Integrated tracing support with [Logfire](https://logfire.pydantic.dev/docs/). 32 | - Chat templating, forking, continuations, generation parameter overloads, stripping segments, etc. 33 | - Async batching and fast iterations for **large scale generation**. 34 | - Metadata, callbacks, and data format conversions. 35 | - Modern python with type hints, async support, pydantic validation, serialization, etc. 36 | 37 | ```py 38 | import rigging as rg 39 | 40 | @rg.prompt(generator_id="gpt-4") 41 | async def get_authors(count: int = 3) -> list[str]: 42 | """Provide famous authors.""" 43 | 44 | print(await get_authors()) 45 | 46 | # ['William Shakespeare', 'J.K. Rowling', 'Jane Austen'] 47 | ``` 48 | 49 | Rigging is built by [**dreadnode**](https://dreadnode.io) where we use it daily. 50 | 51 | ## Installation 52 | 53 | We publish every version to Pypi: 54 | ```bash 55 | pip install rigging 56 | ``` 57 | 58 | If you want to build from source: 59 | ```bash 60 | cd rigging/ 61 | poetry install 62 | ``` 63 | 64 | ## Supported LLMs 65 | 66 | Rigging will run just about any language model: 67 | 68 | - Any model from [**LiteLLM**](https://litellm.vercel.app/docs/providers) 69 | - Any model from [**vLLM**](https://docs.vllm.ai/en/latest/models/supported_models.html) 70 | - Any model from [**transformers**](https://huggingface.co/docs/transformers/) 71 | 72 | ### API Keys 73 | 74 | Pass the `api_key` in an generator id or use standard environment variables. 75 | 76 | ```py 77 | rg.get_generator("gpt-4-turbo,api_key=...") 78 | ``` 79 | 80 | ```bash 81 | export OPENAI_API_KEY=... 82 | export MISTRAL_API_KEY=... 83 | export ANTHROPIC_API_KEY=... 84 | ... 85 | ``` 86 | 87 | Check out [the docs](https://docs.dreadnode.io/open-source/rigging/topics/generators#api-keys) for more. 88 | 89 | ## Getting Started 90 | 91 | **Check out the guide [in the docs](https://docs.dreadnode.io/open-source/rigging/intro#getting-started)** 92 | 93 | 1. **Get a generator** using a connection string. 94 | 2. Build a **chat** or **completion** pipeline 95 | 3. **Run** the pipeline and get the output. 96 | 97 | ```py 98 | import rigging as rg 99 | import asyncio 100 | 101 | async def main(): 102 | # 1 - Get a generator 103 | generator = rg.get_generator("claude-3-sonnet-20240229") 104 | 105 | # 2 - Build a chat pipeline 106 | pipeline = generator.chat( 107 | [ 108 | {"role": "system", "content": "Talk like a pirate."}, 109 | {"role": "user", "content": "Say hello!"}, 110 | ] 111 | ) 112 | 113 | # 3 - Run the pipeline 114 | chat = await pipeline.run() 115 | print(chat.conversation) 116 | 117 | # Run the main function 118 | asyncio.run(main()) 119 | 120 | # [system]: Talk like a pirate. 121 | # [user]: Say hello! 122 | # [assistant]: Ahoy, matey! Here be the salty sea dog ready to trade greetings wit' ye. Arrr! 123 | ``` 124 | 125 | Want more? 126 | 127 | - Use [structured pydantic parsing](https://docs.dreadnode.io/open-source/rigging/topics/chats-and-messages#parsed-parts) 128 | - Check out [raw completions](https://docs.dreadnode.io/open-source/rigging/topics/completions/) 129 | - Give the LLM [access to tools](https://docs.dreadnode.io/open-source/rigging/topics/tools/) 130 | - Track behavior with [tracing](https://docs.dreadnode.io/open-source/rigging/topics/tracing/) 131 | - Play with [generation params](https://docs.dreadnode.io/open-source/rigging/topics/generators/#overload-generation-params) 132 | - Use [callbacks in the pipeline](https://docs.dreadnode.io/open-source/rigging/topics/callbacks-and-mapping/) 133 | - Scale up with [iterating and batching](https://docs.dreadnode.io/open-source/rigging/topics/iterating-and-batching/) 134 | - Save your work with [serialization](https://docs.dreadnode.io/open-source/rigging/topics/serialization/) 135 | 136 | ## Examples 137 | 138 | - Basic interactive chat: [**chat.py**](examples/chat.py) 139 | - Jupyter code interpreter: [**jupyter.py**](examples/jupyter.py) 140 | - OverTheWire Bandit Agent: [**bandit.py**](examples/bandit.py) 141 | - Damn Vulnerable Restaurant Agent: [**dvra.py**](examples/dvra.py) 142 | - RAG Pipeline: [**rag.py**](examples/rag.py) (from [kyleavery](https://github.com/kyleavery/)) 143 | - Integrating dreadnode-owned [robopages](https://github.com/dreadnode/robopages-cli) as a tool server (basic nmap scan example): [**rigging_example.py**](https://github.com/dreadnode/robopages-cli/blob/main/examples/rigging_example.py) 144 | 145 | ## Documentation 146 | 147 | **[docs.dreadnode.io](https://docs.dreadnode.io/open-source/rigging/intro)** has everything you need. 148 | 149 | ## Star History 150 | 151 | [![Star History Chart](https://api.star-history.com/svg?repos=dreadnode/rigging&type=Date)](https://star-history.com/#dreadnode/rigging&Date) 152 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from textwrap import dedent 3 | from xml.sax.saxutils import escape 4 | 5 | import pytest 6 | 7 | from rigging.model import Model, attr, element 8 | 9 | # mypy: disable-error-code=empty-body 10 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001 11 | 12 | 13 | class SimpleModel(Model): 14 | """A simple model to test basic element generation.""" 15 | 16 | content: str = element(examples=["Hello, World!"]) 17 | """The main content of the model.""" 18 | 19 | 20 | class NoExampleModel(Model): 21 | """A model to test fallback to an empty element when no example is given.""" 22 | 23 | name: str 24 | """The name of the entity.""" 25 | 26 | 27 | class AttrAndElementModel(Model): 28 | """Tests a model with both an attribute and a child element.""" 29 | 30 | id: int = attr(examples=[123]) 31 | """The unique identifier (attribute).""" 32 | value: str = element(examples=["Some value"]) 33 | """The primary value (element).""" 34 | 35 | 36 | class DocstringDescriptionModel(Model): 37 | """Tests that field docstrings are correctly used as descriptions.""" 38 | 39 | field1: str = element(examples=["val1"]) 40 | """This is the description for field1.""" 41 | field2: bool = element(examples=[True]) 42 | """This is the description for field2.""" 43 | 44 | 45 | class ParameterDescriptionModel(Model): 46 | """Tests that the `description` parameter overrides a field's docstring.""" 47 | 48 | param: str = element( 49 | examples=["override"], description="This description is from the `description` parameter." 50 | ) 51 | """This docstring should be ignored in the XML example.""" 52 | 53 | 54 | class SpecialCharsModel(Model): 55 | """Tests proper escaping of special XML characters in examples and comments.""" 56 | 57 | comment: str = element(examples=["ok"]) 58 | """This comment contains < and > & special characters.""" 59 | data: str = element(examples=["<tag>&'"]) 60 | """This element's example contains special XML characters.""" 61 | 62 | 63 | # This class definition is based on the one you provided in the prompt. 64 | class Analysis(Model, tag="analysis"): 65 | """A model to validate the exact output requested in the prompt.""" 66 | 67 | priority: t.Literal["low", "medium", "high", "critical"] = element(examples=["medium"]) 68 | """Triage priority for human follow-up.""" 69 | tags: str = element("tags", examples=["admin panel, error message, legacy"]) 70 | """A list of specific areas within the screenshot that are noteworthy or require further examination.""" 71 | summary: str = element() 72 | """A markdown summary explaining *why* the screenshot is interesting and what a human should investigate next.""" 73 | 74 | 75 | @pytest.mark.parametrize( 76 | ("model_cls", "expected_xml"), 77 | [ 78 | pytest.param( 79 | SimpleModel, 80 | """ 81 | <simple-model> 82 | <!-- The main content of the model. --> 83 | <content>Hello, World!</content> 84 | </simple-model> 85 | """, 86 | id="simple_model", 87 | ), 88 | pytest.param( 89 | NoExampleModel, 90 | """ 91 | <no-example-model></no-example-model> 92 | """, 93 | id="model_with_no_example", 94 | ), 95 | pytest.param( 96 | AttrAndElementModel, 97 | """ 98 | <attr-and-element-model id="123"> 99 | <!-- The primary value (element). --> 100 | <value>Some value</value> 101 | </attr-and-element-model> 102 | """, 103 | id="model_with_attribute_and_element", 104 | ), 105 | pytest.param( 106 | DocstringDescriptionModel, 107 | """ 108 | <docstring-description-model> 109 | <!-- This is the description for field1. --> 110 | <field1>val1</field1> 111 | <!-- This is the description for field2. --> 112 | <field2>True</field2> 113 | </docstring-description-model> 114 | """, 115 | id="descriptions_from_docstrings", 116 | ), 117 | pytest.param( 118 | ParameterDescriptionModel, 119 | """ 120 | <parameter-description-model> 121 | <!-- This description is from the `description` parameter. --> 122 | <param>override</param> 123 | </parameter-description-model> 124 | """, 125 | id="description_from_parameter_overrides_docstring", 126 | ), 127 | pytest.param( 128 | SpecialCharsModel, 129 | f""" 130 | <special-chars-model> 131 | <!-- {escape("This comment contains < and > & special characters.")} --> 132 | <comment>ok</comment> 133 | <!-- {escape("This element's example contains special XML characters.")} --> 134 | <data>{escape("<tag>&'")}</data> 135 | </special-chars-model> 136 | """, 137 | id="escaping_of_special_characters", 138 | ), 139 | pytest.param( 140 | Analysis, 141 | """ 142 | <analysis> 143 | <!-- Triage priority for human follow-up. --> 144 | <priority>medium</priority> 145 | <!-- A list of specific areas within the screenshot that are noteworthy or require further examination. --> 146 | <tags>admin panel, error message, legacy</tags> 147 | <!-- A markdown summary explaining *why* the screenshot is interesting and what a human should investigate next. --> 148 | <summary/> 149 | </analysis> 150 | """, 151 | id="user_provided_analysis_model", 152 | ), 153 | ], 154 | ) 155 | def test_xml_example_generation(model_cls: type[Model], expected_xml: str) -> None: 156 | """ 157 | Validates that the `xml_example()` class method produces the correct 158 | pretty-printed XML with examples and descriptions as comments. 159 | """ 160 | actual_xml = model_cls.xml_example() 161 | assert dedent(actual_xml).strip() == dedent(expected_xml).strip() 162 | -------------------------------------------------------------------------------- /tests/test_prompt.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from textwrap import dedent 3 | from typing import Annotated 4 | 5 | import pytest 6 | 7 | import rigging as rg 8 | from rigging.chat import Chat 9 | 10 | # mypy: disable-error-code=empty-body 11 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001 12 | 13 | 14 | def test_prompt_render_docstring_parse() -> None: 15 | @rg.prompt 16 | async def foo(name: str) -> str: 17 | """Say hello.""" 18 | 19 | assert foo.docstring == "Say hello." 20 | 21 | @rg.prompt 22 | async def bar(name: str) -> str: 23 | """ 24 | Say hello.""" 25 | 26 | assert bar.docstring == "Say hello." 27 | 28 | @rg.prompt 29 | async def baz(name: str) -> str: 30 | """ 31 | Say \ 32 | hello. 33 | 34 | """ 35 | 36 | assert baz.docstring == "Say hello." 37 | 38 | 39 | def test_basic_prompt_render() -> None: 40 | @rg.prompt 41 | async def hello(name: str) -> str: 42 | """Say hello.""" 43 | 44 | rendered = hello.render("Alice") 45 | assert rendered == dedent( 46 | """\ 47 | Say hello. 48 | 49 | <name>Alice</name> 50 | 51 | Produce the following output (use xml tags): 52 | 53 | <str></str> 54 | """, 55 | ) 56 | 57 | 58 | def test_prompt_render_with_docstring_variables() -> None: 59 | @rg.prompt 60 | async def greet(name: str, greeting: str = "Hello") -> str: 61 | """Say '{{ greeting }}' to {{ name }}.""" 62 | 63 | rendered = greet.render("Bob") 64 | assert rendered == dedent( 65 | """\ 66 | Say 'Hello' to Bob. 67 | 68 | Produce the following output (use xml tags): 69 | 70 | <str></str> 71 | """, 72 | ) 73 | 74 | 75 | def test_prompt_render_with_model_output() -> None: 76 | class Person(rg.Model): 77 | name: str = rg.element() 78 | age: int = rg.element() 79 | 80 | @rg.prompt 81 | async def create_person(name: str, age: int) -> Person: 82 | """Create a person.""" 83 | 84 | rendered = create_person.render("Alice", 30) 85 | assert rendered == dedent( 86 | """\ 87 | Create a person. 88 | 89 | <name>Alice</name> 90 | 91 | <age>30</age> 92 | 93 | Produce the following output (use xml tags): 94 | 95 | <person> 96 | <name/> 97 | <age/> 98 | </person> 99 | """, 100 | ) 101 | 102 | 103 | def test_prompt_render_with_list_output() -> None: 104 | @rg.prompt 105 | async def generate_numbers(count: int) -> list[int]: 106 | """Generate a list of numbers.""" 107 | 108 | rendered = generate_numbers.render(5) 109 | assert rendered == dedent( 110 | """\ 111 | Generate a list of numbers. 112 | 113 | <count>5</count> 114 | 115 | Produce the following output for each item (use xml tags): 116 | 117 | <int></int> 118 | """, 119 | ) 120 | 121 | 122 | def test_prompt_render_with_tuple_output() -> None: 123 | @rg.prompt 124 | async def create_user(username: str) -> tuple[str, int]: 125 | """Create a new user.""" 126 | 127 | rendered = create_user.render("johndoe") 128 | assert rendered == dedent( 129 | """\ 130 | Create a new user. 131 | 132 | <username>johndoe</username> 133 | 134 | Produce the following outputs (use xml tags): 135 | 136 | <str></str> 137 | 138 | <int></int> 139 | """, 140 | ) 141 | 142 | 143 | def test_prompt_render_with_tuple_output_ctx() -> None: 144 | @rg.prompt 145 | async def create_user(username: str) -> tuple[Annotated[str, rg.Ctx(tag="id")], int]: 146 | """Create a new user.""" 147 | 148 | rendered = create_user.render("johndoe") 149 | assert rendered == dedent( 150 | """\ 151 | Create a new user. 152 | 153 | <username>johndoe</username> 154 | 155 | Produce the following outputs (use xml tags): 156 | 157 | <id></id> 158 | 159 | <int></int> 160 | """, 161 | ) 162 | 163 | 164 | def test_prompt_render_with_dataclass_output() -> None: 165 | @dataclass 166 | class User: 167 | username: str 168 | email: str 169 | age: int 170 | 171 | @rg.prompt 172 | async def register_user(username: str, email: str, age: int) -> User: 173 | """Register a new user: {{ username}}.""" 174 | 175 | rendered = register_user.render("johndoe", "johndoe@example.com", 25) 176 | assert rendered == dedent( 177 | """\ 178 | Register a new user: johndoe. 179 | 180 | <email>johndoe@example.com</email> 181 | 182 | <age>25</age> 183 | 184 | Produce the following outputs (use xml tags): 185 | 186 | <username></username> 187 | 188 | <email></email> 189 | 190 | <age></age> 191 | """, 192 | ) 193 | 194 | 195 | def test_prompt_render_with_chat_return() -> None: 196 | @rg.prompt 197 | async def foo(input_: str) -> Chat: 198 | """Do something.""" 199 | 200 | rendered = foo.render("bar") 201 | assert rendered == dedent( 202 | """\ 203 | Do something. 204 | 205 | <input>bar</input> 206 | """, 207 | ) 208 | 209 | 210 | def test_prompt_render_ctx_in_dataclass() -> None: 211 | @dataclass 212 | class User: 213 | username: str 214 | email: Annotated[str, rg.Ctx(prefix="The user email:", example="[test@email.com]")] 215 | age: Annotated[int, rg.Ctx(tag="override")] 216 | 217 | @rg.prompt 218 | async def register_user(username: str, email: str, age: int) -> User: 219 | """Register a new user: {{ username }}.""" 220 | 221 | rendered = register_user.render("johndoe", "john@email.com", 30) 222 | assert rendered == dedent( 223 | """\ 224 | Register a new user: johndoe. 225 | 226 | <email>john@email.com</email> 227 | 228 | <age>30</age> 229 | 230 | Produce the following outputs (use xml tags): 231 | 232 | <username></username> 233 | 234 | The user email: 235 | <email>[test@email.com]</email> 236 | 237 | <override></override> 238 | """, 239 | ) 240 | 241 | 242 | def test_prompt_parse_fail_nested_input() -> None: 243 | async def foo(arg: list[list[str]]) -> Chat: ... 244 | 245 | with pytest.raises(TypeError): 246 | rg.prompt(foo) 247 | 248 | async def bar(arg: tuple[int, str, tuple[str]]) -> Chat: ... 249 | 250 | with pytest.raises(TypeError): 251 | rg.prompt(bar) 252 | 253 | 254 | def test_prompt_parse_fail_unique_ouput() -> None: 255 | async def foo(arg: int) -> tuple[str, str]: ... 256 | 257 | with pytest.raises(TypeError): 258 | rg.prompt(foo) 259 | -------------------------------------------------------------------------------- /rigging/error.py: -------------------------------------------------------------------------------- 1 | """ 2 | We try to avoid creating custom exceptions unless they are necessary. 3 | 4 | We use the built-in and pydantic exceptions as much as possible. 5 | """ 6 | 7 | import functools 8 | import typing as t 9 | 10 | import typing_extensions as te 11 | 12 | if t.TYPE_CHECKING: 13 | from rigging.chat import PipelineStep 14 | from rigging.message import Message 15 | 16 | 17 | # User Throwable Exceptions 18 | 19 | 20 | class Stop(Exception): # noqa: N818 21 | """ 22 | Raise inside a pipeline to indicate a stopping condition. 23 | 24 | Example: 25 | ``` 26 | import rigging as rg 27 | 28 | async def read_file(path: str) -> str: 29 | "Read the contents of a file." 30 | 31 | if no_more_files(path): 32 | raise rg.Stop("There are no more files to read.") 33 | 34 | ... 35 | 36 | chat = await pipeline.using(read_file).run() 37 | ``` 38 | """ 39 | 40 | def __init__(self, message: str): 41 | super().__init__(message) 42 | self.message = message 43 | """The message associated with the stop.""" 44 | 45 | 46 | # Warnings 47 | 48 | 49 | class PipelineWarning(Warning): 50 | """ 51 | Base class for all pipeline warnings. 52 | 53 | This is used to indicate that something unexpected happened during the pipeline execution, 54 | but it is not critical enough to stop the execution. 55 | """ 56 | 57 | 58 | class ToolWarning(Warning): 59 | """ 60 | Base class for all tool warnings. 61 | 62 | This is used to indicate that something unexpected happened during the tool execution, 63 | but it is not critical enough to stop the execution. 64 | """ 65 | 66 | 67 | class MessageWarning(Warning): 68 | """ 69 | Base class for all message warnings. 70 | 71 | This is used to indicate that something unexpected happened during the message processing, 72 | but it is not critical enough to stop the execution. 73 | """ 74 | 75 | 76 | class TokenizerWarning(Warning): 77 | """ 78 | Base class for all tokenization warnings. 79 | 80 | This is used to indicate that something unexpected happened during the tokenization process, 81 | but it is not critical enough to stop the execution. 82 | """ 83 | 84 | 85 | class GeneratorWarning(Warning): 86 | """ 87 | Base class for all generator warnings. 88 | 89 | This is used to indicate that something unexpected happened during the generator execution, 90 | but it is not critical enough to stop the execution. 91 | """ 92 | 93 | 94 | # System Exceptions 95 | 96 | 97 | class UnknownToolError(Exception): 98 | """ 99 | Raised when the an api tool call is made for an unknown tool. 100 | """ 101 | 102 | def __init__(self, tool_name: str): 103 | super().__init__(f"Unknown tool call was requested for '{tool_name}'") 104 | self.tool_name = tool_name 105 | """The name of the tool which was unknown.""" 106 | 107 | 108 | class ToolDefinitionError(Exception): 109 | """ 110 | Raised when a tool cannot be properly defined. 111 | """ 112 | 113 | def __init__(self, message: str): 114 | super().__init__(message) 115 | 116 | 117 | class ExhaustedMaxRoundsError(Exception): 118 | """ 119 | Raised when the maximum number of rounds is exceeded while generating. 120 | """ 121 | 122 | def __init__(self, max_rounds: int): 123 | super().__init__(f"Exhausted max rounds ({max_rounds}) while generating") 124 | self.max_rounds = max_rounds 125 | """The number of rounds which was exceeded.""" 126 | 127 | 128 | class MessagesExhaustedMaxRoundsError(ExhaustedMaxRoundsError): 129 | """ 130 | Raised when the maximum number of rounds is exceeded while generating messages. 131 | """ 132 | 133 | def __init__(self, max_rounds: int, messages: list["Message"]): 134 | super().__init__(max_rounds) 135 | self.messages = messages 136 | """The messages which were being generated when the exception occurred.""" 137 | 138 | 139 | class CompletionExhaustedMaxRoundsError(ExhaustedMaxRoundsError): 140 | """ 141 | Raised when the maximum number of rounds is exceeded while generating completions. 142 | """ 143 | 144 | def __init__(self, max_rounds: int, completion: str): 145 | super().__init__(max_rounds) 146 | self.completion = completion 147 | """The completion which was being generated when the exception occurred.""" 148 | 149 | 150 | class MaxDepthError(Exception): 151 | """ 152 | Raised when the maximum depth is exceeded while generating. 153 | """ 154 | 155 | def __init__(self, max_depth: int, step: "PipelineStep", reference: str): 156 | super().__init__(f"Exceeded max depth ({max_depth}) while generating ('{reference}')") 157 | self.max_depth = max_depth 158 | """The maximum depth of nested pipeline generations which was exceeded.""" 159 | self.step = step 160 | """The pipeline step which cause the depth error.""" 161 | 162 | 163 | class InvalidGeneratorError(Exception): 164 | """ 165 | Raised when an invalid identifier is specified when getting a generator. 166 | """ 167 | 168 | def __init__(self, model: str): 169 | super().__init__(f"Invalid model specified: {model}") 170 | 171 | 172 | class InvalidTokenizerError(Exception): 173 | """ 174 | Raised when an invalid tokenizer is specified. 175 | """ 176 | 177 | def __init__(self, tokenizer: str): 178 | super().__init__(f"Invalid tokenizer specified: {tokenizer}") 179 | self.tokenizer = tokenizer 180 | """The name of the tokenizer which was invalid.""" 181 | 182 | 183 | class MissingModelError(Exception): 184 | """ 185 | Raised when a model is missing when parsing a message. 186 | """ 187 | 188 | def __init__(self, content: str): 189 | super().__init__(content) 190 | 191 | 192 | class ProcessingError(Exception): 193 | """ 194 | Raised when an error occurs during internal generator processing. 195 | """ 196 | 197 | def __init__(self, content: str): 198 | super().__init__(content) 199 | 200 | 201 | P = te.ParamSpec("P") 202 | R = t.TypeVar("R") 203 | 204 | 205 | def raise_as( 206 | error_type: type[Exception], 207 | message: str, 208 | ) -> t.Callable[[t.Callable[P, R]], t.Callable[P, R]]: 209 | "When the wrapped function raises an exception, `raise ... from` with the new error type." 210 | 211 | def _raise_as(func: t.Callable[P, R]) -> t.Callable[P, R]: 212 | @functools.wraps(func) 213 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 214 | try: 215 | return func(*args, **kwargs) 216 | except Exception as e: 217 | error = error_type(message) 218 | raise error from e 219 | 220 | if wrapper.__doc__ is None: 221 | wrapper.__doc__ = "" 222 | 223 | wrapper.__doc__ += f"\n\nRaises:\n {error_type.__name__}{': ' + message}" 224 | 225 | return wrapper 226 | 227 | return _raise_as 228 | -------------------------------------------------------------------------------- /rigging/generator/vllm_.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import inspect 3 | import typing as t 4 | 5 | import torch # type: ignore [import-not-found, import-untyped, unused-ignore] 6 | import vllm # type: ignore [import-not-found,import-untyped, unused-ignore] 7 | 8 | from rigging.generator.base import ( 9 | GeneratedMessage, 10 | GeneratedText, 11 | GenerateParams, 12 | Generator, 13 | trace_messages, 14 | trace_str, 15 | ) 16 | 17 | if t.TYPE_CHECKING: 18 | from rigging.message import Message 19 | 20 | # Any batch over this size will trigger a dedicated 21 | # cache warmup step 22 | CACHE_TRIGGER = 8 23 | 24 | 25 | class VLLMGenerator(Generator): 26 | """ 27 | Generator backed by the vLLM library for local model loading. 28 | 29 | Find more information about supported models and formats [in their docs.](https://docs.vllm.ai/en/latest/index.html) 30 | 31 | Warning: 32 | The use of VLLM requires the `vllm` package to be installed directly or by 33 | installing rigging as `rigging[all]`. 34 | 35 | Note: 36 | This generator doesn't leverage any async capabilities. 37 | 38 | Note: 39 | The model load into memory will occur lazily when the first generation is requested. 40 | If you'd want to force this to happen earlier, you can use the 41 | [`.load()`][rigging.generator.Generator.load] method. 42 | 43 | To unload, call [`.unload()`][rigging.generator.Generator.unload]. 44 | """ 45 | 46 | dtype: str = "auto" 47 | """Tensor dtype passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)""" 48 | quantization: str | None = None 49 | """Quantiziation passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)""" 50 | gpu_memory_utilization: float = 0.9 51 | """Memory utilization passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)""" 52 | enforce_eager: bool = False 53 | """Eager enforcement passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)""" 54 | trust_remote_code: bool = False 55 | """Trust remote code passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)""" 56 | 57 | # TODO: We should look at leveraging the AsyncLLMEngine or an 58 | # async alternative to the LLM class to allow for async generation 59 | 60 | _llm: vllm.LLM | None = None 61 | 62 | @property 63 | def llm(self) -> vllm.LLM: 64 | """The underlying [`vLLM model`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) instance.""" 65 | # Lazy initialization 66 | if self._llm is None: 67 | self._llm = vllm.LLM( 68 | self.model, 69 | dtype=self.dtype, 70 | quantization=self.quantization, 71 | gpu_memory_utilization=self.gpu_memory_utilization, 72 | enforce_eager=self.enforce_eager, 73 | trust_remote_code=self.trust_remote_code, 74 | ) 75 | return self._llm 76 | 77 | @classmethod 78 | def from_obj( 79 | cls, 80 | model: str, 81 | llm: vllm.LLM, 82 | *, 83 | params: GenerateParams | None = None, 84 | ) -> "VLLMGenerator": 85 | """Create a generator from an existing vLLM instance. 86 | 87 | Args: 88 | llm: The vLLM instance to create the generator from. 89 | 90 | Returns: 91 | The VLLMGenerator instance. 92 | """ 93 | generator = cls(model=model, params=params or GenerateParams(), api_key=None) 94 | generator._llm = llm 95 | return generator 96 | 97 | def load(self) -> "VLLMGenerator": 98 | _ = self.llm 99 | return self 100 | 101 | def unload(self) -> "VLLMGenerator": 102 | del self._llm 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | return self 106 | 107 | def _generate( 108 | self, 109 | texts: list[str], 110 | params: t.Sequence[GenerateParams], 111 | ) -> list[GeneratedText]: 112 | sampling_params_args = list( 113 | inspect.signature(vllm.SamplingParams.__init__).parameters.keys(), 114 | ) 115 | sampling_params = ( 116 | [ 117 | vllm.SamplingParams( 118 | **{ 119 | k: v 120 | for k, v in self.params.merge_with(p).to_dict().items() 121 | if k in sampling_params_args 122 | }, 123 | ) 124 | for p in params 125 | ] 126 | if params 127 | else None 128 | ) 129 | 130 | # Do a cache warmup step if we have a lot of texts 131 | if len(texts) > CACHE_TRIGGER: 132 | self.llm.generate( 133 | prompts=min(texts, key=len), 134 | use_tqdm=False, 135 | ) 136 | 137 | outputs = self.llm.generate( 138 | prompts=texts, 139 | sampling_params=sampling_params, 140 | use_tqdm=False, 141 | ) 142 | return [ 143 | GeneratedText( 144 | text=o.outputs[-1].text, 145 | stop_reason=o.outputs[-1].finish_reason or "unknown", 146 | extra={ 147 | "request_id": o.request_id, 148 | "metrics": o.metrics, 149 | "stop_token": o.outputs[-1].stop_reason, 150 | }, 151 | ) 152 | for o in outputs 153 | ] 154 | 155 | async def generate_messages( 156 | self, 157 | messages: t.Sequence[t.Sequence["Message"]], 158 | params: t.Sequence[GenerateParams], 159 | ) -> t.Sequence[GeneratedMessage]: 160 | message_dicts = [[m.to_openai() for m in _messages] for _messages in messages] 161 | tokenizer = self.llm.get_tokenizer() 162 | if not hasattr(tokenizer, "apply_chat_template"): 163 | raise RuntimeError( 164 | "The tokenizer does not support the apply_chat_template method.", 165 | ) 166 | 167 | texts = tokenizer.apply_chat_template( 168 | message_dicts, 169 | add_generation_prompt=True, 170 | tokenize=False, 171 | ) 172 | generated_texts = self._generate(t.cast("list[str]", texts), params=params) 173 | generated = [g.to_generated_message() for g in generated_texts] 174 | 175 | for i, (in_messages, out_message) in enumerate(zip(messages, generated, strict=False)): 176 | trace_messages(in_messages, f"Messages {i + 1}/{len(in_messages)}") 177 | trace_messages([out_message], f"Response {i + 1}/{len(in_messages)}") 178 | 179 | return generated 180 | 181 | async def generate_texts( 182 | self, 183 | texts: t.Sequence[str], 184 | params: t.Sequence[GenerateParams], 185 | ) -> t.Sequence[GeneratedText]: 186 | generated = self._generate(list(texts), params=params) 187 | 188 | for i, (text, response) in enumerate(zip(texts, generated, strict=False)): 189 | trace_str(text, f"Text {i + 1}/{len(texts)}") 190 | trace_str(response, f"Generated {i + 1}/{len(texts)}") 191 | 192 | return generated 193 | -------------------------------------------------------------------------------- /tests/test_watchers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing as t 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | import rigging as rg 8 | from rigging.chat import Chat 9 | from rigging.message import Message 10 | 11 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001, FBT001, FBT002, N803 12 | 13 | 14 | @pytest.fixture 15 | def sample_chats() -> list[Chat]: 16 | chat1 = Chat( 17 | messages=[ 18 | Message(role="user", content="Hello"), 19 | Message(role="assistant", content="Hi there!"), 20 | ], 21 | ) 22 | chat2 = Chat( 23 | messages=[ 24 | Message(role="user", content="How are you?"), 25 | Message(role="assistant", content="I'm doing well!"), 26 | ], 27 | ) 28 | return [chat1, chat2] 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_write_chats_to_jsonl(tmp_path: Path, sample_chats: list[Chat]) -> None: 33 | output_file = tmp_path / "chats.jsonl" 34 | watcher = rg.watchers.write_chats_to_jsonl(output_file) 35 | 36 | await watcher(sample_chats) 37 | 38 | assert output_file.exists() 39 | 40 | # read and verify contents 41 | with output_file.open() as f: 42 | lines = f.readlines() 43 | assert len(lines) == 2 44 | 45 | # verify each chat was written correctly 46 | for i, line in enumerate(lines): 47 | saved_chat = json.loads(line) 48 | original_chat = sample_chats[i] 49 | for k, message in enumerate(saved_chat["messages"]): 50 | assert message["role"] == original_chat.messages[k].role 51 | assert message["content"] == [ 52 | {"text": original_chat.messages[k].content, "type": "text"}, 53 | ] 54 | 55 | 56 | @pytest.mark.asyncio 57 | async def test_write_chats_to_jsonl_append(tmp_path: Path, sample_chats: list[Chat]) -> None: 58 | output_file = tmp_path / "chats.jsonl" 59 | watcher = rg.watchers.write_chats_to_jsonl(output_file) 60 | 61 | # write first batch 62 | await watcher(sample_chats[:1]) 63 | 64 | # write second batch 65 | await watcher(sample_chats[1:]) 66 | 67 | with output_file.open() as f: 68 | lines = f.readlines() 69 | assert len(lines) == 2 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_write_chats_to_jsonl_replace(tmp_path: Path, sample_chats: list[Chat]) -> None: 74 | output_file = tmp_path / "chats.jsonl" 75 | 76 | # write initial content 77 | watcher1 = rg.watchers.write_chats_to_jsonl(output_file) 78 | await watcher1(sample_chats) 79 | 80 | # create new watcher with replace=True 81 | watcher2 = rg.watchers.write_chats_to_jsonl(output_file, replace=True) 82 | 83 | # write only one chat - should replace previous content 84 | await watcher2(sample_chats[:1]) 85 | 86 | with output_file.open() as f: 87 | lines = f.readlines() 88 | assert len(lines) == 1 89 | saved_chat = json.loads(lines[0]) 90 | original_chat = sample_chats[0] 91 | for k, message in enumerate(saved_chat["messages"]): 92 | assert message["role"] == original_chat.messages[k].role 93 | assert message["content"] == [ 94 | {"text": original_chat.messages[k].content, "type": "text"}, 95 | ] 96 | 97 | # write another chat - should append since already replaced once 98 | await watcher2(sample_chats[1:2]) 99 | 100 | with output_file.open() as f: 101 | lines = f.readlines() 102 | assert len(lines) == 2 103 | 104 | # verify both chats were written correctly 105 | for i, line in enumerate(lines): 106 | saved_chat = json.loads(line) 107 | original_chat = sample_chats[i] 108 | for j, message in enumerate(saved_chat["messages"]): 109 | assert message["role"] == original_chat.messages[j].role 110 | assert message["content"] == [ 111 | {"text": original_chat.messages[j].content, "type": "text"}, 112 | ] 113 | 114 | 115 | class MockS3Client: 116 | class exceptions: # noqa: N801 117 | class ClientError(Exception): 118 | def __init__(self, code: str): 119 | self.response = {"Error": {"Code": code}} 120 | 121 | class Body: 122 | def __init__(self, content: str): 123 | self.content = content 124 | 125 | def read(self) -> bytes: 126 | return self.content.encode() 127 | 128 | def __init__(self) -> None: 129 | self.buckets: dict[str, t.Any] = {"test-bucket": {}} 130 | 131 | def head_object(self, Bucket: str, Key: str) -> t.Any: 132 | if Bucket not in self.buckets: 133 | raise self.exceptions.ClientError("404") 134 | if Key not in self.buckets[Bucket]: 135 | raise self.exceptions.ClientError("404") 136 | return self.buckets[Bucket][Key] 137 | 138 | def get_object(self, Bucket: str, Key: str) -> t.Any: 139 | if Bucket not in self.buckets: 140 | raise self.exceptions.ClientError("404") 141 | if Key not in self.buckets[Bucket]: 142 | raise self.exceptions.ClientError("404") 143 | return {"Body": MockS3Client.Body(self.buckets[Bucket][Key])} 144 | 145 | def delete_object(self, Bucket: str, Key: str) -> None: 146 | if Bucket not in self.buckets: 147 | raise self.exceptions.ClientError("404") 148 | if Key not in self.buckets[Bucket]: 149 | raise self.exceptions.ClientError("404") 150 | del self.buckets[Bucket][Key] 151 | 152 | def put_object(self, Bucket: str, Key: str, Body: str) -> None: 153 | self.buckets[Bucket][Key] = Body 154 | 155 | 156 | @pytest.mark.asyncio 157 | async def test_write_chats_to_s3(sample_chats: list[Chat]) -> None: 158 | s3_mock_client = MockS3Client() 159 | 160 | bucket = "test-bucket" 161 | key = "test/chats.jsonl" 162 | 163 | expected_content = "" 164 | for chat in sample_chats[:1]: 165 | expected_content += chat.model_dump_json() + "\n" 166 | 167 | watcher = rg.watchers.write_chats_to_s3(s3_mock_client, bucket, key) # type: ignore [arg-type] 168 | 169 | # write first batch 170 | await watcher(sample_chats[:1]) 171 | 172 | got = s3_mock_client.get_object(Bucket=bucket, Key=key) 173 | assert got["Body"].read() == expected_content.encode() 174 | 175 | # write second batch 176 | await watcher(sample_chats[1:]) 177 | 178 | expected_content = "" 179 | for chat in sample_chats: 180 | expected_content += chat.model_dump_json() + "\n" 181 | 182 | got = s3_mock_client.get_object(Bucket=bucket, Key=key) 183 | assert got["Body"].read() == expected_content.encode() 184 | 185 | # create a new watcher with replace=True 186 | watcher = rg.watchers.write_chats_to_s3(s3_mock_client, bucket, key, replace=True) # type: ignore [arg-type] 187 | 188 | # write a single chat 189 | await watcher(sample_chats[:1]) 190 | 191 | expected_content = "" 192 | for chat in sample_chats[:1]: 193 | expected_content += chat.model_dump_json() + "\n" 194 | 195 | # verify it's been replaced 196 | got = s3_mock_client.get_object(Bucket=bucket, Key=key) 197 | assert got["Body"].read() == expected_content.encode() 198 | 199 | # write second batch 200 | await watcher(sample_chats[1:]) 201 | 202 | expected_content = "" 203 | for chat in sample_chats: 204 | expected_content += chat.model_dump_json() + "\n" 205 | 206 | # verify replace happens only once 207 | got = s3_mock_client.get_object(Bucket=bucket, Key=key) 208 | assert got["Body"].read() == expected_content.encode() 209 | -------------------------------------------------------------------------------- /docs/topics/iterating-and-batching.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Iterating and Batching" 3 | description: "Iterating over messages, params, and generators, as well as batching of requests." 4 | public: true 5 | --- 6 | 7 | Rigging has good support for iterating over messages, parameters, and generators, as well as large batching of requests. How efficiently these mechanisms operates is dependent on the underlying generator that's being used, but Rigging has been developed with scale in mind. 8 | 9 | ## Multiple Generations 10 | 11 | The `run_many` functions let you scale out generation N times with the same inputs: 12 | 13 | - `ChatPipeline.run_many()` 14 | - `CompletionPipeline.run_many()` 15 | - `Prompt.run_many()` 16 | 17 | <CodeGroup> 18 | ```python Run Many Code 19 | import rigging as rg 20 | 21 | async def check_animal(chats: list[rg.Chat]) -> list[rg.Chat]: 22 | return [ 23 | await chat.continue_(f"Why did you pick that animal?").meta(questioned=True).run() 24 | if any(a in chat.last.content.lower() for a in ["cat", "dog", "cow", "mouse"]) 25 | else chat 26 | for chat in chats 27 | ] 28 | 29 | chats = ( 30 | await 31 | rg.get_generator("gpt-3.5-turbo") 32 | .chat("Tell me a joke about an animal.") 33 | .map(check_animal) 34 | .run_many(3) 35 | ) 36 | 37 | for i, chat in enumerate(chats): 38 | questioned = chat.metadata.get("questioned", False) 39 | print(f"--- Chat {i+1} (?: {questioned}) ---") 40 | print(chat.conversation) 41 | print() 42 | ``` 43 | 44 | ```text Outputs 45 | --- Chat 1 (?: False) --- 46 | [user]: Tell me a joke about an animal. 47 | 48 | [assistant]: Why did the spider go to the computer? 49 | 50 | To check his website! 51 | 52 | --- Chat 2 (?: False) --- 53 | [user]: Tell me a joke about an animal. 54 | 55 | [assistant]: Why did the chicken join a band? Because it had the drumsticks! 56 | 57 | --- Chat 3 (?: True) --- 58 | [user]: Tell me a joke about an animal. 59 | 60 | [assistant]: Why don't elephants use computers? 61 | 62 | Because they're afraid of the mouse! 63 | 64 | [user]: Why did you pick that animal? 65 | 66 | [assistant]: I chose an elephant because they are known for their intelligence and gentle nature, making them a popular subject for jokes and humorous anecdotes. Plus, imagining an elephant trying to use a computer and being scared of a tiny mouse is a funny visual image! 67 | ``` 68 | </CodeGroup> 69 | 70 | ## Batching Inputs 71 | 72 | The `run_batch` functions let you batch across a set of inputs: 73 | 74 | - `ChatPipeline.run_batch()` 75 | - `CompletionPipeline.run_batch()` 76 | 77 | As processing proceeds with things like `.then()` or `.map()`, the chats will resolve individually and collapse into the final results. 78 | 79 | <CodeGroup> 80 | ```python Batching Inputs 81 | import rigging as rg 82 | from rigging.model import CommaDelimitedAnswer 83 | 84 | pipeline = ( 85 | rg.get_generator('gpt-4-turbo') 86 | .chat({ 87 | "role": "system", 88 | "content": f"Always respond with {CommaDelimitedAnswer.xml_tags()} tags."} 89 | ) 90 | .until_parsed_as(CommaDelimitedAnswer, attempt_recovery=True) 91 | ) 92 | 93 | many = [f"Give me 3 famous {thing}" for thing in ["authors", "painters", "musicians", "hackers"]] 94 | 95 | chats = await pipeline.run_batch(many, on_failed='skip') 96 | 97 | for i, chat in enumerate(chats): 98 | print(f"--- Chat {i+1} ({len(chat)}) ---") 99 | print(chat.last.parse(CommaDelimitedAnswer).items) 100 | print() 101 | ``` 102 | 103 | ```text Output 104 | --- Chat 1 (2) --- 105 | ['Leonardo da Vinci', 'Vincent van Gogh', 'Pablo Picasso'] 106 | 107 | --- Chat 2 (2) --- 108 | ['Michael Jackson', 'Beyonce', 'The Beatles'] 109 | ``` 110 | </CodeGroup> 111 | 112 | <Tip> 113 | **Skipping failed results** 114 | 115 | Passing `on_failed='skip'` to `.run_batch`, or configuring a pipeline with `.catch(..., on_failed='skip')` will cause the function to ignore any parsing errors like `ExhaustedMaxRoundsError` and only return successful chats. 116 | </Tip> 117 | 118 | ## Batching Parameters 119 | 120 | In addition to batching against input messages or strings, you can fix a single input and build a batch across a set of generation parameters. The inputs to `.run_batch` will scale either the generate parameters or the input messages if either is a single item. 121 | 122 | <CodeGroup> 123 | ```python Batching 124 | import rigging as rg 125 | 126 | pipeline = rg.get_generator("gpt-3.5-turbo").chat() 127 | 128 | chats = await pipeline.run_batch( 129 | ["Tell me a short fact about an japanese city."], 130 | [rg.GenerateParams(temperature=t) for t in [0.6, 0.9, 1.2, 1.5, 1.8]] 131 | ) 132 | 133 | for i, chat in enumerate(chats): 134 | print(f"--- Chat {i+1} ---") 135 | print(chat.generator_id) 136 | print() 137 | print(chat.conversation) 138 | print() 139 | ``` 140 | 141 | ```text Output 142 | --- Chat 1 --- 143 | litellm!gpt-3.5-turbo,temperature=0.6 144 | 145 | [assistant]: Tokyo, the capital city of Japan, is the most populous 146 | metropolitan area in the world, with over 37 million residents. 147 | 148 | --- Chat 2 --- 149 | litellm!gpt-3.5-turbo,temperature=0.9 150 | 151 | [assistant]: Tokyo is the largest metropolitan area in the world, 152 | with a population of over 37 million people. 153 | 154 | --- Chat 3 --- 155 | litellm!gpt-3.5-turbo,temperature=1.2 156 | 157 | [assistant]: Kyoto, a city in Japan known for its historic temples 158 | and gardens, was once the capital of Japan for over 1,000 years from 159 | 794 until the capital was moved to Tokyo in 1869. 160 | 161 | --- Chat 4 --- 162 | litellm!gpt-3.5-turbo,temperature=1.5 163 | 164 | [assistant]: Nagoya, Japan is known for being one of the leading 165 | manufacturing and industrial regions in the country, with a strong 166 | automotive presence including major factories for Toyota, Honda, and Mitsubishi. 167 | 168 | --- Chat 5 --- 169 | litellm!gpt-3.5-turbo,temperature=1.8 170 | 171 | [assistant]: Sendai is the largest city in the Tohoku region of 172 | Japan and is known for its incredible natural scenery, such as the 173 | nearby Sendai Bay and Zuihoden mausoleum. 174 | ``` 175 | </CodeGroup> 176 | 177 | ## Iterating over Models 178 | 179 | The `run_over` functions let you execute generation over a set of generators: 180 | 181 | - `ChatPipeline.run_over()` 182 | - `CompletionPipeline.run_over()` 183 | - `Prompt.run_over()` 184 | 185 | Generators can be passed as string identifiers or full instances of `Generator`. By default the original generator associated with the `ChatPipeline` is included in the iteration, configurable with the `include_original` parameter. 186 | 187 | Much like the `run_many` and `run_batch` functions, you can control the handling of failures with the `on_failed` parameter. 188 | 189 | <CodeGroup> 190 | ```python Run Over 191 | import rigging as rg 192 | from rigging.model import Answer 193 | 194 | QUESTION = "What is the capital of France?" 195 | ANSWER = "paris" 196 | 197 | async def score_output(chats: list[rg.Chat]) -> list[rg.Chat]: 198 | return [ 199 | chat.meta(correct=chat.last.parse(Answer).content.lower() == ANSWER) 200 | for chat in chats 201 | ] 202 | 203 | chats = ( 204 | await 205 | rg.get_generator("gpt-3.5-turbo") 206 | .chat([ 207 | {"role": "system", "content": f"Always respond in one word between {Answer.xml_tags()} tags."}, 208 | {"role": "user", "content": QUESTION} 209 | ]) 210 | .until_parsed_as(Answer, max_rounds=3) 211 | .map(score_output) 212 | .run_over("gpt-4-turbo", "claude-3-haiku-20240307,temperature=0.5", "claude-3-sonnet-20240229") 213 | ) 214 | 215 | for chat in chats: 216 | print("Model: ", chat.generator.model) 217 | print("Msg: ", chat.last.content) 218 | print("Meta: ", chat.metadata) 219 | print() 220 | ``` 221 | 222 | ```text Outputs 223 | Model: gpt-4-turbo 224 | Msg: <answer>Paris</answer> 225 | Meta: {'correct': True} 226 | 227 | Model: claude-3-haiku-20240307 228 | Msg: <answer>Paris</answer> 229 | Meta: {'correct': True} 230 | 231 | Model: claude-3-sonnet-20240229 232 | Msg: <answer>Paris</answer> 233 | Meta: {'correct': True} 234 | 235 | Model: openai/gpt-3.5-turbo 236 | Msg: <answer>Paris</answer> 237 | Meta: {'correct': True} 238 | ``` 239 | </CodeGroup> 240 | -------------------------------------------------------------------------------- /examples/tokenize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8f675f31", 6 | "metadata": {}, 7 | "source": [ 8 | "### Load a math dataset\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "f679b928", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import typing as t\n", 19 | "\n", 20 | "import datasets\n", 21 | "import contextlib\n", 22 | "import rigging as rg\n", 23 | "\n", 24 | "def is_basic_question(sample: dict[str, t.Any]) -> bool:\n", 25 | " with contextlib.suppress(ValueError):\n", 26 | " float(sample[\"answer\"])\n", 27 | " return True\n", 28 | " return False\n", 29 | " \n", 30 | "dataset = [\n", 31 | " rg.Message(\"user\", sample[\"problem\"], metadata={**sample})\n", 32 | " for sample in datasets.load_dataset(\"HuggingFaceH4/MATH-500\", split=\"test\").filter(is_basic_question)\n", 33 | "]\n", 34 | "\n", 35 | "print(f\"Loaded {len(dataset)} basic questions from MATH-500.\")" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "33fe0a3f", 41 | "metadata": {}, 42 | "source": [ 43 | "### Define tools\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "28abc318", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "import rigging as rg\n", 54 | "from io import StringIO\n", 55 | "from contextlib import redirect_stdout\n", 56 | "\n", 57 | "# Dedicated calculator tool\n", 58 | "\n", 59 | "class Calculator:\n", 60 | " @rg.tool_method(catch=True)\n", 61 | " def add(self, x: float, y: float) -> float:\n", 62 | " \"\"\"Adds two numbers.\"\"\"\n", 63 | " return x + y\n", 64 | "\n", 65 | " @rg.tool_method(catch=True)\n", 66 | " def subtract(self, x: float, y: float) -> float:\n", 67 | " \"\"\"Subtracts the second number from the first.\"\"\"\n", 68 | " return x - y\n", 69 | "\n", 70 | " @rg.tool_method(catch=True)\n", 71 | " def multiply(self, x: float, y: float) -> float:\n", 72 | " \"\"\"Multiplies two numbers.\"\"\"\n", 73 | " return x * y\n", 74 | "\n", 75 | " @rg.tool_method(catch=True)\n", 76 | " def divide(self, x: float, y: float) -> float:\n", 77 | " \"\"\"Divides the first number by the second.\"\"\"\n", 78 | " if y == 0:\n", 79 | " raise ValueError(\"Cannot divide by zero.\")\n", 80 | " return x / y\n", 81 | "\n", 82 | "calculator = Calculator()\n", 83 | "\n", 84 | "# Python execution tool\n", 85 | "\n", 86 | "@rg.tool(catch=True)\n", 87 | "def execute_python(code: str) -> str:\n", 88 | " \"\"\"\n", 89 | " Executes Python code and returns stdout from the execution.\n", 90 | " \n", 91 | " - Use print() to output results.\n", 92 | " - Be thoughtful with indentation and syntax.\n", 93 | " \"\"\"\n", 94 | " output = StringIO()\n", 95 | " with redirect_stdout(output):\n", 96 | " exec(code)\n", 97 | " return output.getvalue()\n", 98 | " " 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "d1ea856c", 104 | "metadata": {}, 105 | "source": [ 106 | "### Define our agents\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "2962c1ac", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "import contextlib\n", 117 | "\n", 118 | "# Define a model for parsing the agent answer\n", 119 | "\n", 120 | "class Answer(rg.Model):\n", 121 | " value: str\n", 122 | "\n", 123 | "# Callback to inspect results after the agent execution\n", 124 | "\n", 125 | "async def inspect_results(chat: rg.Chat) -> rg.Chat:\n", 126 | " used_tools = any(msg.tool_calls for msg in chat.messages if msg.role == \"assistant\")\n", 127 | " answer = chat.message_metadata[\"answer\"]\n", 128 | " agent_answer = chat.last.try_parse(Answer)\n", 129 | "\n", 130 | " correct = False\n", 131 | " with contextlib.suppress(ValueError):\n", 132 | " correct = agent_answer is not None and float(agent_answer.value) == float(answer)\n", 133 | " \n", 134 | " return chat.meta(\n", 135 | " correct=correct,\n", 136 | " gave_answer=answer is not None,\n", 137 | " agent_answer=agent_answer.value if agent_answer else None,\n", 138 | " used_tools=used_tools,\n", 139 | " true_answer=answer\n", 140 | " )\n", 141 | "\n", 142 | "\n", 143 | "# Build our core pipeline\n", 144 | "\n", 145 | "pipeline = (\n", 146 | " rg.get_generator(\"groq/meta-llama/llama-4-maverick-17b-128e-instruct\")\n", 147 | " .chat(\"Answer math questions and return basic floats between <answer></answer> tags.\")\n", 148 | " .until_parsed_as(Answer)\n", 149 | " .then(inspect_results)\n", 150 | " .catch(Exception, on_failed=\"include\")\n", 151 | ")\n", 152 | "\n", 153 | "# Define 3 agents with different capabilities\n", 154 | "\n", 155 | "agent_no_tools = pipeline.clone().meta(variant=\"no_tools\")\n", 156 | "agent_with_calculator = pipeline.clone().using(calculator, mode=\"xml\").meta(variant=\"with_calculator\")\n", 157 | "agent_with_python = pipeline.clone().using(execute_python, mode=\"xml\").meta(variant=\"with_python\")\n" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "e0dae7e4", 163 | "metadata": {}, 164 | "source": [ 165 | "### Run 10 samples through our 3 agents\n" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "39bf862f", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "import random\n", 176 | "\n", 177 | "samples = random.sample(dataset, 10)\n", 178 | "\n", 179 | "chats_no_tools = await agent_no_tools.run_batch(samples)\n", 180 | "chats_with_calculator = await agent_with_calculator.run_batch(samples)\n", 181 | "chats_with_python = await agent_with_python.run_batch(samples)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "id": "b8305557", 187 | "metadata": {}, 188 | "source": [ 189 | "### Calculate success rates\n" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "436079a5", 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "def get_success_rate(chats: t.List[rg.Chat]) -> float:\n", 200 | " return sum(chat.metadata.get(\"correct\", False) for chat in chats) / len(chats)\n", 201 | "\n", 202 | "no_tools_success_rate = get_success_rate(chats_no_tools)\n", 203 | "with_calculator_success_rate = get_success_rate(chats_with_calculator)\n", 204 | "with_python_success_rate = get_success_rate(chats_with_python)\n", 205 | "\n", 206 | "print(f\"Success rate without tools: {no_tools_success_rate:.2%}\")\n", 207 | "print(f\"Success rate with calculator: {with_calculator_success_rate:.2%}\")\n", 208 | "print(f\"Success rate with Python: {with_python_success_rate:.2%}\")" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "12655a0c", 214 | "metadata": {}, 215 | "source": [ 216 | "### Tokenize the chat with a tokenizer\n" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "id": "3a7ef6dc", 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "tokenized = await chats_with_python.to_tokens('Qwen/Qwen2.5-1.5B-Instruct', transform=\"json-with-tag\")\n", 227 | "\n", 228 | "print(tokenized[0].metadata)\n", 229 | "print(tokenized[0].slices)" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": ".venv", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.10.14" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /.hooks/generate_docs.py: -------------------------------------------------------------------------------- 1 | import argparse # noqa: INP001 2 | import re 3 | import typing as t 4 | from pathlib import Path 5 | 6 | from markdown import Markdown # type: ignore [import-untyped] 7 | from markdownify import MarkdownConverter # type: ignore [import-untyped] 8 | from markupsafe import Markup 9 | from mkdocstrings_handlers.python._internal.config import PythonConfig 10 | from mkdocstrings_handlers.python._internal.handler import ( 11 | PythonHandler, 12 | ) 13 | 14 | # ruff: noqa: T201 15 | 16 | 17 | class CustomMarkdownConverter(MarkdownConverter): # type: ignore [misc] 18 | # Strip extra whitespace from code blocks 19 | def convert_pre(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: 20 | return super().convert_pre(el, text.strip(), parent_tags) 21 | 22 | # bold items with doc-section-title in a span class 23 | def convert_span(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: # noqa: ARG002 24 | if "doc-section-title" in el.get("class", []): 25 | return f"**{text.strip()}**" 26 | return text 27 | 28 | # Remove the div wrapper for inline descriptions 29 | def convert_div(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: 30 | if "doc-md-description" in el.get("class", []): 31 | return text.strip() 32 | return super().convert_div(el, text, parent_tags) 33 | 34 | # Map mkdocstrings details classes to Mintlify callouts 35 | def convert_details(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: # noqa: ARG002 36 | classes = el.get("class", []) 37 | 38 | # Handle source code details specially 39 | if "quote" in classes: 40 | summary = el.find("summary") 41 | if summary: 42 | file_path = summary.get_text().replace("Source code in ", "").strip() 43 | content = text[text.find("```") :] 44 | return f'\n<Accordion title="Source code in {file_path}" icon="code">\n{content}\n</Accordion>\n' 45 | 46 | callout_map = { 47 | "note": "Note", 48 | "warning": "Warning", 49 | "info": "Info", 50 | "tip": "Tip", 51 | } 52 | 53 | callout_type = None 54 | for cls in classes: 55 | if cls in callout_map: 56 | callout_type = callout_map[cls] 57 | break 58 | 59 | if not callout_type: 60 | return text 61 | 62 | content = text.strip() 63 | if content.startswith(callout_type): 64 | content = content[len(callout_type) :].strip() 65 | 66 | return f"\n<{callout_type}>\n{content}\n</{callout_type}>\n" 67 | 68 | def convert_table(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: 69 | # Check if this is a highlighttable (source code with line numbers) 70 | if "highlighttable" in el.get("class", []): 71 | code_cells = el.find_all("td", class_="code") 72 | if code_cells: 73 | code = code_cells[0].get_text() 74 | code = code.strip() 75 | code = code.replace("```", "~~~") 76 | return f"\n```python\n{code}\n```\n" 77 | 78 | return super().convert_table(el, text, parent_tags) 79 | 80 | 81 | class AutoDocGenerator: 82 | def __init__(self, source_paths: list[str], theme: str = "material", **options: t.Any) -> None: 83 | self.source_paths = source_paths 84 | self.theme = theme 85 | self.handler = PythonHandler(PythonConfig.from_data(), base_dir=Path.cwd()) 86 | self.options = options 87 | 88 | self.handler._update_env( # noqa: SLF001 89 | Markdown(), 90 | config={"mdx": ["toc"]}, 91 | ) 92 | 93 | md = Markdown(extensions=["fenced_code"]) 94 | 95 | def simple_convert_markdown( 96 | text: str, 97 | heading_level: int, 98 | html_id: str = "", 99 | **kwargs: t.Any, 100 | ) -> t.Any: 101 | return Markup(md.convert(text) if text else "") # noqa: S704 # nosec 102 | 103 | self.handler.env.filters["convert_markdown"] = simple_convert_markdown 104 | 105 | def generate_docs_for_module( 106 | self, 107 | module_path: str, 108 | ) -> str: 109 | options = self.handler.get_options( 110 | { 111 | "docstring_section_style": "list", 112 | "merge_init_into_class": True, 113 | "show_signature_annotations": True, 114 | "separate_signature": True, 115 | "show_source": True, 116 | "show_labels": False, 117 | "show_bases": False, 118 | **self.options, 119 | }, 120 | ) 121 | 122 | module_data = self.handler.collect(module_path, options) 123 | html = self.handler.render(module_data, options) 124 | 125 | return str( 126 | CustomMarkdownConverter( 127 | code_language="python", 128 | ).convert(html), 129 | ) 130 | 131 | def process_mdx_file(self, file_path: Path) -> bool: 132 | content = file_path.read_text(encoding="utf-8") 133 | original_content = content 134 | 135 | # Find the header comment block 136 | header_match = re.search( 137 | r"\{\s*/\*\s*((?:::.*?\n?)*)\s*\*/\s*\}", 138 | content, 139 | re.MULTILINE | re.DOTALL, 140 | ) 141 | 142 | if not header_match: 143 | return False 144 | 145 | header = header_match.group(0) 146 | module_lines = header_match.group(1).strip().split("\n") 147 | 148 | # Generate content for each module 149 | markdown_blocks = [] 150 | for line in module_lines: 151 | if line.startswith(":::"): 152 | module_path = line.strip()[3:].strip() 153 | if module_path: 154 | markdown = self.generate_docs_for_module(module_path) 155 | markdown_blocks.append(markdown) 156 | 157 | keep_end = content.find(header) + len(header) 158 | new_content = content[:keep_end] + "\n\n" + "\n".join(markdown_blocks) 159 | 160 | # Write back if changed 161 | if new_content != original_content: 162 | file_path.write_text(new_content, encoding="utf-8") 163 | print(f"[+] Updated: {file_path}") 164 | return True 165 | 166 | return False 167 | 168 | def process_directory(self, directory: Path, pattern: str = "**/*.mdx") -> int: 169 | if not directory.exists(): 170 | print(f"[!] Directory does not exist: {directory}") 171 | return 0 172 | 173 | files_processed = 0 174 | files_modified = 0 175 | 176 | for mdx_file in directory.glob(pattern): 177 | if mdx_file.is_file(): 178 | files_processed += 1 179 | if self.process_mdx_file(mdx_file): 180 | files_modified += 1 181 | 182 | return files_modified 183 | 184 | 185 | def main() -> None: 186 | """Main entry point for the script.""" 187 | 188 | parser = argparse.ArgumentParser(description="Generate auto-docs for MDX files") 189 | parser.add_argument("--directory", help="Directory containing MDX files", default="docs") 190 | parser.add_argument("--pattern", default="**/*.mdx", help="File pattern to match") 191 | parser.add_argument( 192 | "--source-paths", 193 | nargs="+", 194 | default=["dreadnode"], 195 | help="Python source paths for module discovery", 196 | ) 197 | parser.add_argument( 198 | "--show-if-no-docstring", 199 | type=bool, 200 | default=False, 201 | help="Show module/class/function even if no docstring is present", 202 | ) 203 | parser.add_argument("--theme", default="material", help="Theme to use for rendering") 204 | 205 | args = parser.parse_args() 206 | 207 | # Create generator 208 | generator = AutoDocGenerator( 209 | source_paths=args.source_paths, 210 | theme=args.theme, 211 | show_if_no_docstring=args.show_if_no_docstring, 212 | ) 213 | 214 | # Process directory 215 | directory = Path(args.directory) 216 | modified_count = generator.process_directory(directory, args.pattern) 217 | 218 | print(f"\n[+] Auto-doc generation complete. {modified_count} files were updated.") 219 | 220 | 221 | if __name__ == "__main__": 222 | main() 223 | -------------------------------------------------------------------------------- /rigging/generator/transformers_.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import typing as t 3 | 4 | import torch # type: ignore [import-not-found, import-untyped, unused-ignore] 5 | import transformers # type: ignore [import-not-found, import-untyped, unused-ignore] 6 | from transformers import ( # type: ignore [import-not-found, attr-defined, unused-ignore] 7 | AutoModelForCausalLM, 8 | AutoTokenizer, 9 | TextGenerationPipeline, 10 | ) 11 | 12 | from rigging.generator.base import ( 13 | GeneratedMessage, 14 | GeneratedText, 15 | GenerateParams, 16 | Generator, 17 | trace_messages, 18 | trace_str, 19 | ) 20 | 21 | if t.TYPE_CHECKING: 22 | from rigging.message import Message 23 | 24 | DEFAULT_MAX_TOKENS = 1024 25 | """Lifting the default max tokens from transformers""" 26 | 27 | 28 | class TransformersGenerator(Generator): 29 | """ 30 | Generator backed by the Transformers library for local model loading. 31 | 32 | Warning: 33 | The use of Transformers requires the `transformers` package to be installed directly or by 34 | installing rigging as `rigging[all]`. 35 | 36 | Warning: 37 | The `transformers` library is expansive with many different models, tokenizers, 38 | options, constructors, etc. We do our best to implement a consistent interface, 39 | but there may be limitations. Where needed, use 40 | [`.from_obj()`][rigging.generator.transformers_.TransformersGenerator.from_obj]. 41 | 42 | Note: 43 | This generator doesn't leverage any async capabilities. 44 | 45 | Note: 46 | The model load into memory will occur lazily when the first generation is requested. 47 | If you'd want to force this to happen earlier, you can use the 48 | [`.load()`][rigging.generator.Generator.load] method. 49 | 50 | To unload, call [`.unload()`][rigging.generator.Generator.unload]. 51 | """ 52 | 53 | torch_dtype: str = "auto" 54 | """Torch dtype passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" 55 | device_map: str = "auto" 56 | """Device map passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" 57 | trust_remote_code: bool = False 58 | """Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" 59 | load_in_8bit: bool = False 60 | """Load in 8 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" 61 | load_in_4bit: bool = False 62 | """Load in 4 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)""" 63 | 64 | _llm: AutoModelForCausalLM | None = None 65 | _tokenizer: AutoTokenizer | None = None 66 | _pipeline: TextGenerationPipeline | None = None 67 | 68 | @property 69 | def llm(self) -> AutoModelForCausalLM: 70 | """The underlying `AutoModelForCausalLM` instance.""" 71 | # Lazy initialization 72 | if self._llm is None: 73 | llm_kwargs = self.model_dump( 74 | exclude_unset=True, 75 | include={ 76 | "torch_dtype", 77 | "device_map", 78 | "trust_remote_code", 79 | "load_in_8bit", 80 | "load_in_4bit", 81 | }, 82 | ) 83 | self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call, assignment, unused-ignore] # nosec 84 | if self._llm is None: 85 | raise ValueError(f"Failed to load model '{self.model}'") 86 | return self._llm 87 | 88 | @property 89 | def tokenizer(self) -> AutoTokenizer: 90 | """The underlying `AutoTokenizer` instance.""" 91 | if self._tokenizer is None: 92 | self._tokenizer = AutoTokenizer.from_pretrained(self.model) # type: ignore [no-untyped-call, unused-ignore] # nosec 93 | return self._tokenizer 94 | 95 | @property 96 | def pipeline(self) -> TextGenerationPipeline: 97 | """The underlying `TextGenerationPipeline` instance.""" 98 | if self._pipeline is None: 99 | self._pipeline = transformers.pipeline( # type: ignore [attr-defined, call-overload, assignment, unused-ignore] 100 | "text-generation", 101 | return_full_text=False, 102 | model=self.llm, # type: ignore [arg-type, unused-ignore] 103 | tokenizer=self.tokenizer, # type: ignore [arg-type, unused-ignore] 104 | ) 105 | return self._pipeline # type: ignore [return-value, unused-ignore] 106 | 107 | @classmethod 108 | def from_obj( 109 | cls, 110 | model: t.Any, 111 | tokenizer: AutoTokenizer, 112 | *, 113 | pipeline: TextGenerationPipeline | None = None, 114 | params: GenerateParams | None = None, 115 | ) -> "TransformersGenerator": 116 | """ 117 | Create a new instance of TransformersGenerator from an already loaded model and tokenizer. 118 | 119 | Args: 120 | model: The loaded model for text generation. 121 | tokenizer: The tokenizer associated with the model. 122 | pipeline: The text generation pipeline. Defaults to None. 123 | params: Generation parameters. Defaults to None. 124 | 125 | Returns: 126 | The TransformersGenerator instance. 127 | """ 128 | instance = cls(model=model, params=params or GenerateParams(), api_key=None) 129 | instance._llm = model 130 | instance._tokenizer = tokenizer 131 | instance._pipeline = pipeline 132 | return instance 133 | 134 | def load(self) -> "TransformersGenerator": 135 | _ = self.pipeline 136 | return self 137 | 138 | def unload(self) -> "TransformersGenerator": 139 | del self._pipeline 140 | del self._llm 141 | del self._tokenizer 142 | gc.collect() 143 | torch.cuda.empty_cache() 144 | return self 145 | 146 | def _generate( 147 | self, 148 | inputs: t.Sequence[str] | t.Sequence[t.Sequence[dict[str, str]]], 149 | params: t.Sequence[GenerateParams], 150 | ) -> list[GeneratedText]: 151 | param_set = {p.model_dump_json() for p in params} 152 | if len(param_set) != 1: 153 | raise ValueError("All GenerateParams must be identical for this generator") 154 | 155 | # Generation Args + Fixups 156 | if self.params.max_tokens is None: 157 | self.params.max_tokens = DEFAULT_MAX_TOKENS 158 | 159 | kwargs = self.params.merge_with(params[0]).to_dict() 160 | if "max_tokens" in kwargs: 161 | kwargs["max_new_tokens"] = kwargs.pop("max_tokens") 162 | if any(k in kwargs for k in ["temperature", "top_k", "top_p"]): 163 | kwargs["do_sample"] = True 164 | 165 | outputs = self.pipeline(inputs, **kwargs) # type: ignore [call-overload] 166 | 167 | # TODO: We do strip() here as it's often needed, but I think 168 | # we should return and standardize this behavior. 169 | return [GeneratedText(text=o["generated_text"].strip()) for o in outputs] 170 | 171 | async def generate_messages( 172 | self, 173 | messages: t.Sequence[t.Sequence["Message"]], 174 | params: t.Sequence[GenerateParams], 175 | ) -> t.Sequence[GeneratedMessage]: 176 | message_dicts = [ 177 | [m.to_openai(compatibility_flags={"content_as_str"}) for m in _messages] 178 | for _messages in messages 179 | ] 180 | outputs = self._generate(message_dicts, params) 181 | generated = [o.to_generated_message() for o in outputs] 182 | 183 | for i, (in_messages, out_message) in enumerate(zip(messages, generated, strict=False)): 184 | trace_messages(in_messages, f"Messages {i + 1}/{len(in_messages)}") 185 | trace_messages([out_message], f"Response {i + 1}/{len(in_messages)}") 186 | 187 | return generated 188 | 189 | async def generate_texts( 190 | self, 191 | texts: t.Sequence[str], 192 | params: t.Sequence[GenerateParams], 193 | ) -> t.Sequence[GeneratedText]: 194 | generated = self._generate(texts, params) 195 | 196 | for i, (text, response) in enumerate(zip(texts, generated, strict=False)): 197 | trace_str(text, f"Text {i + 1}/{len(texts)}") 198 | trace_str(response, f"Generated {i + 1}/{len(texts)}") 199 | 200 | return generated 201 | -------------------------------------------------------------------------------- /docs/api/parsing.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: rigging.parsing 3 | --- 4 | 5 | {/* 6 | ::: rigging.parsing 7 | */} 8 | 9 | Parsing helpers for extracting rigging models from text 10 | 11 | parse 12 | ----- 13 | 14 | ```python 15 | parse( 16 | text: str, model_type: type[ModelT] 17 | ) -> tuple[ModelT, slice] 18 | ``` 19 | 20 | Parses a single model from text. 21 | 22 | **Parameters:** 23 | 24 | * **`text`** 25 | (`str`) 26 | –The content to parse. 27 | * **`model_type`** 28 | (`type[ModelT]`) 29 | –The type of model to parse. 30 | 31 | **Returns:** 32 | 33 | * `tuple[ModelT, slice]` 34 | –The parsed model. 35 | 36 | **Raises:** 37 | 38 | * `ValueError` 39 | –If no models of the given type are found and `fail_on_missing` is set to `True`. 40 | 41 | <Accordion title="Source code in rigging/parsing.py" icon="code"> 42 | ```python 43 | def parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice]: 44 | """ 45 | Parses a single model from text. 46 | 47 | Args: 48 | text: The content to parse. 49 | model_type: The type of model to parse. 50 | 51 | Returns: 52 | The parsed model. 53 | 54 | Raises: 55 | ValueError: If no models of the given type are found and `fail_on_missing` is set to `True`. 56 | """ 57 | return try_parse_many(text, model_type, fail_on_missing=True)[0] 58 | ``` 59 | 60 | 61 | </Accordion> 62 | 63 | parse\_many 64 | ----------- 65 | 66 | ```python 67 | parse_many( 68 | text: str, *types: type[ModelT] 69 | ) -> list[tuple[ModelT, slice]] 70 | ``` 71 | 72 | Parses multiple models of the specified non-identical types from text. 73 | 74 | **Parameters:** 75 | 76 | * **`text`** 77 | (`str`) 78 | –The content to parse. 79 | * **`*types`** 80 | (`type[ModelT]`, default: 81 | `()` 82 | ) 83 | –The types of models to parse. 84 | 85 | **Returns:** 86 | 87 | * `list[tuple[ModelT, slice]]` 88 | –A list of parsed models. 89 | 90 | **Raises:** 91 | 92 | * `MissingModelError` 93 | –If any of the models are missing. 94 | 95 | <Accordion title="Source code in rigging/parsing.py" icon="code"> 96 | ```python 97 | def parse_many(text: str, *types: type["ModelT"]) -> list[tuple["ModelT", slice]]: 98 | """ 99 | Parses multiple models of the specified non-identical types from text. 100 | 101 | Args: 102 | text: The content to parse. 103 | *types: The types of models to parse. 104 | 105 | Returns: 106 | A list of parsed models. 107 | 108 | Raises: 109 | MissingModelError: If any of the models are missing. 110 | """ 111 | return try_parse_many(text, *types, fail_on_missing=True) 112 | ``` 113 | 114 | 115 | </Accordion> 116 | 117 | parse\_set 118 | ---------- 119 | 120 | ```python 121 | parse_set( 122 | text: str, 123 | model_type: type[ModelT], 124 | *, 125 | minimum: int | None = None, 126 | ) -> list[tuple[ModelT, slice]] 127 | ``` 128 | 129 | Parses a set of models with the specified identical type from text. 130 | 131 | **Parameters:** 132 | 133 | * **`text`** 134 | (`str`) 135 | –The content to parse. 136 | * **`model_type`** 137 | (`type[ModelT]`) 138 | –The type of models to parse. 139 | * **`minimum`** 140 | (`int | None`, default: 141 | `None` 142 | ) 143 | –The minimum number of models required. 144 | 145 | **Returns:** 146 | 147 | * `list[tuple[ModelT, slice]]` 148 | –A list of parsed models. 149 | 150 | **Raises:** 151 | 152 | * `MissingModelError` 153 | –If the minimum number of models is not met. 154 | 155 | <Accordion title="Source code in rigging/parsing.py" icon="code"> 156 | ```python 157 | def parse_set( 158 | text: str, 159 | model_type: type["ModelT"], 160 | *, 161 | minimum: int | None = None, 162 | ) -> list[tuple["ModelT", slice]]: 163 | """ 164 | Parses a set of models with the specified identical type from text. 165 | 166 | Args: 167 | text: The content to parse. 168 | model_type: The type of models to parse. 169 | minimum: The minimum number of models required. 170 | 171 | Returns: 172 | A list of parsed models. 173 | 174 | Raises: 175 | MissingModelError: If the minimum number of models is not met. 176 | """ 177 | return try_parse_set(text, model_type, minimum=minimum, fail_on_missing=True) 178 | ``` 179 | 180 | 181 | </Accordion> 182 | 183 | try\_parse 184 | ---------- 185 | 186 | ```python 187 | try_parse( 188 | text: str, model_type: type[ModelT] 189 | ) -> tuple[ModelT, slice] | None 190 | ``` 191 | 192 | Tries to parse a model from text. 193 | 194 | **Parameters:** 195 | 196 | * **`text`** 197 | (`str`) 198 | –The content to parse. 199 | * **`model_type`** 200 | (`type[ModelT]`) 201 | –The type of model to search for. 202 | 203 | **Returns:** 204 | 205 | * `tuple[ModelT, slice] | None` 206 | –The first model that matches the given model type, or None if no match is found. 207 | 208 | <Accordion title="Source code in rigging/parsing.py" icon="code"> 209 | ```python 210 | def try_parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice] | None: 211 | """ 212 | Tries to parse a model from text. 213 | 214 | Args: 215 | text: The content to parse. 216 | model_type: The type of model to search for. 217 | 218 | Returns: 219 | The first model that matches the given model type, or None if no match is found. 220 | """ 221 | return next(iter(try_parse_many(text, model_type)), None) 222 | ``` 223 | 224 | 225 | </Accordion> 226 | 227 | try\_parse\_many 228 | ---------------- 229 | 230 | ```python 231 | try_parse_many( 232 | text: str, 233 | *types: type[ModelT], 234 | fail_on_missing: bool = False, 235 | ) -> list[tuple[ModelT, slice]] 236 | ``` 237 | 238 | Tries to parses multiple models of the specified non-identical types from text. 239 | 240 | **Parameters:** 241 | 242 | * **`text`** 243 | (`str`) 244 | –The content to parse. 245 | * **`*types`** 246 | (`type[ModelT]`, default: 247 | `()` 248 | ) 249 | –The types of models to parse. 250 | * **`fail_on_missing`** 251 | (`bool`, default: 252 | `False` 253 | ) 254 | –Whether to raise an exception if a model type is missing. 255 | 256 | **Returns:** 257 | 258 | * `list[tuple[ModelT, slice]]` 259 | –A list of parsed models. 260 | 261 | **Raises:** 262 | 263 | * `MissingModelError` 264 | –If a model type is missing and `fail_on_missing` is True. 265 | * `Exception` 266 | –If the model is malformed and `fail_on_missing` is True. 267 | 268 | <Accordion title="Source code in rigging/parsing.py" icon="code"> 269 | ```python 270 | def try_parse_many( 271 | text: str, 272 | *types: type["ModelT"], 273 | fail_on_missing: bool = False, 274 | ) -> list[tuple["ModelT", slice]]: 275 | """ 276 | Tries to parses multiple models of the specified non-identical types from text. 277 | 278 | Args: 279 | text: The content to parse. 280 | *types: The types of models to parse. 281 | fail_on_missing: Whether to raise an exception if a model type is missing. 282 | 283 | Returns: 284 | A list of parsed models. 285 | 286 | Raises: 287 | MissingModelError: If a model type is missing and `fail_on_missing` is True. 288 | Exception: If the model is malformed and `fail_on_missing` is True. 289 | """ 290 | model: ModelT 291 | parsed: list[tuple[ModelT, slice]] = [] 292 | 293 | try: 294 | for model_class in types: 295 | for model, slice_ in model_class.from_text(text): 296 | parsed.append((model, slice_)) 297 | except Exception: 298 | if fail_on_missing: 299 | raise 300 | 301 | return sorted(parsed, key=lambda x: x[1].start) 302 | ``` 303 | 304 | 305 | </Accordion> 306 | 307 | try\_parse\_set 308 | --------------- 309 | 310 | ```python 311 | try_parse_set( 312 | text: str, 313 | model_type: type[ModelT], 314 | *, 315 | minimum: int | None = None, 316 | fail_on_missing: bool = False, 317 | ) -> list[tuple[ModelT, slice]] 318 | ``` 319 | 320 | Tries to parse a set of models with the specified identical type from text. 321 | 322 | **Parameters:** 323 | 324 | * **`text`** 325 | (`str`) 326 | –The content to parse. 327 | * **`model_type`** 328 | (`type[ModelT]`) 329 | –The type of model to parse. 330 | * **`minimum`** 331 | (`int | None`, default: 332 | `None` 333 | ) 334 | –The minimum number of models expected. 335 | * **`fail_on_missing`** 336 | (`bool`, default: 337 | `False` 338 | ) 339 | –Whether to raise an exception if models are missing. 340 | 341 | **Returns:** 342 | 343 | * `list[tuple[ModelT, slice]]` 344 | –The parsed models. 345 | 346 | **Raises:** 347 | 348 | * `MissingModelError` 349 | –If the number of parsed models is less than the minimum required. 350 | 351 | <Accordion title="Source code in rigging/parsing.py" icon="code"> 352 | ```python 353 | def try_parse_set( 354 | text: str, 355 | model_type: type["ModelT"], 356 | *, 357 | minimum: int | None = None, 358 | fail_on_missing: bool = False, 359 | ) -> list[tuple["ModelT", slice]]: 360 | """ 361 | Tries to parse a set of models with the specified identical type from text. 362 | 363 | Args: 364 | text: The content to parse. 365 | model_type: The type of model to parse. 366 | minimum: The minimum number of models expected. 367 | fail_on_missing: Whether to raise an exception if models are missing. 368 | 369 | Returns: 370 | The parsed models. 371 | 372 | Raises: 373 | MissingModelError: If the number of parsed models is less than the minimum required. 374 | """ 375 | models = try_parse_many(text, model_type, fail_on_missing=fail_on_missing) 376 | if minimum is not None and len(models) < minimum: 377 | raise MissingModelError(f"Expected at least {minimum} {model_type.__name__} in message") 378 | return models 379 | ``` 380 | 381 | 382 | </Accordion> -------------------------------------------------------------------------------- /docs/api/error.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | title: rigging.error 3 | --- 4 | 5 | {/* 6 | ::: rigging.error 7 | */} 8 | 9 | We try to avoid creating custom exceptions unless they are necessary. 10 | 11 | We use the built-in and pydantic exceptions as much as possible. 12 | 13 | CompletionExhaustedMaxRoundsError 14 | --------------------------------- 15 | 16 | ```python 17 | CompletionExhaustedMaxRoundsError( 18 | max_rounds: int, completion: str 19 | ) 20 | ``` 21 | 22 | Raised when the maximum number of rounds is exceeded while generating completions. 23 | 24 | <Accordion title="Source code in rigging/error.py" icon="code"> 25 | ```python 26 | def __init__(self, max_rounds: int, completion: str): 27 | super().__init__(max_rounds) 28 | self.completion = completion 29 | """The completion which was being generated when the exception occurred.""" 30 | ``` 31 | 32 | 33 | </Accordion> 34 | 35 | ### completion 36 | 37 | ```python 38 | completion = completion 39 | ``` 40 | 41 | The completion which was being generated when the exception occurred. 42 | 43 | ExhaustedMaxRoundsError 44 | ----------------------- 45 | 46 | ```python 47 | ExhaustedMaxRoundsError(max_rounds: int) 48 | ``` 49 | 50 | Raised when the maximum number of rounds is exceeded while generating. 51 | 52 | <Accordion title="Source code in rigging/error.py" icon="code"> 53 | ```python 54 | def __init__(self, max_rounds: int): 55 | super().__init__(f"Exhausted max rounds ({max_rounds}) while generating") 56 | self.max_rounds = max_rounds 57 | """The number of rounds which was exceeded.""" 58 | ``` 59 | 60 | 61 | </Accordion> 62 | 63 | ### max\_rounds 64 | 65 | ```python 66 | max_rounds = max_rounds 67 | ``` 68 | 69 | The number of rounds which was exceeded. 70 | 71 | GeneratorWarning 72 | ---------------- 73 | 74 | Base class for all generator warnings. 75 | 76 | This is used to indicate that something unexpected happened during the generator execution, 77 | but it is not critical enough to stop the execution. 78 | 79 | InvalidGeneratorError 80 | --------------------- 81 | 82 | ```python 83 | InvalidGeneratorError(model: str) 84 | ``` 85 | 86 | Raised when an invalid identifier is specified when getting a generator. 87 | 88 | <Accordion title="Source code in rigging/error.py" icon="code"> 89 | ```python 90 | def __init__(self, model: str): 91 | super().__init__(f"Invalid model specified: {model}") 92 | ``` 93 | 94 | 95 | </Accordion> 96 | 97 | InvalidTokenizerError 98 | --------------------- 99 | 100 | ```python 101 | InvalidTokenizerError(tokenizer: str) 102 | ``` 103 | 104 | Raised when an invalid tokenizer is specified. 105 | 106 | <Accordion title="Source code in rigging/error.py" icon="code"> 107 | ```python 108 | def __init__(self, tokenizer: str): 109 | super().__init__(f"Invalid tokenizer specified: {tokenizer}") 110 | self.tokenizer = tokenizer 111 | """The name of the tokenizer which was invalid.""" 112 | ``` 113 | 114 | 115 | </Accordion> 116 | 117 | ### tokenizer 118 | 119 | ```python 120 | tokenizer = tokenizer 121 | ``` 122 | 123 | The name of the tokenizer which was invalid. 124 | 125 | MaxDepthError 126 | ------------- 127 | 128 | ```python 129 | MaxDepthError( 130 | max_depth: int, step: PipelineStep, reference: str 131 | ) 132 | ``` 133 | 134 | Raised when the maximum depth is exceeded while generating. 135 | 136 | <Accordion title="Source code in rigging/error.py" icon="code"> 137 | ```python 138 | def __init__(self, max_depth: int, step: "PipelineStep", reference: str): 139 | super().__init__(f"Exceeded max depth ({max_depth}) while generating ('{reference}')") 140 | self.max_depth = max_depth 141 | """The maximum depth of nested pipeline generations which was exceeded.""" 142 | self.step = step 143 | """The pipeline step which cause the depth error.""" 144 | ``` 145 | 146 | 147 | </Accordion> 148 | 149 | ### max\_depth 150 | 151 | ```python 152 | max_depth = max_depth 153 | ``` 154 | 155 | The maximum depth of nested pipeline generations which was exceeded. 156 | 157 | ### step 158 | 159 | ```python 160 | step = step 161 | ``` 162 | 163 | The pipeline step which cause the depth error. 164 | 165 | MessageWarning 166 | -------------- 167 | 168 | Base class for all message warnings. 169 | 170 | This is used to indicate that something unexpected happened during the message processing, 171 | but it is not critical enough to stop the execution. 172 | 173 | MessagesExhaustedMaxRoundsError 174 | ------------------------------- 175 | 176 | ```python 177 | MessagesExhaustedMaxRoundsError( 178 | max_rounds: int, messages: list[Message] 179 | ) 180 | ``` 181 | 182 | Raised when the maximum number of rounds is exceeded while generating messages. 183 | 184 | <Accordion title="Source code in rigging/error.py" icon="code"> 185 | ```python 186 | def __init__(self, max_rounds: int, messages: list["Message"]): 187 | super().__init__(max_rounds) 188 | self.messages = messages 189 | """The messages which were being generated when the exception occurred.""" 190 | ``` 191 | 192 | 193 | </Accordion> 194 | 195 | ### messages 196 | 197 | ```python 198 | messages = messages 199 | ``` 200 | 201 | The messages which were being generated when the exception occurred. 202 | 203 | MissingModelError 204 | ----------------- 205 | 206 | ```python 207 | MissingModelError(content: str) 208 | ``` 209 | 210 | Raised when a model is missing when parsing a message. 211 | 212 | <Accordion title="Source code in rigging/error.py" icon="code"> 213 | ```python 214 | def __init__(self, content: str): 215 | super().__init__(content) 216 | ``` 217 | 218 | 219 | </Accordion> 220 | 221 | PipelineWarning 222 | --------------- 223 | 224 | Base class for all pipeline warnings. 225 | 226 | This is used to indicate that something unexpected happened during the pipeline execution, 227 | but it is not critical enough to stop the execution. 228 | 229 | ProcessingError 230 | --------------- 231 | 232 | ```python 233 | ProcessingError(content: str) 234 | ``` 235 | 236 | Raised when an error occurs during internal generator processing. 237 | 238 | <Accordion title="Source code in rigging/error.py" icon="code"> 239 | ```python 240 | def __init__(self, content: str): 241 | super().__init__(content) 242 | ``` 243 | 244 | 245 | </Accordion> 246 | 247 | Stop 248 | ---- 249 | 250 | ```python 251 | Stop(message: str) 252 | ``` 253 | 254 | Raise inside a pipeline to indicate a stopping condition. 255 | 256 | Example 257 | 258 | ```python 259 | import rigging as rg 260 | 261 | async def read_file(path: str) -> str: 262 | "Read the contents of a file." 263 | 264 | if no_more_files(path): 265 | raise rg.Stop("There are no more files to read.") 266 | 267 | ... 268 | 269 | chat = await pipeline.using(read_file).run() 270 | ``` 271 | 272 | 273 | <Accordion title="Source code in rigging/error.py" icon="code"> 274 | ```python 275 | def __init__(self, message: str): 276 | super().__init__(message) 277 | self.message = message 278 | """The message associated with the stop.""" 279 | ``` 280 | 281 | 282 | </Accordion> 283 | 284 | ### message 285 | 286 | ```python 287 | message = message 288 | ``` 289 | 290 | The message associated with the stop. 291 | 292 | TokenizerWarning 293 | ---------------- 294 | 295 | Base class for all tokenization warnings. 296 | 297 | This is used to indicate that something unexpected happened during the tokenization process, 298 | but it is not critical enough to stop the execution. 299 | 300 | ToolDefinitionError 301 | ------------------- 302 | 303 | ```python 304 | ToolDefinitionError(message: str) 305 | ``` 306 | 307 | Raised when a tool cannot be properly defined. 308 | 309 | <Accordion title="Source code in rigging/error.py" icon="code"> 310 | ```python 311 | def __init__(self, message: str): 312 | super().__init__(message) 313 | ``` 314 | 315 | 316 | </Accordion> 317 | 318 | ToolWarning 319 | ----------- 320 | 321 | Base class for all tool warnings. 322 | 323 | This is used to indicate that something unexpected happened during the tool execution, 324 | but it is not critical enough to stop the execution. 325 | 326 | UnknownToolError 327 | ---------------- 328 | 329 | ```python 330 | UnknownToolError(tool_name: str) 331 | ``` 332 | 333 | Raised when the an api tool call is made for an unknown tool. 334 | 335 | <Accordion title="Source code in rigging/error.py" icon="code"> 336 | ```python 337 | def __init__(self, tool_name: str): 338 | super().__init__(f"Unknown tool call was requested for '{tool_name}'") 339 | self.tool_name = tool_name 340 | """The name of the tool which was unknown.""" 341 | ``` 342 | 343 | 344 | </Accordion> 345 | 346 | ### tool\_name 347 | 348 | ```python 349 | tool_name = tool_name 350 | ``` 351 | 352 | The name of the tool which was unknown. 353 | 354 | raise\_as 355 | --------- 356 | 357 | ```python 358 | raise_as( 359 | error_type: type[Exception], message: str 360 | ) -> t.Callable[[t.Callable[P, R]], t.Callable[P, R]] 361 | ``` 362 | 363 | When the wrapped function raises an exception, `raise ... from` with the new error type. 364 | 365 | <Accordion title="Source code in rigging/error.py" icon="code"> 366 | ```python 367 | def raise_as( 368 | error_type: type[Exception], 369 | message: str, 370 | ) -> t.Callable[[t.Callable[P, R]], t.Callable[P, R]]: 371 | "When the wrapped function raises an exception, `raise ... from` with the new error type." 372 | 373 | def _raise_as(func: t.Callable[P, R]) -> t.Callable[P, R]: 374 | @functools.wraps(func) 375 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 376 | try: 377 | return func(*args, **kwargs) 378 | except Exception as e: 379 | error = error_type(message) 380 | raise error from e 381 | 382 | if wrapper.__doc__ is None: 383 | wrapper.__doc__ = "" 384 | 385 | wrapper.__doc__ += f"\n\nRaises:\n {error_type.__name__}{': ' + message}" 386 | 387 | return wrapper 388 | 389 | return _raise_as 390 | ``` 391 | 392 | 393 | </Accordion> --------------------------------------------------------------------------------