├── .hooks
├── .gitkeep
├── linters
│ └── yamllint.yaml
├── prettier.sh
├── post_merge.sh
├── generate_pr_description.py
├── check_pinned_hash_dependencies.py
└── generate_docs.py
├── rigging
├── py.typed
├── version.py
├── tokenizer
│ ├── __init__.pyi
│ ├── __init__.py
│ └── transformers_.py
├── tools
│ ├── __init__.py
│ └── robopages.py
├── generator
│ ├── __init__.pyi
│ ├── __init__.py
│ ├── vllm_.py
│ └── transformers_.py
├── transform
│ ├── base.py
│ └── __init__.py
├── caching.py
├── logging.py
├── __init__.py
├── interact.py
├── parsing.py
└── error.py
├── tests
├── __init__.py
├── generators.py
├── test_completion_pipeline.py
├── test_generator_ids.py
├── test_model.py
├── test_prompt.py
└── test_watchers.py
├── CODEOWNERS
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── ISSUE_TEMPLATE
│ ├── config.yaml
│ ├── bug_report.yaml
│ └── feature_request.yaml
├── workflows
│ ├── semantic-prs.yaml
│ ├── meta-labeler.yaml
│ ├── meta-sync-labels.yaml
│ ├── ci.yml
│ ├── docs-update.yaml
│ ├── publish.yml
│ ├── rigging_pr_description.yml
│ ├── semgrep.yaml
│ ├── template-sync.yaml
│ └── renovate.yaml
├── renovate.json5
├── labeler.yaml
├── CONTRIBUTING.md
├── labels.yaml
└── CODE_OF_CONDUCT.md
├── docs
├── assets
│ └── tracing_logfire.png
├── topics
│ ├── logging.mdx
│ ├── completions.mdx
│ ├── tracing.mdx
│ ├── workflow.mdx
│ ├── serialization.mdx
│ └── iterating-and-batching.mdx
├── install.mdx
├── docs.json
└── api
│ ├── logging.mdx
│ ├── interact.mdx
│ ├── parsing.mdx
│ └── error.mdx
├── .vscode
└── settings.json
├── examples
├── robopages.py
├── chat.py
└── tokenize.ipynb
├── LICENSE
├── .pre-commit-config.yaml
├── .gitignore
├── pyproject.toml
├── .secrets.baseline
└── README.md
/.hooks/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rigging/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @dreadnode/team
2 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Notes
2 |
3 | -
4 |
5 | ---
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | blank_issues_enabled: false
3 |
--------------------------------------------------------------------------------
/rigging/version.py:
--------------------------------------------------------------------------------
1 | import importlib_metadata
2 |
3 | VERSION = importlib_metadata.version("rigging")
4 |
--------------------------------------------------------------------------------
/docs/assets/tracing_logfire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dreadnode/rigging/HEAD/docs/assets/tracing_logfire.png
--------------------------------------------------------------------------------
/.hooks/linters/yamllint.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | extends: default
3 |
4 | rules:
5 | line-length:
6 | max: 400
7 | level: warning
8 | truthy: false
9 | comments:
10 | min-spaces-from-content: 1
11 | braces: disable
12 | indentation: disable
13 |
--------------------------------------------------------------------------------
/rigging/tokenizer/__init__.pyi:
--------------------------------------------------------------------------------
1 | from rigging.tokenizer.base import (
2 | TokenizedChat,
3 | Tokenizer,
4 | TokenSlice,
5 | get_tokenizer,
6 | register_tokenizer,
7 | )
8 | from rigging.tokenizer.transformers_ import TransformersTokenizer
9 |
10 | __all__ = [
11 | "TokenSlice",
12 | "TokenizedChat",
13 | "Tokenizer",
14 | "TransformersTokenizer",
15 | "get_tokenizer",
16 | "register_tokenizer",
17 | ]
18 |
--------------------------------------------------------------------------------
/.github/workflows/semantic-prs.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: "Semantic Lints PR"
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | types:
8 | - opened
9 | - edited
10 | - synchronize
11 | - reopened
12 |
13 | permissions:
14 | pull-requests: read
15 |
16 | jobs:
17 | main:
18 | name: Validate PR title
19 | runs-on: ubuntu-latest
20 | steps:
21 | - uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
22 | env:
23 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
24 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "[python]": {
3 | "editor.formatOnSave": true,
4 | "editor.codeActionsOnSave": {
5 | "source.fixAll": "explicit",
6 | "source.organizeImports": "explicit"
7 | },
8 | "editor.defaultFormatter": "charliermarsh.ruff"
9 | },
10 | "python.testing.pytestArgs": [
11 | "tests"
12 | ],
13 | "python.testing.unittestEnabled": false,
14 | "python.testing.pytestEnabled": true,
15 | "mypy.runUsingActiveInterpreter": true,
16 | "debugpy.debugJustMyCode": false,
17 | "jupyter.debugJustMyCode": false
18 | }
19 |
--------------------------------------------------------------------------------
/rigging/tools/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module defines handles tool interaction with rigging generation.
3 | """
4 |
5 | from rigging.tools.base import (
6 | FunctionCall,
7 | FunctionDefinition,
8 | Tool,
9 | ToolCall,
10 | ToolChoice,
11 | ToolDefinition,
12 | ToolMode,
13 | tool,
14 | tool_method,
15 | )
16 | from rigging.tools.mcp import as_mcp, mcp
17 | from rigging.tools.robopages import robopages
18 |
19 | __all__ = [
20 | "FunctionCall",
21 | "FunctionDefinition",
22 | "Tool",
23 | "ToolCall",
24 | "ToolChoice",
25 | "ToolDefinition",
26 | "ToolMode",
27 | "as_mcp",
28 | "mcp",
29 | "robopages",
30 | "tool",
31 | "tool_method",
32 | ]
33 |
--------------------------------------------------------------------------------
/examples/robopages.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | import litellm
5 | import logfire
6 |
7 | import rigging as rg
8 |
9 | # Constants
10 | SYSTEM_PROMPT = "Enumerate all open TCP ports on 127.0.0.1 and provide a vulnerability report"
11 |
12 | # LOGFIRE_PROJECT = "rigging-demos"
13 | logfire.configure()
14 | os.environ.setdefault("LOGFIRE_TOKEN", "") # (1)!
15 | litellm.callbacks = ["logfire"]
16 |
17 | rg.logging.configure_logging("debug")
18 |
19 |
20 | async def main():
21 | tools = rg.robopages("http://localhost:8000")
22 | chat = await rg.get_generator("gpt-4").chat(SYSTEM_PROMPT).using(*tools).run()
23 | print(chat.conversation)
24 |
25 |
26 | if __name__ == "__main__":
27 | asyncio.run(main())
28 |
--------------------------------------------------------------------------------
/.github/renovate.json5:
--------------------------------------------------------------------------------
1 | {
2 | $schema: "https://docs.renovatebot.com/renovate-schema.json",
3 | extends: [
4 | "config:base",
5 | ":disableRateLimiting",
6 | ":semanticCommits",
7 | ":enablePreCommit",
8 | ":automergeDigest",
9 | ":automergeBranch",
10 | ],
11 | dependencyDashboardTitle: "Renovate Dashboard 🤖",
12 | suppressNotifications: ["prIgnoreNotification"],
13 | rebaseWhen: "conflicted",
14 | commitBodyTable: true,
15 | "pre-commit": {
16 | enabled: true,
17 | },
18 | enabledManagers: [
19 | "github-actions",
20 | "dockerfile",
21 | "docker-compose",
22 | "pre-commit"
23 | ],
24 | ignorePaths: [],
25 | packageRules: [
26 | {
27 | description: "Auto merge non-major updates",
28 | matchUpdateTypes: ["minor", "patch"],
29 | automerge: true,
30 | automergeType: "pr",
31 | },
32 | ],
33 | }
34 |
--------------------------------------------------------------------------------
/.github/workflows/meta-labeler.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: "Labeler"
3 | on:
4 | pull_request_target:
5 | branches: ["main"]
6 | types: ["opened", "synchronize"]
7 |
8 | permissions:
9 | actions: read
10 | contents: read
11 | issues: write
12 | pull-requests: write
13 |
14 | jobs:
15 | labeler:
16 | name: Labeler
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Generate Token
20 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1
21 | id: app-token
22 | with:
23 | app-id: "${{ secrets.BOT_APP_ID }}"
24 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}"
25 |
26 | - name: Labeler
27 | uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
28 | with:
29 | configuration-path: .github/labeler.yaml
30 | repo-token: "${{ steps.app-token.outputs.token }}"
31 |
--------------------------------------------------------------------------------
/docs/topics/logging.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Logging"
3 | description: "How to view and configure logging"
4 | public: true
5 | ---
6 |
7 | Rigging uses [loguru](https://loguru.readthedocs.io/) for it's logging. By default it disables it's logger allowing users to choose when/how to gather messages.
8 |
9 | If you want to let rigging messages flow into loguru, you should enable it:
10 |
11 | ```python
12 | from loguru import logger
13 |
14 | logger.enable('rigging')
15 | ```
16 |
17 | If you want sane default handlers with dual console & file logging, you can use the `rigging.logging.configure_logging` function to configure loguru.
18 |
19 | ```python
20 | from rigging.logging import configure_logging
21 |
22 | configure_logging(
23 | 'info', # stderr level
24 | 'out.log', # log file (optional)
25 | 'trace' # log file level
26 | )
27 | ```
28 | *(This will remove existing handlers, so you might prefer to configure them yourself)*
29 |
--------------------------------------------------------------------------------
/rigging/generator/__init__.pyi:
--------------------------------------------------------------------------------
1 | from rigging.generator.base import (
2 | GeneratedMessage,
3 | GeneratedText,
4 | GenerateParams,
5 | Generator,
6 | StopReason,
7 | Usage,
8 | chat,
9 | complete,
10 | get_generator,
11 | get_identifier,
12 | register_generator,
13 | )
14 | from rigging.generator.http import HTTPGenerator
15 | from rigging.generator.litellm_ import LiteLLMGenerator
16 | from rigging.generator.transformers_ import TransformersGenerator
17 | from rigging.generator.vllm_ import VLLMGenerator
18 |
19 | __all__ = [
20 | "GenerateParams",
21 | "GeneratedMessage",
22 | "GeneratedText",
23 | "Generator",
24 | "HTTPGenerator",
25 | "LiteLLMGenerator",
26 | "StopReason",
27 | "TransformersGenerator",
28 | "Usage",
29 | "VLLMGenerator",
30 | "chat",
31 | "complete",
32 | "get_generator",
33 | "get_generator",
34 | "get_identifier",
35 | "register_generator",
36 | ]
37 |
--------------------------------------------------------------------------------
/docs/install.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Installation"
3 | description: "Install and set up Rigging"
4 | public: true
5 | ---
6 |
7 | We publish every version to Pypi:
8 |
9 |
10 | ```bash pip
11 | pip install -U rigging
12 | ```
13 |
14 | ```bash uv
15 | uv pip install -U rigging
16 | ```
17 |
18 | ```bash poetry
19 | poetry add rigging
20 | ```
21 |
22 |
23 | If you want all the extras (vLLM, transformers, examples), specify the `all` extra:
24 |
25 |
26 | ```bash pip
27 | pip install -U rigging[all]
28 | ```
29 |
30 | ```bash uv
31 | uv pip install -U rigging[all]
32 | ```
33 |
34 | ```bash poetry
35 | poetry add rigging[all]
36 | ```
37 |
38 |
39 | If you want to build from source:
40 |
41 | ```bash
42 | cd rigging/
43 | poetry install
44 | ```
45 |
46 | ## Migration Guides
47 |
48 | - **[Migrating from v2 -> v3](/topics/migrations#migrating-from-v2x-to-v3x)**
49 | - **[Migrating from v1 -> v2](/topics/migrations#migrating-from-v1x-to-v2x)**
50 |
--------------------------------------------------------------------------------
/examples/chat.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import click
4 |
5 | import rigging as rg
6 |
7 |
8 | async def main(generator_id: str, system_prompt: str) -> None:
9 | base_pipeline = rg.get_generator(generator_id).chat({"role": "system", "content": system_prompt})
10 | await rg.interact(base_pipeline)
11 |
12 |
13 | @click.command()
14 | @click.option(
15 | "-g",
16 | "--generator-id",
17 | type=str,
18 | required=True,
19 | help="Rigging generator identifier (gpt-4, mistral/mistral-medium, etc.)",
20 | )
21 | @click.option(
22 | "-s",
23 | "--system-prompt",
24 | type=str,
25 | default="You are a helpful assistant.",
26 | help="System prompt to use for the generator",
27 | )
28 | def cli(
29 | generator_id: str,
30 | system_prompt: str,
31 | ) -> None:
32 | """
33 | Rigging example of a basic terminal chat interaction.
34 | """
35 |
36 | asyncio.run(main(generator_id, system_prompt))
37 |
38 |
39 | if __name__ == "__main__":
40 | cli()
41 |
--------------------------------------------------------------------------------
/rigging/transform/base.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from rigging.generator import GenerateParams
4 | from rigging.message import (
5 | Message,
6 | )
7 |
8 | if t.TYPE_CHECKING:
9 | from rigging.chat import Chat
10 |
11 |
12 | @t.runtime_checkable
13 | class PostTransform(t.Protocol):
14 | def __call__(
15 | self,
16 | chat: "Chat",
17 | /,
18 | ) -> "t.Awaitable[Chat]":
19 | """
20 | Passed messages and params to transform.
21 | """
22 | ...
23 |
24 |
25 | @t.runtime_checkable
26 | class Transform(t.Protocol):
27 | def __call__(
28 | self,
29 | messages: list[Message],
30 | params: GenerateParams,
31 | /,
32 | ) -> t.Awaitable[tuple[list[Message], GenerateParams, PostTransform | None]]:
33 | """
34 | Passed messages and params to transform.
35 |
36 | May return an optional post-transform callback to be executed to unwind the transformation.
37 | """
38 | ...
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 dreadnode
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/rigging/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Tokenizers encode chats and associated message data into tokens for training and inference.
3 | """
4 |
5 | import typing as t
6 |
7 | from rigging.tokenizer.base import (
8 | TokenizedChat,
9 | Tokenizer,
10 | TokenSlice,
11 | get_tokenizer,
12 | register_tokenizer,
13 | )
14 |
15 |
16 | def get_transformers_lazy() -> type[Tokenizer]:
17 | try:
18 | from rigging.tokenizer.transformers_ import TransformersTokenizer
19 | except ImportError as e:
20 | raise ImportError(
21 | "TransformersTokenizer is not available. Please install `transformers` or use `rigging[llm]`.",
22 | ) from e
23 |
24 | return TransformersTokenizer
25 |
26 |
27 | register_tokenizer("transformers", get_transformers_lazy)
28 |
29 | __all__ = [
30 | "TokenSlice",
31 | "TokenizedChat",
32 | "Tokenizer",
33 | "get_tokenizer",
34 | "register_tokenizer",
35 | ]
36 |
37 |
38 | def __getattr__(name: str) -> t.Any:
39 | if name == "TransformersTokenizer":
40 | return get_transformers_lazy()
41 | raise AttributeError(f"module {__name__} has no attribute {name}")
42 |
--------------------------------------------------------------------------------
/.github/workflows/meta-sync-labels.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: "Meta Sync labels"
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: ["main"]
7 | paths: [".github/labels.yaml"]
8 |
9 | permissions:
10 | actions: read
11 | contents: read
12 | issues: write
13 | pull-requests: write
14 |
15 | jobs:
16 | labels:
17 | name: Sync Labels
18 | runs-on: ubuntu-latest
19 | steps:
20 | - name: Generate Token
21 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1
22 | id: app-token
23 | with:
24 | app-id: "${{ secrets.BOT_APP_ID }}"
25 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}"
26 |
27 | - name: Set up git repository
28 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
29 | with:
30 | token: "${{ steps.app-token.outputs.token }}"
31 |
32 | - name: Sync Labels
33 | uses: EndBug/label-sync@52074158190acb45f3077f9099fea818aa43f97a # v2.3.3
34 | with:
35 | config-file: .github/labels.yaml
36 | token: "${{ steps.app-token.outputs.token }}"
37 | delete-other-labels: true
38 |
--------------------------------------------------------------------------------
/.hooks/prettier.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euo pipefail
3 |
4 | # Check if npm is installed
5 | if ! command -v npm &> /dev/null; then
6 | echo 'Error: npm is not installed.' >&2
7 | exit 1
8 | fi
9 |
10 | # Check if Prettier is installed, install it if missing
11 | if ! command -v prettier &> /dev/null; then
12 | echo 'Error: Prettier is not installed.' >&2
13 | echo 'Installing Prettier...'
14 | npm install -g prettier
15 | fi
16 |
17 | # Verify Prettier is installed
18 | if ! command -v prettier &> /dev/null; then
19 | echo 'Error: Prettier installation failed.' >&2
20 | exit 1
21 | fi
22 |
23 | # Run Prettier on staged .json, .yaml, and .yml files
24 | echo "Running Prettier on staged files..."
25 |
26 | # List all staged files, filter for the desired extensions, and run Prettier
27 | git diff --cached --name-only --diff-filter=d |
28 | grep -E '\.(json|ya?ml)$' |
29 | xargs -I {} prettier --write {}
30 |
31 | # Add the files back to staging area as Prettier may have modified them
32 | git diff --name-only --diff-filter=d |
33 | grep -E '\.(json|ya?ml)$' |
34 | xargs git add
35 |
36 | echo "Prettier formatting completed."
37 | exit 0
38 |
--------------------------------------------------------------------------------
/rigging/caching.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from loguru import logger
4 |
5 | if t.TYPE_CHECKING:
6 | from rigging.message import Message
7 |
8 | CacheMode = t.Literal["latest"]
9 | """
10 | How to handle cache_control entries on messages.
11 |
12 | - latest: Assign cache_control to the latest 2 non-assistant messages in the pipeline before inference.
13 | """
14 |
15 |
16 | def apply_cache_mode_to_messages(
17 | mode: CacheMode | None,
18 | messages: "list[list[Message]]",
19 | ) -> "list[list[Message]]":
20 | if mode is None:
21 | return messages
22 |
23 | if mode != "latest":
24 | logger.warning(
25 | f"Unknown caching mode '{mode}', defaulting to 'latest'",
26 | )
27 | mode = "latest"
28 |
29 | # first remove existing cache settings
30 | updated: list[list[Message]] = []
31 | for _messages in messages:
32 | updated = [
33 | *updated,
34 | [m.clone().cache(cache_control=False) for m in _messages],
35 | ]
36 |
37 | # then apply the latest cache settings
38 | for _messages in updated:
39 | for message in [m for m in _messages if m.role != "assistant"][-2:]:
40 | message.cache(cache_control=True)
41 |
42 | return updated
43 |
--------------------------------------------------------------------------------
/.hooks/post_merge.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get pre-merge hash from the target branch
4 | old_hash=$(git show ORIG_HEAD:poetry.lock | md5sum 2> /dev/null || echo "")
5 |
6 | # Get current hash
7 | new_hash=$(md5sum poetry.lock 2> /dev/null || echo "")
8 |
9 | # Compare and run poetry install if changed
10 | if [ "$old_hash" != "$new_hash" ]; then
11 | echo "📦 Root dependencies changed. Running poetry install..."
12 | poetry install || {
13 | echo "❌ Failed to update dependencies"
14 | exit 1
15 | }
16 | echo "✅ Root dependencies updated!"
17 | else
18 | echo "📦 No root dependency changes"
19 | fi
20 |
21 | # Get pre-merge hash from the target branch
22 | old_hash=$(git show ORIG_HEAD:components/api/poetry.lock | md5sum 2> /dev/null || echo "")
23 |
24 | # Get current hash
25 | new_hash=$(md5sum components/api/poetry.lock 2> /dev/null || echo "")
26 |
27 | # Compare and run poetry install if changed
28 | if [ "$old_hash" != "$new_hash" ]; then
29 | echo "📦 API dependencies changed. Running poetry install..."
30 | cd components/api || exit
31 | if ! poetry install --with dev; then
32 | echo "❌ Failed to update dependencies"
33 | exit 1
34 | fi
35 | echo "✅ API dependencies updated!"
36 | else
37 | echo "📦 No API dependency changes"
38 | fi
39 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Lint, Typecheck, and Test
3 |
4 | on:
5 | push:
6 | branches: [main]
7 | pull_request:
8 | branches: [main]
9 |
10 | jobs:
11 | ci:
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: ["3.10", "3.11", "3.12"]
16 |
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - name: Checkout code
21 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
22 |
23 | - name: Install Poetry
24 | uses: abatilo/actions-poetry@b8f6fe29ba2eb78e0d45ccbf41cd14154c4e25b2
25 |
26 | - name: Configure Poetry
27 | run: |
28 | poetry config virtualenvs.create true --local
29 | poetry config virtualenvs.in-project true --local
30 |
31 | - name: Setup Python ${{ matrix.python-version }}
32 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548
33 | with:
34 | python-version: ${{ matrix.python-version }}
35 | cache: "poetry"
36 |
37 | - name: Install package
38 | run: poetry install --all-extras
39 |
40 | - name: Lint
41 | run: poetry run ruff check --output-format=github rigging
42 |
43 | - name: Typecheck
44 | if: always()
45 | run: poetry run mypy rigging
46 |
47 | - name: Test
48 | if: always()
49 | run: poetry run pytest
50 |
--------------------------------------------------------------------------------
/.github/workflows/docs-update.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Trigger Docs Update
3 |
4 | on:
5 | push:
6 | branches: [main]
7 | paths:
8 | - "docs/**"
9 | - ".hooks/generate_docs.py"
10 | - ".github/workflows/docs-update.yaml"
11 | workflow_dispatch:
12 |
13 | jobs:
14 | notify-docs:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1
18 | id: app-token
19 | with:
20 | app-id: ${{ vars.UPDATE_DOCS_APP_ID }}
21 | private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }}
22 | owner: "${{ github.repository_owner }}"
23 | repositories: |
24 | docs
25 |
26 | - name: Trigger docs repository workflow
27 | uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0
28 | with:
29 | token: ${{ steps.app-token.outputs.token }}
30 | repository: dreadnode/docs
31 | event-type: docs-update
32 | client-payload: |
33 | {
34 | "repository": "${{ github.repository }}",
35 | "ref": "${{ github.ref }}",
36 | "sha": "${{ github.sha }}",
37 | "source_dir": "docs",
38 | "target_dir": "open-source/rigging",
39 | "nav_target": "Open Source/Rigging"
40 | }
41 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Build and Publish
3 |
4 | on:
5 | push:
6 | tags: ["v*"]
7 |
8 | jobs:
9 | build-and-publish:
10 | name: Build and Publish
11 | environment: protected
12 | permissions:
13 | contents: read
14 | id-token: write
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Checkout code
18 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
19 |
20 | - name: Setup Python
21 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548
22 | with:
23 | python-version: "3.14"
24 |
25 | - name: Install Poetry
26 | uses: abatilo/actions-poetry@b8f6fe29ba2eb78e0d45ccbf41cd14154c4e25b2
27 |
28 | - name: Configure Poetry
29 | run: |
30 | poetry config virtualenvs.create true --local
31 | poetry config virtualenvs.in-project true --local
32 |
33 | - name: Install package
34 | run: poetry install
35 |
36 | - name: Validate version
37 | run: |
38 | TAG_VERSION=${GITHUB_REF#refs/tags/v}
39 | POETRY_VERSION=$(poetry version -s)
40 |
41 | if [ "$TAG_VERSION" != "$POETRY_VERSION" ]; then
42 | echo "Tag ($TAG_VERSION) doesn't match pyproject.toml ($POETRY_VERSION)"
43 | exit 1
44 | fi
45 |
46 | - name: Build package
47 | run: poetry build
48 |
49 | - name: Publish to PyPI
50 | uses: pypa/gh-action-pypi-publish@3317ede93a4981d0fc490510c6fcf8bf0e92ed05
51 |
--------------------------------------------------------------------------------
/rigging/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from rigging.transform.base import PostTransform, Transform
2 | from rigging.transform.json_tools import (
3 | JsonToolMode,
4 | make_tools_to_json_transform,
5 | tools_to_json_in_xml_transform,
6 | tools_to_json_transform,
7 | tools_to_json_with_tag_transform,
8 | )
9 | from rigging.transform.pythonic_tools import (
10 | make_tools_to_pythonic_transform,
11 | tools_to_pythonic_transform,
12 | )
13 | from rigging.transform.xml_tools import make_tools_to_xml_transform
14 |
15 |
16 | def get_transform(identifier: str) -> Transform:
17 | """
18 | Get a well-known transform by its identifier.
19 |
20 | Args:
21 | identifier: The identifier of the transform to retrieve.
22 |
23 | Returns:
24 | The corresponding transform callable.
25 | """
26 | match identifier:
27 | case "json":
28 | return tools_to_json_transform
29 | case "json-in-xml":
30 | return tools_to_json_in_xml_transform
31 | case "json-with-tag":
32 | return tools_to_json_with_tag_transform
33 | case "pythonic":
34 | return tools_to_pythonic_transform
35 | case _:
36 | raise ValueError(f"Unknown transform identifier: {identifier}")
37 |
38 |
39 | __all__ = [
40 | "JsonToolMode",
41 | "PostTransform",
42 | "Transform",
43 | "make_tools_to_json_transform",
44 | "make_tools_to_pythonic_transform",
45 | "make_tools_to_xml_transform",
46 | "tools_to_json_in_xml_transform",
47 | "tools_to_json_transform",
48 | "tools_to_json_with_tag_transform",
49 | "tools_to_pythonic_transform",
50 | ]
51 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: "🚨 Bug Report"
3 | description: File a bug report
4 | title: "🚨 [BUG] -
"
5 | labels: ["bug", "triage"]
6 | assignees:
7 | - octocat
8 | body:
9 | - type: markdown
10 | attributes:
11 | value: |
12 | Thanks for taking the time to fill out this bug report!
13 |
14 | - type: textarea
15 | id: what-happened
16 | attributes:
17 | label: What happened?
18 | description: Also tell us, what did you expect to happen?
19 | placeholder: |
20 | Steps to reproduce the behavior:
21 | 1.
22 | 2.
23 | 3.
24 |
25 | Expected behavior:
26 | ...
27 |
28 | Actual behavior:
29 | ...
30 | validations:
31 | required: true
32 |
33 | - type: textarea
34 | id: possible-fix
35 | attributes:
36 | label: Any suggestions for fixing this bug?
37 | description: If you have an idea to fix this bug, we'd love to hear it!
38 | validations:
39 | required: false
40 |
41 | - type: textarea
42 | id: logs
43 | attributes:
44 | label: Relevant log output
45 | description: Please copy and paste any relevant log output.
46 | render: shell
47 |
48 | - type: textarea
49 | id: environment
50 | attributes:
51 | label: Details about your environment
52 | description: Please provide the following information about your environment.
53 | placeholder: |
54 | ## Your Environment
55 | - Go Version:
56 | - Operating System:
57 | - Browser (if applicable):
58 | - Relevant env vars
59 |
60 | Tell us what you see!
61 | validations:
62 | required: false
63 |
--------------------------------------------------------------------------------
/.github/workflows/rigging_pr_description.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Update PR Description with Rigging
3 |
4 | on:
5 | pull_request:
6 | types: [opened, synchronize]
7 |
8 | jobs:
9 | update-description:
10 | runs-on: ubuntu-latest
11 | permissions:
12 | pull-requests: write
13 | contents: read
14 |
15 | steps:
16 | - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
17 | with:
18 | fetch-depth: 0 # full history for proper diffing
19 |
20 | - name: Set up Python
21 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
22 | with:
23 | python-version: "3.14"
24 |
25 | - name: Install uv
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install uv
29 |
30 | - name: Generate PR Description
31 | id: description
32 | env:
33 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
34 | run: |
35 | DESCRIPTION="$(uv run --no-project .hooks/generate_pr_description.py --base-ref "origin/${{ github.base_ref }}" --exclude "./*.lock")"
36 | {
37 | echo "description<> "$GITHUB_OUTPUT"
41 |
42 | - name: Update PR Description
43 | uses: nefrob/pr-description@4dcc9f3ad5ec06b2a197c5f8f93db5e69d2fdca7 # v1.2.0
44 | with:
45 | token: ${{ secrets.GITHUB_TOKEN }}
46 | content: |
47 |
48 | ---
49 |
50 | ## Generated Summary
51 |
52 | ${{ steps.description.outputs.description }}
53 |
54 | This summary was generated with ❤️ by [rigging](https://docs.dreadnode.io/rigging/)
55 |
--------------------------------------------------------------------------------
/.github/workflows/semgrep.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Semgrep Analysis
3 | on:
4 | merge_group:
5 | pull_request:
6 | branches:
7 | - main
8 | types:
9 | - opened
10 | - synchronize
11 | - reopened
12 | push:
13 | branches:
14 | - main
15 | schedule:
16 | - cron: "0 0 * * *" # Run daily at midnight UTC
17 |
18 | concurrency:
19 | group: pre-commit-${{ github.run_id }}
20 | cancel-in-progress: true
21 |
22 | permissions:
23 | actions: read
24 | checks: write
25 | contents: read
26 | pull-requests: write # Allows merge queue updates
27 | security-events: write # Required for GitHub Security tab
28 |
29 | jobs:
30 | semgrep:
31 | name: Semgrep Analysis
32 | runs-on: ubuntu-latest
33 | container:
34 | image: returntocorp/semgrep
35 |
36 | # Skip any PR created by dependabot to avoid permission issues:
37 | if: (github.actor != 'dependabot[bot]')
38 |
39 | steps:
40 | - name: Set up git repository
41 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
42 | with:
43 | token: ${{ secrets.GITHUB_TOKEN }}
44 |
45 | - name: Configure Git Safe Directory
46 | run: git config --global --add safe.directory "${GITHUB_WORKSPACE}"
47 |
48 | - name: Semgrep Analysis
49 | env:
50 | SEMGREP_RULES: >-
51 | p/python
52 | p/security-audit
53 | p/secrets
54 | p/owasp-top-ten
55 | p/supply-chain
56 | SEMGREP_TIMEOUT: 300 # 5-minute timeout per rule
57 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
58 | run: |
59 | semgrep ci \
60 | --config="${SEMGREP_RULES}" \
61 | --timeout="${SEMGREP_TIMEOUT}" \
62 | --sarif --output=semgrep-results.sarif
63 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: "💡 Feature Request"
3 | description: Create a new ticket for a new feature request
4 | title: "💡 [REQUEST] - "
5 | labels: ["question"]
6 | body:
7 | - type: textarea
8 | id: implementation_pr
9 | attributes:
10 | label: "Implementation PR"
11 | description: Associated pull request
12 | placeholder: "# Pull Request ID"
13 | validations:
14 | required: false
15 | - type: textarea
16 | id: reference_issues
17 | attributes:
18 | label: "Reference Issues"
19 | description: Related issues
20 | placeholder: "# Issue ID(s)"
21 | validations:
22 | required: false
23 | - type: textarea
24 | id: summary
25 | attributes:
26 | label: "Summary"
27 | description: Provide a brief explanation of the feature
28 | placeholder: Describe your feature request
29 | validations:
30 | required: true
31 | - type: textarea
32 | id: basic_example
33 | attributes:
34 | label: "Basic Example"
35 | description: Provide some basic examples of your feature
36 | placeholder: A few specific details about your feature request
37 | validations:
38 | required: true
39 | - type: textarea
40 | id: drawbacks
41 | attributes:
42 | label: "Drawbacks"
43 | description: What are the drawbacks/impacts of your feature request?
44 | placeholder: Identify the drawbacks and impacts while remaining neutral on your feature request
45 | validations:
46 | required: true
47 | - type: textarea
48 | id: unresolved_question
49 | attributes:
50 | label: "Unresolved questions"
51 | description: What questions remain unresolved?
52 | placeholder: Identify any unresolved issues
53 | validations:
54 | required: false
55 |
--------------------------------------------------------------------------------
/.github/labeler.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Area Labels
3 | area/docs:
4 | - changed-files:
5 | - any-glob-to-any-file: "docs/**/*"
6 |
7 | area/examples:
8 | - changed-files:
9 | - any-glob-to-any-file: "examples/**/*"
10 |
11 | area/github:
12 | - changed-files:
13 | - any-glob-to-any-file: ".github/**/*"
14 |
15 | area/pre-commit:
16 | - changed-files:
17 | - any-glob-to-any-file: ".pre-commit-config.yaml"
18 | - any-glob-to-any-file: ".hooks/**/*"
19 |
20 | area/python:
21 | - changed-files:
22 | - any-glob-to-any-file: "pyproject.toml"
23 | - any-glob-to-any-file: "requirements.txt"
24 | - any-glob-to-any-file: "*.py"
25 |
26 | area/security:
27 | - changed-files:
28 | - any-glob-to-any-file: "SECURITY.md"
29 | - any-glob-to-any-file: "secrets.baseline"
30 |
31 | area/taskfiles:
32 | - changed-files:
33 | - any-glob-to-any-file: "Taskfile.yaml"
34 |
35 | area/tests:
36 | - changed-files:
37 | - any-glob-to-any-file: "tests/**/*"
38 |
39 | area/workspace:
40 | - changed-files:
41 | - any-glob-to-any-file: "python.code-workspace"
42 |
43 | # Development Labels
44 | area/dev:
45 | - changed-files:
46 | - any-glob-to-any-file: "dev/**/*"
47 |
48 | # Semantic Type Labels
49 | type/digest:
50 | - head-branch: ["^renovate/"]
51 | - head-branch: ["^deps/"]
52 |
53 | type/patch:
54 | - any: ["title:/^(?:Fix|Patch|Update)/"]
55 |
56 | type/minor:
57 | - any: ["title:/^(?:Add|Feature|Improve)/"]
58 |
59 | type/major:
60 | - any: ["title:/^(?:BREAKING)/"]
61 |
62 | type/break:
63 | - any: ["body:/BREAKING CHANGE:/"]
64 |
65 | # Documentation Labels
66 | type/docs:
67 | - changed-files:
68 | - any-glob-to-any-file: "docs/**/*"
69 | - any-glob-to-any-file: "*.md"
70 |
71 | # Core Files Labels
72 | type/core:
73 | - changed-files:
74 | - any-glob-to-any-file: "CODEOWNERS"
75 | - any-glob-to-any-file: "LICENSE"
76 | - any-glob-to-any-file: "README.md"
77 |
--------------------------------------------------------------------------------
/rigging/logging.py:
--------------------------------------------------------------------------------
1 | """
2 | We use loguru for logging. This module provides a function to configure logging handlers.
3 |
4 | To just enable rigging logs to flow, call `logger.enable("rigging")` after importing the module.
5 | """
6 |
7 | import pathlib
8 | import sys
9 | import typing as t
10 |
11 | from loguru import logger
12 |
13 | g_configured: bool = False
14 |
15 | LogLevelList = ["trace", "debug", "info", "success", "warning", "error", "critical"]
16 | LogLevelLiteral = t.Literal["trace", "debug", "info", "success", "warning", "error", "critical"]
17 | """Valid logging levels."""
18 |
19 |
20 | def configure_logging(
21 | log_level: LogLevelLiteral = "info",
22 | log_file: pathlib.Path | None = None,
23 | log_file_level: LogLevelLiteral = "debug",
24 | ) -> None:
25 | """
26 | Configures common loguru handlers.
27 |
28 | Args:
29 | log_level: The desired log level.
30 | log_file: The path to the log file. If None, logging
31 | will only be done to the console.
32 | log_file_level: The log level for the log file.
33 | """
34 | global g_configured # noqa: PLW0603
35 |
36 | if g_configured:
37 | return
38 |
39 | logger.enable("rigging")
40 |
41 | logger.level("TRACE", color="", icon="[T]")
42 | logger.level("DEBUG", color="", icon="[_]")
43 | logger.level("INFO", color="", icon="[=]")
44 | logger.level("SUCCESS", color="", icon="[+]")
45 | logger.level("WARNING", color="", icon="[-]")
46 | logger.level("ERROR", color="", icon="[!]")
47 | logger.level("CRITICAL", color="", icon="[x]")
48 |
49 | custom_format = "{time:HH:mm:ss.SSS} | {level.icon} {message}"
50 |
51 | logger.remove()
52 | logger.add(sys.stderr, format=custom_format, level=log_level.upper())
53 |
54 | if log_file is not None:
55 | logger.add(log_file, format=custom_format, level=log_file_level.upper())
56 | logger.info(f"Logging to {log_file}")
57 |
58 | g_configured = True
59 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to this project
2 |
3 | We want to make contributing to this project as easy and transparent as
4 | possible.
5 |
6 | ## Pull Request Guidelines
7 |
8 | We actively welcome your pull requests.
9 |
10 | 1. Fork the repo and create your branch from `main`.
11 | 2. If you've added code that should be tested, add tests.
12 | 3. If you've changed APIs, update the documentation.
13 | 4. Ensure the test suite passes.
14 | 5. Make sure your code lints.
15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
16 |
17 | ### PR Description Format
18 |
19 | We use a standardized format for pull request descriptions to ensure
20 | consistency and clarity:
21 |
22 | 1. **Title**: Use a clear, concise title that summarizes the changes
23 | 2. **Key Changes**: List the most important updates
24 | 3. **Added**: Document new features or files
25 | 4. **Changed**: Highlight modifications to existing code
26 | 5. **Removed**: Note any deletions or removals
27 |
28 | Example:
29 |
30 | ```markdown
31 | ### Add device configuration automation
32 |
33 | **Key Changes:**
34 |
35 | - Implement dynamic device configuration
36 | - Add automated setup scripts
37 | - Update documentation
38 |
39 | **Added:**
40 |
41 | - New device setup module
42 | - Configuration templates
43 | - Setup guide
44 |
45 | **Changed:**
46 |
47 | - Refactored device initialization
48 | - Updated configuration format
49 | - Modified setup process
50 |
51 | **Removed:**
52 |
53 | - Legacy device configs
54 | - Deprecated setup scripts
55 | ```
56 |
57 | ## Contributor License Agreement ("CLA")
58 |
59 | In order to accept your pull request, we need you to submit a CLA. You only need
60 | to do this once to work on any of Facebook's open source projects.
61 |
62 | Complete your CLA here:
63 |
64 | ## Issues
65 |
66 | We use GitHub issues to track public bugs. Please ensure your description is
67 | clear and has sufficient instructions to be able to reproduce the issue.
68 |
69 | ## License
70 |
71 | By contributing to this project, you agree that your contributions will be licensed
72 | under the LICENSE file in the root directory of this source tree.
73 |
--------------------------------------------------------------------------------
/rigging/generator/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Generators produce completions for a given set of messages or text.
3 | """
4 |
5 | import typing as t
6 |
7 | from rigging.generator.base import (
8 | GeneratedMessage,
9 | GeneratedText,
10 | GenerateParams,
11 | Generator,
12 | StopReason,
13 | Usage,
14 | chat,
15 | complete,
16 | get_generator,
17 | get_identifier,
18 | register_generator,
19 | )
20 | from rigging.generator.http import HTTPGenerator, HttpHook, HttpHookAction, HTTPSpec
21 | from rigging.generator.litellm_ import LiteLLMGenerator
22 |
23 | register_generator("litellm", LiteLLMGenerator)
24 | register_generator("http", HTTPGenerator)
25 | register_generator(
26 | "base",
27 | Generator,
28 | ) # TODO: Helper while we sort out generators being required so many places.
29 |
30 |
31 | def get_vllm_lazy() -> type[Generator]:
32 | try:
33 | from rigging.generator.vllm_ import VLLMGenerator
34 | except ImportError as e:
35 | raise ImportError(
36 | "VLLMGenerator is not available. Please install `vllm` or use `rigging[llm]`.",
37 | ) from e
38 |
39 | return VLLMGenerator
40 |
41 |
42 | register_generator("vllm", get_vllm_lazy)
43 |
44 |
45 | def get_transformers_lazy() -> type[Generator]:
46 | try:
47 | from rigging.generator.transformers_ import TransformersGenerator
48 | except ImportError as e:
49 | raise ImportError(
50 | "TransformersGenerator is not available. Please install `transformers` or use `rigging[llm]`.",
51 | ) from e
52 |
53 | return TransformersGenerator
54 |
55 |
56 | register_generator("transformers", get_transformers_lazy)
57 |
58 | __all__ = [
59 | "GenerateParams",
60 | "GeneratedMessage",
61 | "GeneratedText",
62 | "Generator",
63 | "HTTPGenerator",
64 | "HTTPGenerator",
65 | "HTTPSpec",
66 | "HttpHook",
67 | "HttpHookAction",
68 | "LiteLLMGenerator",
69 | "StopReason",
70 | "Usage",
71 | "chat",
72 | "complete",
73 | "get_generator",
74 | "get_generator",
75 | "get_identifier",
76 | "register_generator",
77 | ]
78 |
79 |
80 | def __getattr__(name: str) -> t.Any:
81 | if name == "VLLMGenerator":
82 | return get_vllm_lazy()
83 | if name == "TransformersGenerator":
84 | return get_transformers_lazy()
85 | raise AttributeError(f"module {__name__} has no attribute {name}")
86 |
--------------------------------------------------------------------------------
/docs/docs.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://mintlify.com/docs.json",
3 | "theme": "mint",
4 | "name": "Docs",
5 | "colors": {
6 | "primary": "#ea580c",
7 | "light": "#F47150",
8 | "dark": "#333333"
9 | },
10 | "background": {
11 | "color": {
12 | "light": "#e3e3e8",
13 | "dark": "#09090b"
14 | }
15 | },
16 | "navigation": {
17 | "groups": [
18 | {
19 | "group": "Getting Started",
20 | "pages": ["intro", "install"]
21 | },
22 | {
23 | "group": "Usage",
24 | "pages": [
25 | "topics/workflow",
26 | "topics/pipelines",
27 | "topics/prompt-functions",
28 | "topics/chats-and-messages",
29 | "topics/generators",
30 | "topics/data-models",
31 | "topics/tools",
32 | "topics/transforms",
33 | "topics/message-slicing",
34 | "topics/tokenization",
35 | "topics/iterating-and-batching",
36 | "topics/tracing",
37 | "topics/completions",
38 | "topics/migrations",
39 | "topics/serialization",
40 | "topics/logging"
41 | ]
42 | },
43 | {
44 | "group": "API",
45 | "pages": [
46 | "api/chat",
47 | "api/completion",
48 | "api/generator",
49 | "api/model",
50 | "api/message",
51 | "api/prompt",
52 | "api/tools",
53 | "api/transform",
54 | "api/tokenize",
55 | "api/data",
56 | "api/watchers",
57 | "api/parsing",
58 | "api/interact",
59 | "api/logging",
60 | "api/error",
61 | "api/util"
62 | ]
63 | }
64 | ]
65 | },
66 | "navbar": {
67 | "links": [
68 | {
69 | "label": "Home",
70 | "href": "https://docs.dreadnode.io"
71 | },
72 | {
73 | "label": "Support",
74 | "href": "mailto:support@dreadnode.io"
75 | },
76 | {
77 | "label": "Blog",
78 | "href": "https://dreadnode.io/blog"
79 | }
80 | ],
81 | "primary": {
82 | "type": "button",
83 | "label": "Platform",
84 | "href": "https://platform.dreadnode.io"
85 | }
86 | },
87 | "footer": {
88 | "socials": {
89 | "x": "https://x.com/dreadnode",
90 | "github": "https://github.com/dreadnode",
91 | "linkedin": "https://linkedin.com/company/dreadnode"
92 | }
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/.github/workflows/template-sync.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Template Sync
3 | on:
4 | # checkov:skip=CKV_GHA_7: "Workflow dispatch inputs are required for manual debugging and configuration"
5 | workflow_dispatch:
6 | inputs:
7 | dryRun:
8 | description: Dry Run
9 | default: "false"
10 | required: false
11 | logLevel:
12 | description: Log Level
13 | default: "debug"
14 | required: false
15 |
16 | schedule:
17 | # Run on the 1st of every month at 00:00 UTC
18 | - cron: "0 0 1 * *"
19 |
20 | push:
21 | branches: ["main"]
22 | paths:
23 | - ".github/**"
24 | - ".hooks/**"
25 | - ".pre-commit-config.yaml"
26 | - ".mdlrc"
27 | - ".editorconfig"
28 | - "Taskfile.yaml"
29 | - ".task/**"
30 |
31 | permissions:
32 | contents: write
33 | pull-requests: write
34 |
35 | concurrency:
36 | group: ${{ github.workflow }}-${{ github.run_number || github.ref }}
37 | cancel-in-progress: true
38 |
39 | jobs:
40 | template-sync:
41 | name: Template Sync
42 | runs-on: ubuntu-latest
43 | steps:
44 | - name: Generate Token
45 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1
46 | id: app-token
47 | with:
48 | app-id: "${{ secrets.BOT_APP_ID }}"
49 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}"
50 | owner: "${{ github.repository_owner }}"
51 |
52 | - name: Checkout
53 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
54 | with:
55 | token: "${{ steps.app-token.outputs.token }}"
56 |
57 | - name: Template Sync
58 | uses: AndreasAugustin/actions-template-sync@v2
59 | with:
60 | source_gh_token: ${{ steps.app-token.outputs.token }}
61 | git_user_name: github-actions[bot]
62 | git_user_email: github-actions[bot]@users.noreply.github.com
63 | pr_title: "chore: sync infrastructure files with template"
64 | pr_labels: sync,template
65 | pr_body: |
66 | 🤖 A new version of the python template files is available.
67 |
68 | This PR was automatically created to sync the following:
69 | - GitHub Actions workflows
70 | - Pre-commit hooks and configs
71 | - Task definitions
72 | - Editor configs and linter rules
73 |
74 | Please review the changes carefully before merging.
75 | source_repo_path: dreadnode/python-template
76 | steps: "prechecks,pull,commit,push,pr"
77 | upstream_branch: main
78 |
--------------------------------------------------------------------------------
/.hooks/generate_pr_description.py:
--------------------------------------------------------------------------------
1 | # /// script
2 | # requires-python = ">=3.10"
3 | # dependencies = [
4 | # "rigging",
5 | # "typer",
6 | # ]
7 | # ///
8 |
9 | import asyncio
10 | import subprocess
11 | import typing as t
12 |
13 | import typer
14 |
15 | import rigging as rg
16 |
17 | TRUNCATION_WARNING = "\n---\n**Note**: Due to the large size of this diff, some content has been truncated."
18 |
19 |
20 | @rg.prompt
21 | def generate_pr_description(diff: str) -> t.Annotated[str, rg.Ctx("markdown")]: # type: ignore[empty-body]
22 | """
23 | Analyze the provided git diff and create a PR description in markdown format.
24 |
25 |
26 | - Keep the summary concise and informative.
27 | - Use bullet points to structure important statements.
28 | - Focus on key modifications and potential impact - if any.
29 | - Do not add in general advice or best-practice information.
30 | - Write like a developer who authored the changes.
31 | - Prefer flat bullet lists over nested.
32 | - Do not include any title structure.
33 | - If there are no changes, just provide "No relevant changes."
34 | - Order your bullet points by importance.
35 |
36 | """
37 |
38 |
39 | def get_diff(base_ref: str, source_ref: str, *, exclude: list[str] | None = None) -> str:
40 | """
41 | Get the git diff between two branches.
42 | """
43 |
44 | merge_base = subprocess.run(
45 | ["git", "merge-base", source_ref, base_ref],
46 | capture_output=True,
47 | text=True,
48 | check=True,
49 | ).stdout.strip()
50 |
51 | diff_command = ["git", "diff", "--no-color", merge_base, source_ref]
52 | if exclude:
53 | diff_command.extend(["--", ".", *[f":(exclude){path}" for path in exclude]])
54 |
55 | diff_text = subprocess.run(
56 | diff_command,
57 | capture_output=True,
58 | text=True,
59 | check=True,
60 | ).stdout
61 |
62 | return diff_text
63 |
64 |
65 | def main(
66 | base_ref: str = "origin/main",
67 | source_ref: str = "HEAD",
68 | generator_id: str = "openai/o3-mini",
69 | max_diff_lines: int = 10_000,
70 | exclude: list[str] | None = None,
71 | ) -> None:
72 | """
73 | Use rigging to generate a PR description from a git diff.
74 | """
75 |
76 | diff = get_diff(base_ref, source_ref, exclude=exclude)
77 | diff_lines = diff.split("\n")
78 | if len(diff_lines) > max_diff_lines:
79 | diff = "\n".join(diff_lines[:max_diff_lines]) + TRUNCATION_WARNING
80 |
81 | description = asyncio.run(generate_pr_description.bind(generator_id)(diff))
82 |
83 | print(description)
84 |
85 |
86 | if __name__ == "__main__":
87 | typer.run(main)
88 |
--------------------------------------------------------------------------------
/docs/api/logging.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: rigging.logging
3 | ---
4 |
5 | {/*
6 | ::: rigging.logging
7 | */}
8 |
9 | We use loguru for logging. This module provides a function to configure logging handlers.
10 |
11 | To just enable rigging logs to flow, call `logger.enable("rigging")` after importing the module.
12 |
13 | LogLevelLiteral
14 | ---------------
15 |
16 | ```python
17 | LogLevelLiteral = Literal[
18 | "trace",
19 | "debug",
20 | "info",
21 | "success",
22 | "warning",
23 | "error",
24 | "critical",
25 | ]
26 | ```
27 |
28 | Valid logging levels.
29 |
30 | configure\_logging
31 | ------------------
32 |
33 | ```python
34 | configure_logging(
35 | log_level: LogLevelLiteral = "info",
36 | log_file: Path | None = None,
37 | log_file_level: LogLevelLiteral = "debug",
38 | ) -> None
39 | ```
40 |
41 | Configures common loguru handlers.
42 |
43 | **Parameters:**
44 |
45 | * **`log_level`**
46 | (`LogLevelLiteral`, default:
47 | `'info'`
48 | )
49 | –The desired log level.
50 | * **`log_file`**
51 | (`Path | None`, default:
52 | `None`
53 | )
54 | –The path to the log file. If None, logging
55 | will only be done to the console.
56 | * **`log_file_level`**
57 | (`LogLevelLiteral`, default:
58 | `'debug'`
59 | )
60 | –The log level for the log file.
61 |
62 |
63 | ```python
64 | def configure_logging(
65 | log_level: LogLevelLiteral = "info",
66 | log_file: pathlib.Path | None = None,
67 | log_file_level: LogLevelLiteral = "debug",
68 | ) -> None:
69 | """
70 | Configures common loguru handlers.
71 |
72 | Args:
73 | log_level: The desired log level.
74 | log_file: The path to the log file. If None, logging
75 | will only be done to the console.
76 | log_file_level: The log level for the log file.
77 | """
78 | global g_configured # noqa: PLW0603
79 |
80 | if g_configured:
81 | return
82 |
83 | logger.enable("rigging")
84 |
85 | logger.level("TRACE", color="", icon="[T]")
86 | logger.level("DEBUG", color="", icon="[_]")
87 | logger.level("INFO", color="", icon="[=]")
88 | logger.level("SUCCESS", color="", icon="[+]")
89 | logger.level("WARNING", color="", icon="[-]")
90 | logger.level("ERROR", color="", icon="[!]")
91 | logger.level("CRITICAL", color="", icon="[x]")
92 |
93 | custom_format = "{time:HH:mm:ss.SSS} | {level.icon} {message}"
94 |
95 | logger.remove()
96 | logger.add(sys.stderr, format=custom_format, level=log_level.upper())
97 |
98 | if log_file is not None:
99 | logger.add(log_file, format=custom_format, level=log_file_level.upper())
100 | logger.info(f"Logging to {log_file}")
101 |
102 | g_configured = True
103 | ```
104 |
105 |
106 |
--------------------------------------------------------------------------------
/tests/generators.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from rigging import Message
4 | from rigging.generator import GenerateParams, Generator
5 | from rigging.generator.base import GeneratedMessage, GeneratedText
6 |
7 | # ruff: noqa: S101, ARG002
8 |
9 |
10 | class FixedGenerator(Generator):
11 | text: str
12 |
13 | async def generate_messages(
14 | self,
15 | messages: t.Sequence[t.Sequence[Message]],
16 | params: t.Sequence[GenerateParams],
17 | ) -> t.Sequence[GeneratedMessage]:
18 | return [GeneratedMessage.from_text(self.text, stop_reason="stop") for _ in messages]
19 |
20 | async def generate_texts(
21 | self,
22 | texts: t.Sequence[str],
23 | params: t.Sequence[GenerateParams],
24 | ) -> t.Sequence[GeneratedText]:
25 | return [GeneratedText.from_text(self.text, stop_reason="stop") for _ in texts]
26 |
27 |
28 | class EchoGenerator(Generator):
29 | async def generate_messages(
30 | self,
31 | messages: t.Sequence[t.Sequence[Message]],
32 | params: t.Sequence[GenerateParams],
33 | ) -> t.Sequence[GeneratedMessage]:
34 | return [GeneratedMessage.from_text(m[-1].content, stop_reason="stop") for m in messages]
35 |
36 | async def generate_texts(
37 | self,
38 | texts: t.Sequence[str],
39 | params: t.Sequence[GenerateParams],
40 | ) -> t.Sequence[GeneratedText]:
41 | return [GeneratedText.from_text(t, stop_reason="stop") for t in texts]
42 |
43 |
44 | class CallbackGenerator(Generator):
45 | message_callback: t.Callable[["CallbackGenerator", t.Sequence[Message]], str] | None = None
46 | text_callback: t.Callable[["CallbackGenerator", str], str] | None = None
47 |
48 | async def generate_messages(
49 | self,
50 | messages: t.Sequence[t.Sequence[Message]],
51 | params: t.Sequence[GenerateParams],
52 | ) -> t.Sequence[GeneratedMessage]:
53 | assert self.message_callback is not None
54 | return [
55 | GeneratedMessage.from_text(self.message_callback(self, m), stop_reason="stop")
56 | for m in messages
57 | ]
58 |
59 | async def generate_texts(
60 | self,
61 | texts: t.Sequence[str],
62 | params: t.Sequence[GenerateParams],
63 | ) -> t.Sequence[GeneratedText]:
64 | assert len(texts) == 1
65 | assert self.text_callback is not None
66 | return [
67 | GeneratedText.from_text(self.text_callback(self, text), stop_reason="stop")
68 | for text in texts
69 | ]
70 |
71 |
72 | class FailingGenerator(Generator):
73 | _exception: Exception = RuntimeError("Intentional failure")
74 |
75 | async def generate_messages(
76 | self,
77 | messages: t.Sequence[t.Sequence[Message]],
78 | params: t.Sequence[GenerateParams],
79 | ) -> t.Sequence[GeneratedMessage]:
80 | raise self._exception
81 |
82 | async def generate_texts(
83 | self,
84 | texts: t.Sequence[str],
85 | params: t.Sequence[GenerateParams],
86 | ) -> t.Sequence[GeneratedText]:
87 | raise self._exception
88 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | repos:
3 | - repo: https://github.com/pre-commit/pre-commit-hooks
4 | rev: v6.0.0
5 | hooks:
6 | - id: check-added-large-files
7 | args: [--maxkb=10240]
8 | - id: check-case-conflict
9 | - id: check-merge-conflict
10 | - id: check-executables-have-shebangs
11 | - id: check-json
12 | - id: check-shebang-scripts-are-executable
13 | - id: check-symlinks
14 | - id: check-yaml
15 | - id: detect-private-key
16 | - id: end-of-file-fixer
17 | exclude: ^docs/
18 | - id: trailing-whitespace
19 |
20 | - repo: https://github.com/rhysd/actionlint
21 | rev: v1.7.9
22 | hooks:
23 | - id: actionlint
24 |
25 | - repo: https://github.com/adrienverge/yamllint.git
26 | rev: v1.37.1
27 | hooks:
28 | - id: yamllint
29 | entry: yamllint --strict -c .hooks/linters/yamllint.yaml
30 |
31 | - repo: https://github.com/codespell-project/codespell
32 | rev: v2.4.1
33 | hooks:
34 | - id: codespell
35 | entry: codespell -q 3 -f --skip=".git,.github" -L te,Sie,braket,astroid
36 |
37 | # Python code security
38 | - repo: https://github.com/PyCQA/bandit
39 | rev: 1.9.2
40 | hooks:
41 | - id: bandit
42 | name: Code security checks
43 | args: ["-c", "pyproject.toml"]
44 | additional_dependencies: ["bandit[toml]"]
45 | exclude: ^tests/
46 |
47 | - repo: https://github.com/Yelp/detect-secrets
48 | rev: v1.5.0
49 | hooks:
50 | - id: detect-secrets
51 | args: ["--baseline", ".secrets.baseline", "--exclude-files", "examples/*"]
52 | exclude: .secrets.baseline
53 |
54 | # Clean jupyter notebook outputs
55 | - repo: https://github.com/kynan/nbstripout
56 | rev: 0.8.2
57 | hooks:
58 | - id: nbstripout
59 | args: [--keep-id]
60 |
61 | # - repo: https://github.com/astral-sh/ruff-pre-commit
62 | # rev: v0.11.7
63 | # hooks:
64 | # - id: ruff
65 | # args: [--fix]
66 | # - id: ruff-format
67 |
68 | # - repo: https://github.com/pre-commit/mirrors-mypy
69 | # rev: v1.15.0
70 | # hooks:
71 | # - id: mypy
72 | # additional_dependencies:
73 | # - "pydantic"
74 | # - "types-PyYAML"
75 | # - "types-requests"
76 | # - "types-setuptools"
77 |
78 | - repo: local
79 | hooks:
80 | # Ensure our GH actions are pinned to a specific hash
81 | - id: check-github-actions
82 | name: Check GitHub Actions for Pinned Dependencies
83 | entry: .hooks/check_pinned_hash_dependencies.py
84 | language: python
85 | files: \.github/.*\.yml$
86 |
87 | - id: prettier
88 | name: Run prettier
89 | entry: .hooks/prettier.sh
90 | language: script
91 | types: [json, yaml]
92 |
93 | # Generate documentation
94 | - id: generate-docs
95 | name: Generate docs
96 | entry: poetry run python .hooks/generate_docs.py
97 | language: system
98 | pass_filenames: false
99 | always_run: true
100 |
--------------------------------------------------------------------------------
/rigging/__init__.py:
--------------------------------------------------------------------------------
1 | from rigging import (
2 | caching,
3 | data,
4 | error,
5 | generator,
6 | logging,
7 | model,
8 | parsing,
9 | tokenizer,
10 | tools,
11 | transform,
12 | watchers,
13 | )
14 | from rigging.chat import (
15 | Chat,
16 | ChatPipeline,
17 | MapChatCallback,
18 | PipelineStep,
19 | PipelineStepContextManager,
20 | PipelineStepGenerator,
21 | ThenChatCallback,
22 | )
23 | from rigging.completion import (
24 | Completion,
25 | CompletionPipeline,
26 | MapCompletionCallback,
27 | ThenCompletionCallback,
28 | )
29 | from rigging.error import Stop
30 | from rigging.generator import (
31 | GeneratedMessage,
32 | GeneratedText,
33 | GenerateParams,
34 | Generator,
35 | chat,
36 | complete,
37 | get_generator,
38 | register_generator,
39 | )
40 | from rigging.interact import interact
41 | from rigging.message import (
42 | ContentAudioInput,
43 | ContentImageUrl,
44 | ContentText,
45 | Message,
46 | MessageDict,
47 | Messages,
48 | MessageSlice,
49 | )
50 | from rigging.model import Model, attr, element, wrapped
51 | from rigging.prompt import Ctx, Prompt, prompt
52 | from rigging.tokenizer import TokenizedChat, Tokenizer, get_tokenizer, register_tokenizer
53 | from rigging.tools import Tool, as_mcp, mcp, robopages, tool, tool_method
54 | from rigging.transform import PostTransform, Transform
55 | from rigging.util import await_
56 | from rigging.version import VERSION
57 |
58 | __version__ = VERSION
59 |
60 | __all__ = [
61 | "Chat",
62 | "ChatFormatter",
63 | "ChatPipeline",
64 | "Completion",
65 | "CompletionPipeline",
66 | "ContentAudioInput",
67 | "ContentImageUrl",
68 | "ContentText",
69 | "Ctx",
70 | "Decoder",
71 | "Encoder",
72 | "GenerateParams",
73 | "GeneratedMessage",
74 | "GeneratedText",
75 | "Generator",
76 | "MapChatCallback",
77 | "MapCompletionCallback",
78 | "Message",
79 | "MessageDict",
80 | "MessageSlice",
81 | "Messages",
82 | "Model",
83 | "PipelineStep",
84 | "PipelineStepContextManager",
85 | "PipelineStepGenerator",
86 | "PostTransform",
87 | "Prompt",
88 | "Stop",
89 | "ThenChatCallback",
90 | "ThenCompletionCallback",
91 | "TokenSlice",
92 | "TokenizedChat",
93 | "Tokenizer",
94 | "Tool",
95 | "Transform",
96 | "as_mcp",
97 | "attr",
98 | "await_",
99 | "caching",
100 | "chat",
101 | "complete",
102 | "data",
103 | "element",
104 | "error",
105 | "find_in_tokens",
106 | "generator",
107 | "get_generator",
108 | "get_tokenizer",
109 | "interact",
110 | "logging",
111 | "mcp",
112 | "model",
113 | "parsing",
114 | "prompt",
115 | "register_generator",
116 | "register_tokenizer",
117 | "robopages",
118 | "tokenizer",
119 | "tool",
120 | "tool_method",
121 | "tools",
122 | "transform",
123 | "watchers",
124 | "wrapped",
125 | ]
126 |
127 | from loguru import logger
128 |
129 | logger.disable("rigging")
130 |
--------------------------------------------------------------------------------
/.github/workflows/renovate.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Renovate
3 | on:
4 | # checkov:skip=CKV_GHA_7: "Workflow dispatch inputs are required for manual debugging and configuration"
5 | workflow_dispatch:
6 | inputs:
7 | dryRun:
8 | description: Dry Run
9 | default: "false"
10 | required: false
11 | logLevel:
12 | description: Log Level
13 | default: "debug"
14 | required: false
15 | version:
16 | description: Renovate version
17 | default: latest
18 | required: false
19 | schedule:
20 | # Run every evening at 20:00 UTC (8:00 PM UTC)
21 | - cron: "0 20 * * *"
22 | push:
23 | branches: ["main"]
24 | paths:
25 | - .github/renovate.json5
26 | - .github/renovate/**.json5
27 |
28 | permissions:
29 | contents: read
30 | pull-requests: write
31 | issues: write
32 |
33 | concurrency:
34 | group: ${{ github.workflow }}-${{ github.run_number || github.ref }}
35 | cancel-in-progress: true
36 |
37 | # Retrieve BOT_USER_ID via `curl -s "https://api.github.com/users/${BOT_USERNAME}%5Bbot%5D" | jq .id`
38 | env:
39 | WORKFLOW_DRY_RUN: false
40 | WORKFLOW_LOG_LEVEL: debug
41 | WORKFLOW_VERSION: latest # 37.59.8
42 | RENOVATE_PLATFORM: github
43 | RENOVATE_PLATFORM_COMMIT: true
44 | RENOVATE_ONBOARDING_CONFIG_FILE_NAME: .github/renovate.json5
45 | RENOVATE_AUTODISCOVER: true
46 | RENOVATE_AUTODISCOVER_FILTER: "${{ github.repository }}"
47 | RENOVATE_GIT_AUTHOR: "${{ secrets.BOT_USERNAME }} <${{ secrets.BOT_USER_ID }}+${{ secrets.BOT_USERNAME }}[bot]@users.noreply.github.com>"
48 |
49 | jobs:
50 | renovate:
51 | name: Renovate
52 | runs-on: ubuntu-latest
53 | steps:
54 | - name: Generate Token
55 | uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1
56 | id: app-token
57 | with:
58 | app-id: "${{ secrets.BOT_APP_ID }}"
59 | private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}"
60 |
61 | - name: Checkout
62 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
63 | with:
64 | token: "${{ steps.app-token.outputs.token }}"
65 |
66 | - name: Override default config from dispatch variables
67 | run: |
68 | echo "RENOVATE_DRY_RUN=${{ github.event.inputs.dryRun || env.WORKFLOW_DRY_RUN }}" >> "${GITHUB_ENV}"
69 | echo "LOG_LEVEL=${{ github.event.inputs.logLevel || env.WORKFLOW_LOG_LEVEL }}" >> "${GITHUB_ENV}"
70 |
71 | - name: Delete old dashboard
72 | run: |
73 | ISSUE_NUMBER=$(gh issue list -S 'Renovate Dashboard 🤖' --json number -q '.[0].number')
74 | if [ "$ISSUE_NUMBER" != "null" ] && [ -n "$ISSUE_NUMBER" ]; then
75 | gh issue close "$ISSUE_NUMBER"
76 | else
77 | echo "No issue found to close."
78 | fi
79 | env:
80 | GITHUB_TOKEN: "${{ steps.app-token.outputs.token }}"
81 |
82 | - name: Renovate
83 | uses: renovatebot/github-action@822441559e94f98b67b82d97ab89fe3003b0a247 # v44.2.0
84 | with:
85 | configurationFile: "${{ env.RENOVATE_ONBOARDING_CONFIG_FILE_NAME }}"
86 | token: "${{ steps.app-token.outputs.token }}"
87 | renovate-version: "${{ github.event.inputs.version || env.WORKFLOW_VERSION }}"
88 |
--------------------------------------------------------------------------------
/rigging/tools/robopages.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for integrating tools from a Robopages server.
3 | """
4 |
5 | import re
6 | import typing as t
7 |
8 | import httpx
9 | import requests
10 | from loguru import logger
11 | from pydantic import TypeAdapter
12 |
13 | from rigging.tools.base import Tool, ToolDefinition
14 |
15 | DEFAULT_HTTP_TIMEOUT = 10
16 |
17 |
18 | def make_execute_on_server(url: str, tool_name: str) -> t.Callable[..., t.Any]:
19 | async def execute_on_server(**kwargs: t.Any) -> t.Any:
20 | async with httpx.AsyncClient() as client:
21 | response = await client.post(
22 | f"{url}/process",
23 | json=[
24 | {
25 | "type": "function",
26 | "function": {
27 | "name": tool_name,
28 | "arguments": kwargs,
29 | },
30 | },
31 | ],
32 | )
33 | if response.status_code not in [200, 400]:
34 | response.raise_for_status()
35 |
36 | if response.status_code == 400: # noqa: PLR2004
37 | result = response.content.decode()
38 | else:
39 | result = response.json()[0]["content"]
40 |
41 | return result
42 |
43 | return execute_on_server
44 |
45 |
46 | def robopages(url: str, *, name_filter: str | None = None) -> list[Tool[..., t.Any]]:
47 | """
48 | Create a list of tools from a Robopages server.
49 |
50 | Args:
51 | url: The URL of the Robopages server.
52 | name_filter: A regular expression to filter the tools by name.
53 |
54 | Returns:
55 | A list of integrated tools which leverage the Robopages server.
56 |
57 | Example:
58 | ```
59 | import rigging as rg
60 |
61 | tools = rg.tool.robopages("http://localhost:8080")
62 |
63 | chat = (
64 | await rg.get_generator('gpt-4o')
65 | .chat('Please use tools')
66 | .using(*tools)
67 | .run()
68 | )
69 |
70 | print(chat.conversation)
71 | ```
72 | """
73 |
74 | filter_regex = re.compile(name_filter) if name_filter else None
75 |
76 | response = requests.get(url, params={"flavor": "openai"}, timeout=DEFAULT_HTTP_TIMEOUT)
77 | response.raise_for_status()
78 | tools_data = response.json()
79 |
80 | adapter = TypeAdapter(list[ToolDefinition])
81 | tool_definitions = adapter.validate_python(tools_data)
82 |
83 | logger.info(f"Fetched {len(tool_definitions)} functions from Robopages ({url})")
84 |
85 | tools: list[Tool[..., t.Any]] = []
86 | for definition in tool_definitions:
87 | function = definition.function
88 |
89 | if filter_regex and not filter_regex.search(function.name):
90 | logger.debug(f"Skipping function {function.name}")
91 | continue
92 |
93 | tools.append(
94 | Tool(
95 | name=function.name,
96 | description=function.description or "",
97 | parameters_schema=function.parameters or {},
98 | fn=make_execute_on_server(url, function.name),
99 | ),
100 | )
101 |
102 | return tools
103 |
--------------------------------------------------------------------------------
/tests/test_completion_pipeline.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from rigging.completion import Completion, CompletionPipeline
4 | from rigging.generator import GenerateParams, get_generator
5 |
6 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001
7 |
8 |
9 | def test_completion_generator_id() -> None:
10 | generator = get_generator("gpt-3.5")
11 | completion = Completion("foo", "bar", generator)
12 | assert completion.generator_id == "gpt-3.5"
13 |
14 | completion.generator = None
15 | assert completion.generator_id is None
16 |
17 |
18 | def test_completion_properties() -> None:
19 | generator = get_generator("gpt-3.5")
20 | completion = Completion("foo", "bar", generator)
21 | assert completion.text == "foo"
22 | assert completion.generated == "bar"
23 | assert completion.generator == generator
24 | assert len(completion) == len("foo") + len("bar")
25 | assert completion.all == "foobar"
26 |
27 |
28 | def test_completion_restart() -> None:
29 | generator = get_generator("gpt-3.5")
30 | completion = Completion("foo", "bar", generator)
31 | assert len(completion.restart()) == 3
32 | assert len(completion.restart(include_all=True)) == 6
33 |
34 | assert len(completion.fork("baz")) == 6
35 | assert len(completion.continue_("baz")) == 9
36 |
37 | completion.generator = None
38 | with pytest.raises(ValueError):
39 | completion.restart()
40 |
41 |
42 | def test_completion_clone() -> None:
43 | generator = get_generator("gpt-3.5")
44 | original = Completion("foo", "bar", generator).meta(key="value")
45 | clone = original.clone()
46 | assert clone.text == original.text
47 | assert clone.generated == original.generated
48 | assert clone.metadata == original.metadata
49 |
50 | clone_2 = original.clone(only_messages=True)
51 | assert clone.metadata != clone_2.metadata
52 |
53 |
54 | def test_completion_pipeline_with() -> None:
55 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "foo")
56 | with_pipeline = pipeline.with_(GenerateParams(max_tokens=123))
57 | assert with_pipeline == pipeline
58 | assert with_pipeline.params is not None
59 | assert with_pipeline.params.max_tokens == 123
60 |
61 | with_pipeline_2 = with_pipeline.with_(top_p=0.5)
62 | assert with_pipeline_2 != with_pipeline
63 | assert with_pipeline_2.params is not None
64 | assert with_pipeline_2.params.max_tokens == 123
65 | assert with_pipeline_2.params.top_p == 0.5
66 |
67 |
68 | def test_completion_pipeline_fork() -> None:
69 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "foo")
70 | forked_1 = pipeline.fork("bar")
71 | forked_2 = pipeline.fork("baz")
72 |
73 | assert pipeline != forked_1 != forked_2
74 | assert pipeline.text == "foo"
75 | assert forked_1.text == "foobar"
76 | assert forked_2.text == "foobaz"
77 |
78 |
79 | def test_completion_pipeline_meta() -> None:
80 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "foo")
81 | with_meta = pipeline.meta(key="value")
82 | assert with_meta == pipeline
83 | assert with_meta.metadata == {"key": "value"}
84 |
85 |
86 | def test_completion_pipeline_apply() -> None:
87 | pipeline = CompletionPipeline(get_generator("gpt-3.5"), "Hello $name")
88 | applied = pipeline.apply(name="World", noexist="123")
89 | assert pipeline != applied
90 | assert pipeline.text == "Hello $name"
91 | assert applied.text == "Hello World"
92 |
--------------------------------------------------------------------------------
/docs/topics/completions.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Completions"
3 | description: "Using Rigging for text completions"
4 | public: true
5 | ---
6 |
7 | The majority of Rigging was built around "instruct" or "chat" LLM interfaces where a base model has been tuned to work with a structured layer on top of raw text completion. We typically find that base models are more unpredictable with their outputs, tend to be more sensitive to small changes in their context windows, and require frequent use of stop tokens to prevent unnecessary generation.
8 |
9 | However, there are some places where completing raw text and working with base models might be desirable:
10 |
11 | - Fewer restrictions on the types of content they will generate
12 | - Speeding up generation and lowering token usage by discouraging verbose responses
13 | - Leveraging prompts from popular libraries like [LangChain](https://python.langchain.com/) which assume
14 | a completions-style interface
15 |
16 | ## Interface Parity
17 |
18 | While we try to maintain parity between the "Chat" and "Completions" interfaces in Rigging, you'll
19 | find some deviations here and there. Completions should be a simple transition if you are familiar
20 | with the other code in rigging. Here are the highlights:
21 |
22 | - `chat` -> `complete`
23 | - `Chat` -> `Completion`
24 | - `ChatPipeline` -> `CompletionPipeline`
25 | - `generate_messages` -> `generate_texts`
26 |
27 | On all of these interfaces, you'll note that sequences of `Message` objects have been
28 | replaced with basic `str` objects for both inputs and outputs.
29 |
30 | ## Translator Example
31 |
32 | Let's build a simply translator object that we can store as a `CompletionPipeline`
33 | and use it quickly translate a phrase to 3 different languages.
34 |
35 | ```python {13, 15}
36 | PROMPT = """\
37 | As an expert translator, you accept english text and translate it to $language.
38 |
39 | # Format
40 |
41 | Input: [english text]
42 | Output: [translated text]
43 | ---
44 |
45 | Input: $input
46 | Output: """
47 |
48 | translator = (
49 | rg.get_generator('gpt-3.5-turbo')
50 | .complete(PROMPT)
51 | .with_(stop=["---", "Input:", "\n\n"])
52 | )
53 |
54 | text = "Could you please tell me where the nearest train station is?"
55 |
56 | for language in ["spanish", "french", "german"]:
57 | completion = await translator.apply(
58 | language=language,
59 | input=text
60 | ).run()
61 | print(f"[{language}]: {completion.generated}")
62 |
63 | # [spanish]: ¿Podría decirme por favor dónde está la estación de tren más cercana?
64 | # [french]: Pouvez-vous me dire où se trouve la gare la plus proche, s'il vous plaît ?
65 | # [german]: Könnten Sie mir bitte sagen, wo sich der nächste Bahnhof befindet?
66 | ```
67 |
68 | 1. OpenAPI supports the same model IDs for both completions and chats, but other providers might require you to specify a specific model ID used for text completions.
69 | 2. We use `.with_()` to set stop tokens and prevent the generation from simply continuing until our max tokens are reached. This is a very common and often required pattern when doing completions over chats. Here, we aren't totally sure what the model might generate after our translation, so we use a few different token sequences to be safe.
70 |
71 |
72 | **Using .apply()**
73 |
74 | Text completion is a great place to use the `.apply` method as we can easily slot in our inputs without using `.add` and following it with our output section of the prompt.
75 |
76 |
--------------------------------------------------------------------------------
/rigging/tokenizer/transformers_.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from pydantic import Field
4 | from transformers import AutoTokenizer # type: ignore [import-not-found, unused-ignore]
5 |
6 | from rigging.tokenizer.base import Tokenizer
7 |
8 | if t.TYPE_CHECKING:
9 | from transformers.tokenization_utils import ( # type: ignore [import-not-found, unused-ignore]
10 | PreTrainedTokenizer,
11 | )
12 |
13 | from rigging.chat import Chat
14 |
15 |
16 | class TransformersTokenizer(Tokenizer):
17 | """
18 | A tokenizer implementation using Hugging Face Transformers.
19 |
20 | This class provides tokenization capabilities for chat conversations
21 | using transformers models and their associated tokenizers.
22 | """
23 |
24 | apply_chat_template_kwargs: dict[str, t.Any] = Field(default_factory=dict)
25 | """Additional keyword arguments for applying the chat template."""
26 |
27 | encode_kwargs: dict[str, t.Any] = Field(default_factory=dict)
28 | """Additional keyword arguments for encoding text."""
29 |
30 | decode_kwargs: dict[str, t.Any] = Field(default_factory=dict)
31 | """Additional keyword arguments for decoding tokens."""
32 |
33 | _tokenizer: "PreTrainedTokenizer | None" = None
34 |
35 | @property
36 | def tokenizer(self) -> "PreTrainedTokenizer":
37 | """The underlying `PreTrainedTokenizer` instance."""
38 | if self._tokenizer is None:
39 | self._tokenizer = AutoTokenizer.from_pretrained(self.model) # type: ignore[no-untyped-call] # nosec
40 | return self._tokenizer
41 |
42 | @classmethod
43 | def from_obj(cls, tokenizer: "PreTrainedTokenizer") -> "TransformersTokenizer":
44 | """
45 | Create a new instance of TransformersTokenizer from an already loaded tokenizer.
46 |
47 | Args:
48 | tokenizer: The tokenizer associated with the model.
49 |
50 | Returns:
51 | The TransformersTokenizer instance.
52 | """
53 | return cls(model=str(tokenizer), _tokenizer=tokenizer)
54 |
55 | def encode(self, text: str) -> list[int]:
56 | """
57 | Encodes the given text into a list of tokens.
58 |
59 | Args:
60 | text: The text to encode.
61 |
62 | Returns:
63 | A list of tokens representing the encoded text.
64 | """
65 | return self.tokenizer.encode(text, **self.encode_kwargs) # type: ignore [no-any-return]
66 |
67 | def decode(self, tokens: list[int]) -> str:
68 | decode_kwargs = {
69 | "clean_up_tokenization_spaces": False,
70 | **self.decode_kwargs,
71 | }
72 | return self.tokenizer.decode(tokens, **decode_kwargs) # type: ignore [no-any-return, unused-ignore]
73 |
74 | def format_chat(self, chat: "Chat") -> str:
75 | messages = [m.to_openai(compatibility_flags={"content_as_str"}) for m in chat.all]
76 | tools = (
77 | [tool.model_dump() for tool in chat.params.tools]
78 | if chat.params and chat.params.tools
79 | else None
80 | )
81 |
82 | apply_chat_template_kwargs = {
83 | "tokenize": False,
84 | **self.apply_chat_template_kwargs,
85 | }
86 |
87 | return str(
88 | self.tokenizer.apply_chat_template(
89 | messages,
90 | tools=tools, # type: ignore [arg-type, unused-ignore]
91 | **apply_chat_template_kwargs,
92 | ),
93 | )
94 |
--------------------------------------------------------------------------------
/.github/labels.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Area Labels
3 | - name: area/docs
4 | color: "72CCF3" # Light Blue
5 | description: >-
6 | Changes to documentation and guides
7 |
8 | - name: area/examples
9 | color: "BC9BE3" # Lavender
10 | description: >-
11 | Changes to example code and demonstrations
12 |
13 | - name: area/github
14 | color: "F4D1B7" # Peach
15 | description: >-
16 | Changes made to GitHub Actions
17 |
18 | - name: area/pre-commit
19 | color: "84B6EB" # Steel Blue
20 | description: >-
21 | Changes made to pre-commit hooks
22 |
23 | - name: area/python
24 | color: "7BD7E0" # Turquoise
25 | description: >-
26 | Changes to Python package configuration and dependencies
27 |
28 | - name: area/security
29 | color: "FF6600" # Orange
30 | description: >-
31 | Changes to security policies and configurations
32 |
33 | - name: area/taskfiles
34 | color: "66CCFF" # Sky Blue
35 | description: >-
36 | Changes made to Taskfiles
37 |
38 | - name: area/tests
39 | color: "99CC00" # Lime Green
40 | description: >-
41 | Changes to test files and testing infrastructure
42 |
43 | - name: area/workspace
44 | color: "FF99CC" # Pink
45 | description: >-
46 | Changes to VSCode workspace configuration
47 |
48 | - name: area/assets
49 | color: "FFA07A" # Light Salmon
50 | description: >-
51 | Changes to asset files
52 |
53 | - name: area/templates
54 | color: "DA70D6" # Orchid
55 | description: >-
56 | Changes to templates
57 |
58 | - name: area/scripts
59 | color: "40E0D0" # Turquoise
60 | description: >-
61 | Changes to script files
62 |
63 | - name: area/src
64 | color: "4682B4" # Steel Blue
65 | description: >-
66 | Changes to source code
67 |
68 | - name: area/ci
69 | color: "FF4500" # Orange Red
70 | description: >-
71 | Changes related to CI/CD configurations
72 |
73 | - name: area/shell
74 | color: "556B2F" # Dark Olive Green
75 | description: >-
76 | Changes to shell scripts
77 |
78 | - name: area/dev
79 | color: "CC6699" # Dusty Rose
80 | description: >-
81 | Changes to development tools and assets
82 |
83 | # Renovate Labels
84 | - name: renovate/container
85 | color: "9933CC" # Purple
86 | description: >-
87 | Docker container updates via Renovate
88 |
89 | - name: renovate/github-action
90 | color: "FF3366" # Hot Pink
91 | description: >-
92 | GitHub Action updates via Renovate
93 |
94 | - name: renovate/github-release
95 | color: "3399FF" # Bright Blue
96 | description: >-
97 | GitHub Release updates via Renovate
98 |
99 | # Semantic Type Labels
100 | - name: type/digest
101 | color: "FF66CC" # Bright Pink
102 | description: >-
103 | Dependency digest updates
104 |
105 | - name: type/patch
106 | color: "FFC300" # Golden Yellow
107 | description: >-
108 | Patch changes (fixes, updates)
109 |
110 | - name: type/minor
111 | color: "FFD700" # Gold
112 | description: >-
113 | Minor changes (features, improvements)
114 |
115 | - name: type/major
116 | color: "F6412D" # Red Orange
117 | description: >-
118 | Major changes
119 |
120 | - name: type/break
121 | color: "FF0000" # Bright Red
122 | description: >-
123 | Breaking changes
124 |
125 | # Documentation Labels
126 | - name: type/docs
127 | color: "0075CA" # Documentation Blue
128 | description: >-
129 | Documentation updates and improvements
130 |
131 | - name: type/core
132 | color: "A2EEEF" # Light Blue
133 | description: >-
134 | Changes to core repository files and configurations
135 |
--------------------------------------------------------------------------------
/tests/test_generator_ids.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from rigging.error import InvalidGeneratorError
4 | from rigging.generator import (
5 | GenerateParams,
6 | LiteLLMGenerator,
7 | get_generator,
8 | get_identifier,
9 | register_generator,
10 | )
11 |
12 | from .generators import EchoGenerator
13 |
14 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001
15 |
16 |
17 | @pytest.mark.parametrize("identifier", ["test_model", "litellm!test_model"])
18 | def test_get_generator_default_is_litellm(identifier: str) -> None:
19 | generator = get_generator(identifier)
20 | assert isinstance(generator, LiteLLMGenerator)
21 | assert generator.model == "test_model"
22 |
23 |
24 | @pytest.mark.parametrize("identifier", ["invalid!testing", "no_exist!stuff,args=123"])
25 | def test_get_generator_invalid_provider(identifier: str) -> None:
26 | with pytest.raises(InvalidGeneratorError):
27 | get_generator(identifier)
28 |
29 |
30 | @pytest.mark.parametrize(
31 | ("identifier", "valid_params"),
32 | [
33 | ("litellm!test_model,max_tokens=123,top_p=10", GenerateParams(max_tokens=123, top_p=10)),
34 | ("litellm!test_model,temperature=0.5", GenerateParams(temperature=0.5)),
35 | (
36 | "test_model,temperature=1.0,max_tokens=100",
37 | GenerateParams(max_tokens=100, temperature=1.0),
38 | ),
39 | ],
40 | )
41 | def test_get_generator_with_params(identifier: str, valid_params: GenerateParams) -> None:
42 | generator = get_generator(identifier)
43 | assert isinstance(generator, LiteLLMGenerator)
44 | assert generator.model == "test_model"
45 | assert generator.params == valid_params
46 |
47 |
48 | @pytest.mark.parametrize(
49 | "identifier",
50 | [
51 | ("test_model,max_tokens=1024,top_p=0.1"),
52 | ("custom,temperature=1.0,max_tokens=100,api_base=https://localhost:8000"),
53 | ("many/model/slashes,stop=a;b;c;"),
54 | ("http!with_cls_args"),
55 | ],
56 | )
57 | def test_identifier_roundtrip(identifier: str) -> None:
58 | generator = get_generator(identifier)
59 | assert generator.to_identifier() == identifier
60 |
61 |
62 | def test_get_identifier_no_extra() -> None:
63 | generator = get_generator("testing_model,temperature=0.5")
64 | generator.params.extra = {"abc": 123}
65 | identifier = get_identifier(generator)
66 | assert "extra" not in identifier
67 |
68 |
69 | @pytest.mark.parametrize(
70 | "identifier",
71 | ["litellm:invalid,stuff:test,t1/123", "bad:invalid,stuff:test,t1//;;123:"],
72 | )
73 | def test_get_generator_invalid_structure_format(identifier: str) -> None:
74 | with pytest.raises(InvalidGeneratorError):
75 | get_generator(identifier)
76 |
77 |
78 | @pytest.mark.parametrize(
79 | "identifier",
80 | ["litellm:model,bad_param=123,temperature=1.0", "litellm:model,temperature=True"],
81 | )
82 | def test_get_generator_invalid_params(identifier: str) -> None:
83 | with pytest.raises(InvalidGeneratorError):
84 | get_generator(identifier)
85 |
86 |
87 | def test_register_generator() -> None:
88 | with pytest.raises(InvalidGeneratorError):
89 | get_generator("echo!test")
90 |
91 | register_generator("echo", EchoGenerator)
92 | generator = get_generator("echo!test")
93 | assert isinstance(generator, EchoGenerator)
94 |
95 |
96 | def test_get_generator_b64() -> None:
97 | generator = get_generator("litellm!test_model,api_key=ZXhhbXBsZXRleHQ=")
98 | assert isinstance(generator, LiteLLMGenerator)
99 | assert generator.model == "test_model"
100 |
--------------------------------------------------------------------------------
/docs/topics/tracing.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Tracing"
3 | description: "Trace all rigging internal behaviors with OpenTelemetry"
4 | public: true
5 | ---
6 |
7 | Rigging integrates with the [Logfire](https://logfire.pydantic.dev/docs/) library for exposing tracing information about execution. Specifically we use the logfire-API no-op package, making it optional for users with no overhead if you don't need it.
8 |
9 | Logfire is capable of reporting trace information to any Open Telemetry compatible system, and provides convenient abstractions on top of the standard open-telemetry-SDK. If the `logfire` package is installed and configured, details about pipelines, prompts, and tools will be traced when you use rigging.
10 |
11 | You can configure Logfire to use [alternative backends](https://logfire.pydantic.dev/docs/how-to-guides/alternative-backends/) as needed to integrate with your preferred tracing stack.
12 |
13 | ```python
14 | import rigging as rg
15 | import logfire
16 |
17 | logfire.configure()
18 |
19 | @rg.prompt(generator_id="gpt-4o")
20 | async def summarize(content: str) -> str:
21 | """
22 | Summarize the content into 1-2 sentences then save it
23 | """
24 |
25 | summarize.watch(rg.watchers.write_chats_to_jsonl("chats.jsonl"))
26 |
27 | text = """
28 | Revachol is located on the island of Le Caillou, also called "The Pebble" on the northeast side of the
29 | Insulindian Isola, on the world's largest body of water: the Insulindic. The city itself has a radius
30 | of 80 kilometres and is split by the River Esperance into Revachol East and Revachol West. The north
31 | side of the island is shattered by the delta of the Esperance, and is named La Delta.
32 | """
33 |
34 | await summarize.run_many(3, text)
35 |
36 | # 23:46:31.484 Prompt summarize() (x3)
37 | # 23:46:31.485 Chat with litellm!gpt-4o (x3)
38 | # 23:46:32.874 Watch with rigging.watchers.write_chats_to_jsonl()
39 | ```
40 |
41 |
42 | Rigging will attach call parameters and results for both tools and prompt functions, as well as finalized chat objects at the end of a pipeline. Logfire will serialize these items as JSON values inside attributes, and include a dynamic JSON schema for reference. When using their platform, these items deserialize directly into the web view.
43 |
44 |
45 | 
46 |
47 | ## Inference Tracing
48 |
49 | We've opted to exclude tracing at the generator level in Rigging (for now) and focus on instrumenting higher-order abstractions like pipelines and tools, which are specific to the framework.
50 |
51 | There are a suite of powerful instrumentation libraries (including Logfire) which will add tracing to underlying libraries like LiteLLM (recommended), OpenAI, Anthropic, VertexAI, and others. These snap right into the tracing spans from rigging, and provide insight into the raw inference traffic before it's sent to API endpoints and inference libraries.
52 |
53 | - [LiteLLM - Logfire](https://docs.litellm.ai/docs/observability/logfire_integration)
54 | - [LiteLLM - OpenTelemetry](https://docs.litellm.ai/docs/observability/opentelemetry_integration)
55 | - [Logfire - Integrations](https://logfire.pydantic.dev/docs/integrations/)
56 | - [TraceLoop - openllmetry](https://github.com/traceloop/openllmetry)
57 |
58 | Here is an example of adding LiteLLM tracing on top of rigging:
59 |
60 | ```python
61 | import rigging as rg
62 | import logfire
63 | import litellm
64 |
65 | logfire.configure()
66 |
67 | os.environ.setdefault("LOGFIRE_TOKEN", "") # (1)!
68 | litellm.callbacks = ["logfire"]
69 |
70 | # ...
71 | ```
72 |
73 | *1. As of publication, LiteLLM requires this environment variable, even if it's empty and Logfire is managing tokens for you.*
74 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | - Using welcoming and inclusive language
18 | - Being respectful of differing viewpoints and experiences
19 | - gracefully accepting constructive criticism
20 | - Focusing on what is best for the community
21 | - Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | - The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | - Trolling, insulting/derogatory comments, and personal or political attacks
28 | - Public or private harassment
29 | - Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | - Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
75 | version 1.4, available at
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 |
81 |
--------------------------------------------------------------------------------
/rigging/interact.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for interactive chat sessions.
3 | """
4 |
5 | import asyncio
6 | import itertools
7 | import typing as t
8 |
9 | from colorama import Fore, Style
10 |
11 | from rigging.chat import Chat, ChatPipeline
12 | from rigging.generator.base import Generator, get_generator
13 |
14 | # ruff: noqa: T201 (we use print() here for simplicity)
15 |
16 |
17 | async def _animate(
18 | *,
19 | delay: float = 0.5,
20 | chars: list[str] | None = None,
21 | color: str = Fore.BLUE,
22 | ) -> None:
23 | cycle = itertools.cycle(chars or [" ", ". ", ".. ", "..."])
24 | while True:
25 | print(f"{color}{next(cycle)}{Style.RESET_ALL}", end="\r", flush=True)
26 | await asyncio.sleep(delay)
27 |
28 |
29 | async def interact(
30 | entrypoint: ChatPipeline | Generator | str,
31 | *,
32 | reset_callback: t.Callable[[Chat | None], None] | None = None,
33 | ) -> Chat | None:
34 | """
35 | Start an interactive chat session using the given pipeline, generator, or generator id.
36 |
37 | This function allows the user to have a conversation with an assistant by providing input
38 | and receiving responses. The chat session can be controlled using specific commands.
39 |
40 | Args:
41 | entrypoint: A ChatPipeline, Generator, or generator id to use for the chat session.
42 | reset_callback: A callback function to execute when the chat is reset.
43 |
44 | Returns:
45 | The final Chat object, or None if the chat was interrupted before any generation.
46 | """
47 |
48 | print(f"\n{Fore.YELLOW}")
49 | print("Starting interactive chat.")
50 | print()
51 | print(" - Type 'exit' to quit or use Ctrl+C.")
52 | print(" - Type 'reset' or 'restart' to restart the chat.")
53 | print(" - Type 'again' or 'retry' to re-run the last generation.")
54 | print(Style.RESET_ALL)
55 |
56 | base_pipeline = (
57 | entrypoint
58 | if isinstance(entrypoint, ChatPipeline)
59 | else entrypoint.chat()
60 | if isinstance(entrypoint, Generator)
61 | else get_generator(entrypoint).chat()
62 | )
63 |
64 | pipeline = base_pipeline.clone()
65 | chat: Chat | None = None
66 |
67 | while True:
68 | try:
69 | user_input = input(f"\n{Fore.GREEN}User: {Style.RESET_ALL}")
70 | if not user_input:
71 | continue
72 |
73 | if user_input.lower() == "exit":
74 | print(f"\n\n{Fore.YELLOW}Exiting chat.{Style.RESET_ALL}")
75 | break
76 |
77 | if user_input.lower() in ["reset", "restart"]:
78 | print(f"\n{Fore.YELLOW}--- Reset ---{Style.RESET_ALL}")
79 | pipeline = base_pipeline.clone()
80 | if reset_callback:
81 | reset_callback(chat)
82 | continue
83 |
84 | if user_input.lower() in ["again", "retry"]:
85 | print(f"\n{Fore.YELLOW}--- Retry ---{Style.RESET_ALL}")
86 | pipeline.chat.messages = pipeline.chat.messages[:-1]
87 | else:
88 | pipeline.add(user_input)
89 |
90 | print()
91 |
92 | animation_task = asyncio.create_task(_animate())
93 | chat = await pipeline.run()
94 | animation_task.cancel()
95 |
96 | print(f"\r{Fore.BLUE}Assistant: {Style.RESET_ALL}{chat.last.content}")
97 |
98 | pipeline.add(chat.last)
99 |
100 | except KeyboardInterrupt:
101 | print(f"\n\n{Fore.YELLOW}Chat interrupted. Exiting.{Style.RESET_ALL}")
102 | break
103 | except Exception as e: # noqa: BLE001
104 | print(f"\n\n{Fore.RED}An error occurred: {e!s}{Style.RESET_ALL}")
105 | break
106 |
107 | return chat
108 |
--------------------------------------------------------------------------------
/docs/topics/workflow.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Workflow"
3 | description: "How to use Rigging to generate messages."
4 | public: true
5 | ---
6 |
7 | There are two main ways to use Rigging: prompts and pipelines.
8 |
9 | Prompts offer a simple entry point and cover a lot of ground, but pipelines are the core of Rigging’s power—prompts ultimately serve as a way to leverage pipelines.
10 |
11 | ## Using Prompts
12 |
13 | 1. Establish a function decorated with `@prompt` with the inference model you want to use.
14 | 2. Call the function as you normally would and receive structured data back.
15 |
16 | ```python
17 | import rigging as rg
18 |
19 | @rg.prompt(generator_id="claude-3-5-sonnet-latest")
20 | async def get_authors(count: int = 3) -> list[str]:
21 | """Provide famous authors."""
22 |
23 | print(await get_authors())
24 |
25 | # ['William Shakespeare', 'J.K. Rowling', 'Jane Austen']
26 | ```
27 |
28 | Underneath, Rigging will produce a `Generator` with `get_generator("claude-3-5-sonnet-latest")`, prepare a small template that will establish the required context and output structure, pass it into a new `ChatPipeline`, run the generation process, and parse the output into our structured list with `ChatPipeline.then()`.
29 |
30 | If you want to see the resulting `Chat` object, you can set that as your return value and no output parsing
31 |
32 | ```python
33 | @rg.prompt(generator_id="claude-3-5-sonnet-latest")
34 | async def get_authors(count: int = 3) -> rg.Chat:
35 | ...
36 | ```
37 |
38 | Now the prompt is only responsible for abstracting the generator, pipeline, and content for you. You can also use a nested object like a `tuple` and include both your structured data and the `Chat` object.
39 |
40 | ```python
41 | @rg.prompt(generator_id="claude-3-5-sonnet-latest")
42 | async def get_authors(count: int = 3) -> tuple[list[str], rg.Chat]:
43 | """Provide famous authors."""
44 | ```
45 |
46 | This will return a tuple with the parsed output as the first element and the raw `Chat` object as the second element.
47 |
48 | You can learn more about the `@prompt` decorator in the [Prompt Functions](/topics/prompt-functions) section.
49 |
50 | ## Using Pipelines
51 |
52 | 1. Get a `Generator` object - usually with `get_generator()`.
53 | 2. Call `generator.chat()`to produce a `ChatPipeline` and ready it for generation.
54 | 3. Call `pipeline.run()` to kick off generation and get your final `Chat` object.
55 |
56 | `ChatPipeline` objects hold any messages waiting to be delivered to an LLM in exchange for a new response message. These objects are also where most of the power in rigging comes from. You'll build a generation pipeline with options, parsing, callbacks, etc. After preparation, this pipeline is used to make a final `Chat` which holds all messages prior to generation (`.prev`) and after generation (`.next`).
57 |
58 | You should think of `ChatPipeline` objects like the configurable pre-generation step with calls like `.with_()`, `.apply()`, `.until()`, `.using()`, etc. Once you call one of the many `.run()`functions, the generator is used to produce the next message (or many messages) based on the prior context and any constraints you have in place. Once you have a `Chat` object, the interaction is complete and you can inspect and operate on the messages.
59 |
60 |
61 | Rigging supports both Chat objects (messages with roles in a conversation format), as well as raw text completions. While we use Chat objects in most of our examples, you can check out the [Completions](/topics/completions) section to learn more about their feature parity.
62 |
63 |
64 | We often use functional styling chaining as most of our utility functions return the object back to you.
65 |
66 | ```python
67 | chat = (
68 | await
69 | generator.chat(...)
70 | .using(...) # tools
71 | .then(...) # follow up functions
72 | .with_(...) # generation params
73 | .run()
74 | )
75 | ```
76 |
77 | Learn more about the `ChatPipeline` object in the [Pipelines](/topics/pipelines) section.
78 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Testing code
2 | notebooks/
3 |
4 | # Custom parquet storage
5 | *.parquet
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/#use-with-ide
116 | .pdm.toml
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | #.idea/
167 |
168 | # macos
169 | .DS_Store
170 | .AppleDouble
171 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "rigging"
3 | version = "3.3.5"
4 | description = "LLM Interaction Framework"
5 | authors = ["Nick Landers "]
6 | license = "MIT"
7 | repository = "https://github.com/dreadnode/rigging"
8 | readme = "README.md"
9 | packages = [{ include = "rigging" }]
10 |
11 | # Dependencies
12 |
13 | [tool.poetry.dependencies]
14 | python = ">=3.10,<3.14"
15 | pydantic = "^2.7.3"
16 | pydantic-xml = "<=2.17.0"
17 | loguru = "^0.7.2"
18 | litellm = "^1.67.2"
19 | xmltodict = "^0.13.0"
20 | colorama = "^0.4.6"
21 | jsonpath-ng = "^1.7.0"
22 | ruamel-yaml = "^0.18.10"
23 | jsonref = "^1.1.0"
24 | mcp = "^1.5.0"
25 | dreadnode = ">=1.12.0"
26 |
27 | vllm = { version = "^0.5.0", optional = true }
28 | transformers = { version = "^4.41.0", optional = true }
29 | accelerate = { version = "^0.30.1", optional = true }
30 |
31 | asyncssh = { version = "^2.14.2", optional = true }
32 | click = { version = "^8.1.7", optional = true }
33 | httpx = { version = "^0.28.0", optional = true }
34 | aiodocker = { version = "^0.22.2", optional = true }
35 | websockets = { version = "^13.0", optional = true }
36 |
37 | elasticsearch = { version = "^8.13.2", optional = true }
38 | pandas = { version = "^2.2.2", optional = true }
39 |
40 | [tool.poetry.extras]
41 | data = ["pandas", "elasticsearch"]
42 | examples = ["asyncssh", "click", "httpx", "aiodocker", "websockets"]
43 | llm = ["vllm", "transformers", "accelerate"]
44 | all = [
45 | "vllm",
46 | "transformers",
47 | "accelerate",
48 | "asyncssh",
49 | "click",
50 | "httpx",
51 | "aiodocker",
52 | "websockets",
53 | "logfire",
54 | "elasticsearch",
55 | "pandas",
56 | ]
57 |
58 | [tool.poetry.group.dev.dependencies]
59 | ipykernel = "^6.27.1"
60 | mypy = "^1.15.0"
61 | ruff = "^0.10.0"
62 | pytest = "^8.0.0"
63 | pandas-stubs = "^2.2.1.240316"
64 | coverage = "^7.5.1"
65 | ipywidgets = "^8.1.3"
66 | pytest-asyncio = "^1.0.0"
67 | types-colorama = "^0.4.15.20240311"
68 | types-requests = "2.32.4.20250611"
69 | beautifulsoup4 = "^4.13.4"
70 | mkdocstrings = { extras = ["python"], version = "^0.29.1" }
71 | markdown = "^3.8"
72 | markdownify = "^1.1.0"
73 | boto3-stubs = { extras = ["s3"], version = "^1.35.0" }
74 |
75 | # Build
76 |
77 | [build-system]
78 | requires = ["poetry-core"]
79 | build-backend = "poetry.core.masonry.api"
80 |
81 | # Tests / Coverage
82 |
83 | [tool.pytest.ini_options]
84 | asyncio_mode = "auto"
85 | asyncio_default_fixture_loop_scope = "function"
86 | filterwarnings = ["ignore::DeprecationWarning"]
87 |
88 | [tool.coverage.run]
89 | command_line = "-m pytest"
90 |
91 | [tool.coverage.report]
92 | include = ["rigging/*.py"]
93 | show_missing = true
94 |
95 | [tool.coverage.lcov]
96 | output = "lcov.info"
97 |
98 | # Tracing
99 |
100 | [tool.logfire]
101 | ignore_no_config = true
102 |
103 | # Security
104 |
105 | [tool.bandit]
106 | exclude_dirs = ["examples/*", ".github/*", ".hooks/*"]
107 |
108 | # Type Checking
109 |
110 | [tool.mypy]
111 | plugins = "pydantic.mypy"
112 | strict = true
113 |
114 | # Formatting / Linting
115 |
116 | [tool.ruff]
117 | target-version = "py310"
118 | line-length = 100
119 | extend-exclude = [
120 | "*.ipynb", # jupyter notebooks
121 | "examples/*", # example files
122 | ".github/*", # github files
123 | ".hooks/*", # git hooks
124 | ]
125 |
126 | [tool.ruff.lint]
127 | select = ["ALL"]
128 | ignore = [
129 | "E501", # line too long (we make best effort)
130 | "TRY003", # long messages in exception classes
131 | "EM", # picky message construction for exceptions
132 | "C90", # mccabe complexity
133 | "A002", # shadowing built-in
134 | "D", # docstrings
135 | "ANN", # annotations (handled by mypy)
136 | "PLR0913", # too many arguments
137 | "ERA001", # commented out code
138 | "FIX002", # contains todo, consider fixing
139 | "TD002", # TODO
140 | "TD003", # TODO
141 | "PLR0911", # too many return statements
142 | "FBT003", # boolean positional in function call
143 | "COM812", # missing trailing comma in function call
144 | ]
145 |
146 | [tool.ruff.format]
147 | skip-magic-trailing-comma = false
148 |
--------------------------------------------------------------------------------
/.secrets.baseline:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.5.0",
3 | "plugins_used": [
4 | {
5 | "name": "ArtifactoryDetector"
6 | },
7 | {
8 | "name": "AWSKeyDetector"
9 | },
10 | {
11 | "name": "AzureStorageKeyDetector"
12 | },
13 | {
14 | "name": "Base64HighEntropyString",
15 | "limit": 4.5
16 | },
17 | {
18 | "name": "BasicAuthDetector"
19 | },
20 | {
21 | "name": "CloudantDetector"
22 | },
23 | {
24 | "name": "DiscordBotTokenDetector"
25 | },
26 | {
27 | "name": "GitHubTokenDetector"
28 | },
29 | {
30 | "name": "GitLabTokenDetector"
31 | },
32 | {
33 | "name": "HexHighEntropyString",
34 | "limit": 3.0
35 | },
36 | {
37 | "name": "IbmCloudIamDetector"
38 | },
39 | {
40 | "name": "IbmCosHmacDetector"
41 | },
42 | {
43 | "name": "IPPublicDetector"
44 | },
45 | {
46 | "name": "JwtTokenDetector"
47 | },
48 | {
49 | "name": "KeywordDetector",
50 | "keyword_exclude": ""
51 | },
52 | {
53 | "name": "MailchimpDetector"
54 | },
55 | {
56 | "name": "NpmDetector"
57 | },
58 | {
59 | "name": "OpenAIDetector"
60 | },
61 | {
62 | "name": "PrivateKeyDetector"
63 | },
64 | {
65 | "name": "PypiTokenDetector"
66 | },
67 | {
68 | "name": "SendGridDetector"
69 | },
70 | {
71 | "name": "SlackDetector"
72 | },
73 | {
74 | "name": "SoftlayerDetector"
75 | },
76 | {
77 | "name": "SquareOAuthDetector"
78 | },
79 | {
80 | "name": "StripeDetector"
81 | },
82 | {
83 | "name": "TelegramBotTokenDetector"
84 | },
85 | {
86 | "name": "TwilioKeyDetector"
87 | }
88 | ],
89 | "filters_used": [
90 | {
91 | "path": "detect_secrets.filters.allowlist.is_line_allowlisted"
92 | },
93 | {
94 | "path": "detect_secrets.filters.common.is_baseline_file",
95 | "filename": ".secrets.baseline"
96 | },
97 | {
98 | "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
99 | "min_level": 2
100 | },
101 | {
102 | "path": "detect_secrets.filters.heuristic.is_indirect_reference"
103 | },
104 | {
105 | "path": "detect_secrets.filters.heuristic.is_likely_id_string"
106 | },
107 | {
108 | "path": "detect_secrets.filters.heuristic.is_lock_file"
109 | },
110 | {
111 | "path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
112 | },
113 | {
114 | "path": "detect_secrets.filters.heuristic.is_potential_uuid"
115 | },
116 | {
117 | "path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
118 | },
119 | {
120 | "path": "detect_secrets.filters.heuristic.is_sequential_string"
121 | },
122 | {
123 | "path": "detect_secrets.filters.heuristic.is_swagger_file"
124 | },
125 | {
126 | "path": "detect_secrets.filters.heuristic.is_templated_secret"
127 | },
128 | {
129 | "path": "detect_secrets.filters.regex.should_exclude_file",
130 | "pattern": [
131 | "examples/*"
132 | ]
133 | }
134 | ],
135 | "results": {
136 | "docs/topics/generators.mdx": [
137 | {
138 | "type": "Secret Keyword",
139 | "filename": "docs/topics/generators.mdx",
140 | "hashed_secret": "ef5225a03e4f9cc953ab3c4dd41f5c4db7dc2e5b",
141 | "is_verified": false,
142 | "line_number": 360
143 | },
144 | {
145 | "type": "Secret Keyword",
146 | "filename": "docs/topics/generators.mdx",
147 | "hashed_secret": "eb6256c862c356b375aafa760fa1851e33aa62a9",
148 | "is_verified": false,
149 | "line_number": 384
150 | }
151 | ],
152 | "tests/test_http_spec.py": [
153 | {
154 | "type": "Secret Keyword",
155 | "filename": "tests/test_http_spec.py",
156 | "hashed_secret": "3acfb2c2b433c0ea7ff107e33df91b18e52f960f",
157 | "is_verified": false,
158 | "line_number": 19
159 | }
160 | ]
161 | },
162 | "generated_at": "2025-09-03T21:56:13Z"
163 | }
164 |
--------------------------------------------------------------------------------
/docs/api/interact.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: rigging.interact
3 | ---
4 |
5 | {/*
6 | ::: rigging.interact
7 | */}
8 |
9 | Utility functions for interactive chat sessions.
10 |
11 | interact
12 | --------
13 |
14 | ```python
15 | interact(
16 | entrypoint: ChatPipeline | Generator | str,
17 | *,
18 | reset_callback: Callable[[Chat | None], None]
19 | | None = None,
20 | ) -> Chat | None
21 | ```
22 |
23 | Start an interactive chat session using the given pipeline, generator, or generator id.
24 |
25 | This function allows the user to have a conversation with an assistant by providing input
26 | and receiving responses. The chat session can be controlled using specific commands.
27 |
28 | **Parameters:**
29 |
30 | * **`entrypoint`**
31 | (`ChatPipeline | Generator | str`)
32 | –A ChatPipeline, Generator, or generator id to use for the chat session.
33 | * **`reset_callback`**
34 | (`Callable[[Chat | None], None] | None`, default:
35 | `None`
36 | )
37 | –A callback function to execute when the chat is reset.
38 |
39 | **Returns:**
40 |
41 | * `Chat | None`
42 | –The final Chat object, or None if the chat was interrupted before any generation.
43 |
44 |
45 | ```python
46 | async def interact(
47 | entrypoint: ChatPipeline | Generator | str,
48 | *,
49 | reset_callback: t.Callable[[Chat | None], None] | None = None,
50 | ) -> Chat | None:
51 | """
52 | Start an interactive chat session using the given pipeline, generator, or generator id.
53 |
54 | This function allows the user to have a conversation with an assistant by providing input
55 | and receiving responses. The chat session can be controlled using specific commands.
56 |
57 | Args:
58 | entrypoint: A ChatPipeline, Generator, or generator id to use for the chat session.
59 | reset_callback: A callback function to execute when the chat is reset.
60 |
61 | Returns:
62 | The final Chat object, or None if the chat was interrupted before any generation.
63 | """
64 |
65 | print(f"\n{Fore.YELLOW}")
66 | print("Starting interactive chat.")
67 | print()
68 | print(" - Type 'exit' to quit or use Ctrl+C.")
69 | print(" - Type 'reset' or 'restart' to restart the chat.")
70 | print(" - Type 'again' or 'retry' to re-run the last generation.")
71 | print(Style.RESET_ALL)
72 |
73 | base_pipeline = (
74 | entrypoint
75 | if isinstance(entrypoint, ChatPipeline)
76 | else entrypoint.chat()
77 | if isinstance(entrypoint, Generator)
78 | else get_generator(entrypoint).chat()
79 | )
80 |
81 | pipeline = base_pipeline.clone()
82 | chat: Chat | None = None
83 |
84 | while True:
85 | try:
86 | user_input = input(f"\n{Fore.GREEN}User: {Style.RESET_ALL}")
87 | if not user_input:
88 | continue
89 |
90 | if user_input.lower() == "exit":
91 | print(f"\n\n{Fore.YELLOW}Exiting chat.{Style.RESET_ALL}")
92 | break
93 |
94 | if user_input.lower() in ["reset", "restart"]:
95 | print(f"\n{Fore.YELLOW}--- Reset ---{Style.RESET_ALL}")
96 | pipeline = base_pipeline.clone()
97 | if reset_callback:
98 | reset_callback(chat)
99 | continue
100 |
101 | if user_input.lower() in ["again", "retry"]:
102 | print(f"\n{Fore.YELLOW}--- Retry ---{Style.RESET_ALL}")
103 | pipeline.chat.messages = pipeline.chat.messages[:-1]
104 | else:
105 | pipeline.add(user_input)
106 |
107 | print()
108 |
109 | animation_task = asyncio.create_task(_animate())
110 | chat = await pipeline.run()
111 | animation_task.cancel()
112 |
113 | print(f"\r{Fore.BLUE}Assistant: {Style.RESET_ALL}{chat.last.content}")
114 |
115 | pipeline.add(chat.last)
116 |
117 | except KeyboardInterrupt:
118 | print(f"\n\n{Fore.YELLOW}Chat interrupted. Exiting.{Style.RESET_ALL}")
119 | break
120 | except Exception as e: # noqa: BLE001
121 | print(f"\n\n{Fore.RED}An error occurred: {e!s}{Style.RESET_ALL}")
122 | break
123 |
124 | return chat
125 | ```
126 |
127 |
128 |
--------------------------------------------------------------------------------
/rigging/parsing.py:
--------------------------------------------------------------------------------
1 | """
2 | Parsing helpers for extracting rigging models from text
3 | """
4 |
5 | import typing as t
6 |
7 | from rigging.error import MissingModelError
8 |
9 | if t.TYPE_CHECKING:
10 | from rigging.model import ModelT
11 |
12 |
13 | def parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice]:
14 | """
15 | Parses a single model from text.
16 |
17 | Args:
18 | text: The content to parse.
19 | model_type: The type of model to parse.
20 |
21 | Returns:
22 | The parsed model.
23 |
24 | Raises:
25 | ValueError: If no models of the given type are found and `fail_on_missing` is set to `True`.
26 | """
27 | return try_parse_many(text, model_type, fail_on_missing=True)[0]
28 |
29 |
30 | def try_parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice] | None:
31 | """
32 | Tries to parse a model from text.
33 |
34 | Args:
35 | text: The content to parse.
36 | model_type: The type of model to search for.
37 |
38 | Returns:
39 | The first model that matches the given model type, or None if no match is found.
40 | """
41 | return next(iter(try_parse_many(text, model_type)), None)
42 |
43 |
44 | def parse_set(
45 | text: str,
46 | model_type: type["ModelT"],
47 | *,
48 | minimum: int | None = None,
49 | ) -> list[tuple["ModelT", slice]]:
50 | """
51 | Parses a set of models with the specified identical type from text.
52 |
53 | Args:
54 | text: The content to parse.
55 | model_type: The type of models to parse.
56 | minimum: The minimum number of models required.
57 |
58 | Returns:
59 | A list of parsed models.
60 |
61 | Raises:
62 | MissingModelError: If the minimum number of models is not met.
63 | """
64 | return try_parse_set(text, model_type, minimum=minimum, fail_on_missing=True)
65 |
66 |
67 | def try_parse_set(
68 | text: str,
69 | model_type: type["ModelT"],
70 | *,
71 | minimum: int | None = None,
72 | fail_on_missing: bool = False,
73 | ) -> list[tuple["ModelT", slice]]:
74 | """
75 | Tries to parse a set of models with the specified identical type from text.
76 |
77 | Args:
78 | text: The content to parse.
79 | model_type: The type of model to parse.
80 | minimum: The minimum number of models expected.
81 | fail_on_missing: Whether to raise an exception if models are missing.
82 |
83 | Returns:
84 | The parsed models.
85 |
86 | Raises:
87 | MissingModelError: If the number of parsed models is less than the minimum required.
88 | """
89 | models = try_parse_many(text, model_type, fail_on_missing=fail_on_missing)
90 | if minimum is not None and len(models) < minimum:
91 | raise MissingModelError(f"Expected at least {minimum} {model_type.__name__} in message")
92 | return models
93 |
94 |
95 | def parse_many(text: str, *types: type["ModelT"]) -> list[tuple["ModelT", slice]]:
96 | """
97 | Parses multiple models of the specified non-identical types from text.
98 |
99 | Args:
100 | text: The content to parse.
101 | *types: The types of models to parse.
102 |
103 | Returns:
104 | A list of parsed models.
105 |
106 | Raises:
107 | MissingModelError: If any of the models are missing.
108 | """
109 | return try_parse_many(text, *types, fail_on_missing=True)
110 |
111 |
112 | def try_parse_many(
113 | text: str,
114 | *types: type["ModelT"],
115 | fail_on_missing: bool = False,
116 | ) -> list[tuple["ModelT", slice]]:
117 | """
118 | Tries to parses multiple models of the specified non-identical types from text.
119 |
120 | Args:
121 | text: The content to parse.
122 | *types: The types of models to parse.
123 | fail_on_missing: Whether to raise an exception if a model type is missing.
124 |
125 | Returns:
126 | A list of parsed models.
127 |
128 | Raises:
129 | MissingModelError: If a model type is missing and `fail_on_missing` is True.
130 | Exception: If the model is malformed and `fail_on_missing` is True.
131 | """
132 | model: ModelT
133 | parsed: list[tuple[ModelT, slice]] = []
134 |
135 | try:
136 | for model_class in types:
137 | for model, slice_ in model_class.from_text(text):
138 | parsed.append((model, slice_))
139 | except Exception:
140 | if fail_on_missing:
141 | raise
142 |
143 | return sorted(parsed, key=lambda x: x[1].start)
144 |
--------------------------------------------------------------------------------
/.hooks/check_pinned_hash_dependencies.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import re
3 | import sys
4 | from pathlib import Path
5 |
6 |
7 | class GitHubActionChecker:
8 | def __init__(self) -> None:
9 | # Pattern for actions with SHA-1 hashes (pinned)
10 | self.pinned_pattern = re.compile(r"uses:\s+([^@\s]+)@([a-f0-9]{40})")
11 |
12 | # Pattern for actions with version tags (unpinned)
13 | self.unpinned_pattern = re.compile(
14 | r"uses:\s+([^@\s]+)@(v\d+(?:\.\d+)*(?:-[a-zA-Z0-9]+(?:\.\d+)*)?)",
15 | )
16 |
17 | # Pattern for all uses statements
18 | self.all_uses_pattern = re.compile(r"uses:\s+([^@\s]+)@([^\s\n]+)")
19 |
20 | def format_terminal_link(self, file_path: str, line_number: int) -> str:
21 | """Format a terminal link to a file and line number.
22 |
23 | Args:
24 | file_path: Path to the file
25 | line_number: Line number in the file
26 |
27 | Returns:
28 | str: Formatted string with file path and line number
29 | """
30 | return f"{file_path}:{line_number}"
31 |
32 | def get_line_numbers(self, content: str, pattern: re.Pattern[str]) -> list[tuple[str, int]]:
33 | """Find matches with their line numbers."""
34 | matches = []
35 | matches.extend(
36 | (match.group(0), i)
37 | for i, line in enumerate(content.splitlines(), 1)
38 | for match in pattern.finditer(line)
39 | )
40 | return matches
41 |
42 | def check_file(self, file_path: str) -> bool:
43 | """Check a single file for unpinned dependencies."""
44 | try:
45 | content = Path(file_path).read_text()
46 | except (FileNotFoundError, PermissionError, IsADirectoryError, OSError) as e:
47 | print(f"\033[91mError reading file {file_path}: {e}\033[0m")
48 | return False
49 |
50 | # Get matches with line numbers
51 | pinned_matches = self.get_line_numbers(content, self.pinned_pattern)
52 | unpinned_matches = self.get_line_numbers(content, self.unpinned_pattern)
53 | all_matches = self.get_line_numbers(content, self.all_uses_pattern)
54 |
55 | print(f"\n\033[1m[=] Checking file: {file_path}\033[0m")
56 |
57 | # Print pinned dependencies
58 | if pinned_matches:
59 | print("\033[92m[+] Pinned:\033[0m")
60 | for match, line_num in pinned_matches:
61 | print(f" |- {match} \033[90m({file_path}:{line_num})\033[0m")
62 |
63 | # Track all found actions for validation
64 | found_actions = set()
65 | for match, _ in pinned_matches + unpinned_matches:
66 | action_name = self.pinned_pattern.match(match) or self.unpinned_pattern.match(match)
67 | if action_name:
68 | found_actions.add(action_name.group(1))
69 |
70 | has_errors = False
71 |
72 | # Check for unpinned dependencies
73 | if unpinned_matches:
74 | has_errors = True
75 | print("\033[93m[!] Unpinned (using version tags):\033[0m")
76 | for match, line_num in unpinned_matches:
77 | print(f" |- {match} \033[90m({file_path}:{line_num})\033[0m")
78 |
79 | # Check for completely unpinned dependencies (no SHA or version)
80 | unpinned_without_hash = [
81 | (match, line_num)
82 | for match, line_num in all_matches
83 | if not any(match in pinned[0] for pinned in pinned_matches)
84 | and not any(match in unpinned[0] for unpinned in unpinned_matches)
85 | ]
86 |
87 | if unpinned_without_hash:
88 | has_errors = True
89 | print("\033[91m[!] Completely unpinned (no SHA or version):\033[0m")
90 | for match, line_num in unpinned_without_hash:
91 | print(
92 | f" |- {match} \033[90m({self.format_terminal_link(file_path, line_num)})\033[0m",
93 | )
94 |
95 | # Print summary
96 | total_actions = len(pinned_matches) + len(unpinned_matches) + len(unpinned_without_hash)
97 | if total_actions == 0:
98 | print("\033[93m[!] No GitHub Actions found in this file\033[0m")
99 | else:
100 | print("\n\033[1mSummary:\033[0m")
101 | print(f"Total actions: {total_actions}")
102 | print(f"Pinned: {len(pinned_matches)}")
103 | print(f"Unpinned with version: {len(unpinned_matches)}")
104 | print(f"Completely unpinned: {len(unpinned_without_hash)}")
105 |
106 | return not has_errors
107 |
108 |
109 | def main() -> None:
110 | checker = GitHubActionChecker()
111 | files_to_check = sys.argv[1:]
112 |
113 | if not files_to_check:
114 | print("\033[91mError: No files provided to check\033[0m")
115 | print("Usage: python script.py ...")
116 | sys.exit(1)
117 |
118 | results = {file: checker.check_file(file) for file in files_to_check}
119 |
120 | # Print final summary
121 | print("\n\033[1mFinal Results:\033[0m")
122 | for file, passed in results.items():
123 | status = "\033[92m✓ Passed\033[0m" if passed else "\033[91m✗ Failed\033[0m"
124 | print(f"{status} {file}")
125 |
126 | if not all(results.values()):
127 | sys.exit(1)
128 |
129 |
130 | if __name__ == "__main__":
131 | main()
132 |
--------------------------------------------------------------------------------
/docs/topics/serialization.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Serialization"
3 | description: "Serialization and Deserialization of Rigging objects"
4 | public: true
5 | ---
6 |
7 | The following objects in Rigging have great serialization support for storage and retrieval:
8 |
9 | - `Chat`
10 | - `Completion`
11 | - `Generator`
12 | - `Model`
13 |
14 | Most of this stems from our use of Pydantic for core models, and we've included some helpful fields for reconstructing Chats and Completions.
15 |
16 | ## JSON Serialization
17 |
18 | Let's build a joke pipeline and serialize the final chat into JSON.
19 |
20 |
21 | ```python Serialization
22 | import rigging as rg
23 |
24 | class Joke(rg.Model):
25 | content: str
26 |
27 | chat = (
28 | await
29 | rg.get_generator("gpt-3.5-turbo")
30 | .chat(f"Provide 3 jokes each between {Joke.xml_tags()} tags.")
31 | .meta(tags=['joke'])
32 | .with_(temperature=1.25)
33 | .run()
34 | )
35 |
36 | chat.last.parse_set(Joke)
37 |
38 | serialized = chat.model_dump_json(indent=2)
39 | print(serialized)
40 | ```
41 |
42 | ```json Serialized JSON
43 | {
44 | "uuid": "891c3834-2588-4652-8371-e9746086fd46",
45 | "timestamp": "2024-05-10T11:44:15.501326",
46 | "messages": [
47 | {
48 | "role": "user",
49 | "parts": [],
50 | "content": "Provide 3 jokes each between tags."
51 | }
52 | ],
53 | "generated": [
54 | {
55 | "role": "assistant",
56 | "parts": [
57 | {
58 | "model": {
59 | "content": " Why was the math book sad? Because it had too many problems. "
60 | },
61 | "slice_": [
62 | 0,
63 | 75
64 | ]
65 | },
66 | {
67 | "model": {
68 | "content": " I told my wife she should embrace her mistakes. She gave me a hug. "
69 | },
70 | "slice_": [
71 | 76,
72 | 157
73 | ]
74 | },
75 | {
76 | "model": {
77 | "content": " Why did the scarecrow win an award? Because he was outstanding in his field. "
78 | },
79 | "slice_": [
80 | 158,
81 | 249
82 | ]
83 | }
84 | ],
85 | "content": " Why was the math book sad? Because it had too many problems. \n I told my wife she should embrace her mistakes. She gave me a hug. \n Why did the scarecrow win an award? Because he was outstanding in his field. "
86 | }
87 | ],
88 | "metadata": {
89 | "tags": [
90 | "joke"
91 | ]
92 | },
93 | "generator_id": "litellm!gpt-3.5-turbo,temperature=1.25"
94 | }
95 | ```
96 |
97 |
98 | You'll notice that every Chat gets a unique `id` field to help track them in a datastore like Elastic or Pandas. We also assign a `timestamp` to understand when the generation took place. We are also taking advantage of the `.meta()` rigging.chat.ChatPipeline.meta to add a tracking tag for filtering later.
99 |
100 | ## JSON Deserialization
101 |
102 | The JSON has everything required to reconstruct a Chat including a `generator_id` dynamically constructed to preserve the parameters used to create the generated message(s). We can now deserialize a chat from a datastore, and immediately step back into a `ChatPipeline` for exploration.
103 |
104 | ```python
105 | chat = rg.Chat.model_validate_json(serialized)
106 | print(chat.conversation)
107 | # [user]: Provide 3 jokes each between tags.
108 |
109 | # [assistant]:
110 | # Why was the math book sad? Because it had too many problems.
111 | # I told my wife she should embrace her mistakes. She gave me a hug.
112 | # Why did the scarecrow win an award? Because he was outstanding in his field.
113 |
114 | continued = chat.continue_("Please explain the first joke to me.").run()
115 | print(continued.last)
116 | # [assistant]: In the first joke, the pun is based on the double meaning of the word "problems."
117 | # The math book is described as being sad because it has "too many problems," which could be
118 | # interpreted as having both mathematical problems (equations to solve) and emotional difficulties.
119 | # This play on words adds humor to the joke.
120 | ```
121 |
122 | ## Pandas DataFrames
123 |
124 | Rigging also has helpers in the `rigging.data` module for performing conversions between Chat objects and other storage formats like Pandas. In `chats_to_df` the messages are flattened and stored with a `chat_id` column for grouping. `df_to_chats` allows you to reconstruct a list of Chat objects back from a DataFrame.
125 |
126 | ```python
127 | import rigging as rg
128 |
129 | chats = (
130 | await
131 | rg.get_generator("claude-3-haiku-20240307")
132 | .chat("Write me a haiku.")
133 | .run_many(3)
134 | )
135 |
136 | df = rg.data.chats_to_df(chats)
137 | # or
138 | df = chats.to_df()
139 |
140 | print(df.info())
141 |
142 | # RangeIndex: 6 entries, 0 to 5
143 | # Data columns (total 9 columns):
144 | # # Column Non-Null Count Dtype
145 | # --- ------ -------------- -----
146 | # 0 chat_id 6 non-null string
147 | # 1 chat_metadata 6 non-null string
148 | # 2 chat_generator_id 6 non-null string
149 | # 3 chat_timestamp 6 non-null datetime64[ms]
150 | # 4 generated 6 non-null bool
151 | # 5 role 6 non-null category
152 | # 6 parts 6 non-null string
153 | # 7 content 6 non-null string
154 | # 8 message_id 6 non-null string
155 | # dtypes: bool(1), category(1), datetime64[ms](1), string(6)
156 |
157 | df.content.apply(lambda x: len(x)).mean()
158 |
159 | # 60.166666666666664
160 |
161 | back = rg.data.df_to_chats(df)
162 | print(back[0].conversation)
163 |
164 | # [user]: Write me a haiku.
165 | #
166 | # [assistant]: Here's a haiku for you:
167 | #
168 | # Gentle breeze whispers,
169 | # Flowers bloom in vibrant hues,
170 | # Nature's simple bliss.
171 | ```
172 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
11 |
12 | Flexible LLM library for code and agents
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | Rigging is a lightweight LLM framework to make using language models in production code as simple and effective as possible. Here are the highlights:
25 |
26 | - **Structured Pydantic models** can be used interchangeably with unstructured text output.
27 | - LiteLLM as the default generator giving you **instant access to a huge array of models**.
28 | - Define prompts as python functions with **type hints and docstrings**.
29 | - Simple **tool use**, even for models which don't support them at the API.
30 | - Store different models and configs as **simple connection strings** just like databases.
31 | - Integrated tracing support with [Logfire](https://logfire.pydantic.dev/docs/).
32 | - Chat templating, forking, continuations, generation parameter overloads, stripping segments, etc.
33 | - Async batching and fast iterations for **large scale generation**.
34 | - Metadata, callbacks, and data format conversions.
35 | - Modern python with type hints, async support, pydantic validation, serialization, etc.
36 |
37 | ```py
38 | import rigging as rg
39 |
40 | @rg.prompt(generator_id="gpt-4")
41 | async def get_authors(count: int = 3) -> list[str]:
42 | """Provide famous authors."""
43 |
44 | print(await get_authors())
45 |
46 | # ['William Shakespeare', 'J.K. Rowling', 'Jane Austen']
47 | ```
48 |
49 | Rigging is built by [**dreadnode**](https://dreadnode.io) where we use it daily.
50 |
51 | ## Installation
52 |
53 | We publish every version to Pypi:
54 | ```bash
55 | pip install rigging
56 | ```
57 |
58 | If you want to build from source:
59 | ```bash
60 | cd rigging/
61 | poetry install
62 | ```
63 |
64 | ## Supported LLMs
65 |
66 | Rigging will run just about any language model:
67 |
68 | - Any model from [**LiteLLM**](https://litellm.vercel.app/docs/providers)
69 | - Any model from [**vLLM**](https://docs.vllm.ai/en/latest/models/supported_models.html)
70 | - Any model from [**transformers**](https://huggingface.co/docs/transformers/)
71 |
72 | ### API Keys
73 |
74 | Pass the `api_key` in an generator id or use standard environment variables.
75 |
76 | ```py
77 | rg.get_generator("gpt-4-turbo,api_key=...")
78 | ```
79 |
80 | ```bash
81 | export OPENAI_API_KEY=...
82 | export MISTRAL_API_KEY=...
83 | export ANTHROPIC_API_KEY=...
84 | ...
85 | ```
86 |
87 | Check out [the docs](https://docs.dreadnode.io/open-source/rigging/topics/generators#api-keys) for more.
88 |
89 | ## Getting Started
90 |
91 | **Check out the guide [in the docs](https://docs.dreadnode.io/open-source/rigging/intro#getting-started)**
92 |
93 | 1. **Get a generator** using a connection string.
94 | 2. Build a **chat** or **completion** pipeline
95 | 3. **Run** the pipeline and get the output.
96 |
97 | ```py
98 | import rigging as rg
99 | import asyncio
100 |
101 | async def main():
102 | # 1 - Get a generator
103 | generator = rg.get_generator("claude-3-sonnet-20240229")
104 |
105 | # 2 - Build a chat pipeline
106 | pipeline = generator.chat(
107 | [
108 | {"role": "system", "content": "Talk like a pirate."},
109 | {"role": "user", "content": "Say hello!"},
110 | ]
111 | )
112 |
113 | # 3 - Run the pipeline
114 | chat = await pipeline.run()
115 | print(chat.conversation)
116 |
117 | # Run the main function
118 | asyncio.run(main())
119 |
120 | # [system]: Talk like a pirate.
121 | # [user]: Say hello!
122 | # [assistant]: Ahoy, matey! Here be the salty sea dog ready to trade greetings wit' ye. Arrr!
123 | ```
124 |
125 | Want more?
126 |
127 | - Use [structured pydantic parsing](https://docs.dreadnode.io/open-source/rigging/topics/chats-and-messages#parsed-parts)
128 | - Check out [raw completions](https://docs.dreadnode.io/open-source/rigging/topics/completions/)
129 | - Give the LLM [access to tools](https://docs.dreadnode.io/open-source/rigging/topics/tools/)
130 | - Track behavior with [tracing](https://docs.dreadnode.io/open-source/rigging/topics/tracing/)
131 | - Play with [generation params](https://docs.dreadnode.io/open-source/rigging/topics/generators/#overload-generation-params)
132 | - Use [callbacks in the pipeline](https://docs.dreadnode.io/open-source/rigging/topics/callbacks-and-mapping/)
133 | - Scale up with [iterating and batching](https://docs.dreadnode.io/open-source/rigging/topics/iterating-and-batching/)
134 | - Save your work with [serialization](https://docs.dreadnode.io/open-source/rigging/topics/serialization/)
135 |
136 | ## Examples
137 |
138 | - Basic interactive chat: [**chat.py**](examples/chat.py)
139 | - Jupyter code interpreter: [**jupyter.py**](examples/jupyter.py)
140 | - OverTheWire Bandit Agent: [**bandit.py**](examples/bandit.py)
141 | - Damn Vulnerable Restaurant Agent: [**dvra.py**](examples/dvra.py)
142 | - RAG Pipeline: [**rag.py**](examples/rag.py) (from [kyleavery](https://github.com/kyleavery/))
143 | - Integrating dreadnode-owned [robopages](https://github.com/dreadnode/robopages-cli) as a tool server (basic nmap scan example): [**rigging_example.py**](https://github.com/dreadnode/robopages-cli/blob/main/examples/rigging_example.py)
144 |
145 | ## Documentation
146 |
147 | **[docs.dreadnode.io](https://docs.dreadnode.io/open-source/rigging/intro)** has everything you need.
148 |
149 | ## Star History
150 |
151 | [](https://star-history.com/#dreadnode/rigging&Date)
152 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 | from textwrap import dedent
3 | from xml.sax.saxutils import escape
4 |
5 | import pytest
6 |
7 | from rigging.model import Model, attr, element
8 |
9 | # mypy: disable-error-code=empty-body
10 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001
11 |
12 |
13 | class SimpleModel(Model):
14 | """A simple model to test basic element generation."""
15 |
16 | content: str = element(examples=["Hello, World!"])
17 | """The main content of the model."""
18 |
19 |
20 | class NoExampleModel(Model):
21 | """A model to test fallback to an empty element when no example is given."""
22 |
23 | name: str
24 | """The name of the entity."""
25 |
26 |
27 | class AttrAndElementModel(Model):
28 | """Tests a model with both an attribute and a child element."""
29 |
30 | id: int = attr(examples=[123])
31 | """The unique identifier (attribute)."""
32 | value: str = element(examples=["Some value"])
33 | """The primary value (element)."""
34 |
35 |
36 | class DocstringDescriptionModel(Model):
37 | """Tests that field docstrings are correctly used as descriptions."""
38 |
39 | field1: str = element(examples=["val1"])
40 | """This is the description for field1."""
41 | field2: bool = element(examples=[True])
42 | """This is the description for field2."""
43 |
44 |
45 | class ParameterDescriptionModel(Model):
46 | """Tests that the `description` parameter overrides a field's docstring."""
47 |
48 | param: str = element(
49 | examples=["override"], description="This description is from the `description` parameter."
50 | )
51 | """This docstring should be ignored in the XML example."""
52 |
53 |
54 | class SpecialCharsModel(Model):
55 | """Tests proper escaping of special XML characters in examples and comments."""
56 |
57 | comment: str = element(examples=["ok"])
58 | """This comment contains < and > & special characters."""
59 | data: str = element(examples=["&'"])
60 | """This element's example contains special XML characters."""
61 |
62 |
63 | # This class definition is based on the one you provided in the prompt.
64 | class Analysis(Model, tag="analysis"):
65 | """A model to validate the exact output requested in the prompt."""
66 |
67 | priority: t.Literal["low", "medium", "high", "critical"] = element(examples=["medium"])
68 | """Triage priority for human follow-up."""
69 | tags: str = element("tags", examples=["admin panel, error message, legacy"])
70 | """A list of specific areas within the screenshot that are noteworthy or require further examination."""
71 | summary: str = element()
72 | """A markdown summary explaining *why* the screenshot is interesting and what a human should investigate next."""
73 |
74 |
75 | @pytest.mark.parametrize(
76 | ("model_cls", "expected_xml"),
77 | [
78 | pytest.param(
79 | SimpleModel,
80 | """
81 |
82 |
83 | Hello, World!
84 |
85 | """,
86 | id="simple_model",
87 | ),
88 | pytest.param(
89 | NoExampleModel,
90 | """
91 |
92 | """,
93 | id="model_with_no_example",
94 | ),
95 | pytest.param(
96 | AttrAndElementModel,
97 | """
98 |
99 |
100 | Some value
101 |
102 | """,
103 | id="model_with_attribute_and_element",
104 | ),
105 | pytest.param(
106 | DocstringDescriptionModel,
107 | """
108 |
109 |
110 | val1
111 |
112 | True
113 |
114 | """,
115 | id="descriptions_from_docstrings",
116 | ),
117 | pytest.param(
118 | ParameterDescriptionModel,
119 | """
120 |
121 |
122 | override
123 |
124 | """,
125 | id="description_from_parameter_overrides_docstring",
126 | ),
127 | pytest.param(
128 | SpecialCharsModel,
129 | f"""
130 |
131 |
132 | ok
133 |
134 | {escape("&'")}
135 |
136 | """,
137 | id="escaping_of_special_characters",
138 | ),
139 | pytest.param(
140 | Analysis,
141 | """
142 |
143 |
144 | medium
145 |
146 | admin panel, error message, legacy
147 |
148 |
149 |
150 | """,
151 | id="user_provided_analysis_model",
152 | ),
153 | ],
154 | )
155 | def test_xml_example_generation(model_cls: type[Model], expected_xml: str) -> None:
156 | """
157 | Validates that the `xml_example()` class method produces the correct
158 | pretty-printed XML with examples and descriptions as comments.
159 | """
160 | actual_xml = model_cls.xml_example()
161 | assert dedent(actual_xml).strip() == dedent(expected_xml).strip()
162 |
--------------------------------------------------------------------------------
/tests/test_prompt.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from textwrap import dedent
3 | from typing import Annotated
4 |
5 | import pytest
6 |
7 | import rigging as rg
8 | from rigging.chat import Chat
9 |
10 | # mypy: disable-error-code=empty-body
11 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001
12 |
13 |
14 | def test_prompt_render_docstring_parse() -> None:
15 | @rg.prompt
16 | async def foo(name: str) -> str:
17 | """Say hello."""
18 |
19 | assert foo.docstring == "Say hello."
20 |
21 | @rg.prompt
22 | async def bar(name: str) -> str:
23 | """
24 | Say hello."""
25 |
26 | assert bar.docstring == "Say hello."
27 |
28 | @rg.prompt
29 | async def baz(name: str) -> str:
30 | """
31 | Say \
32 | hello.
33 |
34 | """
35 |
36 | assert baz.docstring == "Say hello."
37 |
38 |
39 | def test_basic_prompt_render() -> None:
40 | @rg.prompt
41 | async def hello(name: str) -> str:
42 | """Say hello."""
43 |
44 | rendered = hello.render("Alice")
45 | assert rendered == dedent(
46 | """\
47 | Say hello.
48 |
49 | Alice
50 |
51 | Produce the following output (use xml tags):
52 |
53 |
54 | """,
55 | )
56 |
57 |
58 | def test_prompt_render_with_docstring_variables() -> None:
59 | @rg.prompt
60 | async def greet(name: str, greeting: str = "Hello") -> str:
61 | """Say '{{ greeting }}' to {{ name }}."""
62 |
63 | rendered = greet.render("Bob")
64 | assert rendered == dedent(
65 | """\
66 | Say 'Hello' to Bob.
67 |
68 | Produce the following output (use xml tags):
69 |
70 |
71 | """,
72 | )
73 |
74 |
75 | def test_prompt_render_with_model_output() -> None:
76 | class Person(rg.Model):
77 | name: str = rg.element()
78 | age: int = rg.element()
79 |
80 | @rg.prompt
81 | async def create_person(name: str, age: int) -> Person:
82 | """Create a person."""
83 |
84 | rendered = create_person.render("Alice", 30)
85 | assert rendered == dedent(
86 | """\
87 | Create a person.
88 |
89 | Alice
90 |
91 | 30
92 |
93 | Produce the following output (use xml tags):
94 |
95 |
96 |
97 |
98 |
99 | """,
100 | )
101 |
102 |
103 | def test_prompt_render_with_list_output() -> None:
104 | @rg.prompt
105 | async def generate_numbers(count: int) -> list[int]:
106 | """Generate a list of numbers."""
107 |
108 | rendered = generate_numbers.render(5)
109 | assert rendered == dedent(
110 | """\
111 | Generate a list of numbers.
112 |
113 | 5
114 |
115 | Produce the following output for each item (use xml tags):
116 |
117 |
118 | """,
119 | )
120 |
121 |
122 | def test_prompt_render_with_tuple_output() -> None:
123 | @rg.prompt
124 | async def create_user(username: str) -> tuple[str, int]:
125 | """Create a new user."""
126 |
127 | rendered = create_user.render("johndoe")
128 | assert rendered == dedent(
129 | """\
130 | Create a new user.
131 |
132 | johndoe
133 |
134 | Produce the following outputs (use xml tags):
135 |
136 |
137 |
138 |
139 | """,
140 | )
141 |
142 |
143 | def test_prompt_render_with_tuple_output_ctx() -> None:
144 | @rg.prompt
145 | async def create_user(username: str) -> tuple[Annotated[str, rg.Ctx(tag="id")], int]:
146 | """Create a new user."""
147 |
148 | rendered = create_user.render("johndoe")
149 | assert rendered == dedent(
150 | """\
151 | Create a new user.
152 |
153 | johndoe
154 |
155 | Produce the following outputs (use xml tags):
156 |
157 |
158 |
159 |
160 | """,
161 | )
162 |
163 |
164 | def test_prompt_render_with_dataclass_output() -> None:
165 | @dataclass
166 | class User:
167 | username: str
168 | email: str
169 | age: int
170 |
171 | @rg.prompt
172 | async def register_user(username: str, email: str, age: int) -> User:
173 | """Register a new user: {{ username}}."""
174 |
175 | rendered = register_user.render("johndoe", "johndoe@example.com", 25)
176 | assert rendered == dedent(
177 | """\
178 | Register a new user: johndoe.
179 |
180 | johndoe@example.com
181 |
182 | 25
183 |
184 | Produce the following outputs (use xml tags):
185 |
186 |
187 |
188 |
189 |
190 |
191 | """,
192 | )
193 |
194 |
195 | def test_prompt_render_with_chat_return() -> None:
196 | @rg.prompt
197 | async def foo(input_: str) -> Chat:
198 | """Do something."""
199 |
200 | rendered = foo.render("bar")
201 | assert rendered == dedent(
202 | """\
203 | Do something.
204 |
205 | bar
206 | """,
207 | )
208 |
209 |
210 | def test_prompt_render_ctx_in_dataclass() -> None:
211 | @dataclass
212 | class User:
213 | username: str
214 | email: Annotated[str, rg.Ctx(prefix="The user email:", example="[test@email.com]")]
215 | age: Annotated[int, rg.Ctx(tag="override")]
216 |
217 | @rg.prompt
218 | async def register_user(username: str, email: str, age: int) -> User:
219 | """Register a new user: {{ username }}."""
220 |
221 | rendered = register_user.render("johndoe", "john@email.com", 30)
222 | assert rendered == dedent(
223 | """\
224 | Register a new user: johndoe.
225 |
226 | john@email.com
227 |
228 | 30
229 |
230 | Produce the following outputs (use xml tags):
231 |
232 |
233 |
234 | The user email:
235 | [test@email.com]
236 |
237 |
238 | """,
239 | )
240 |
241 |
242 | def test_prompt_parse_fail_nested_input() -> None:
243 | async def foo(arg: list[list[str]]) -> Chat: ...
244 |
245 | with pytest.raises(TypeError):
246 | rg.prompt(foo)
247 |
248 | async def bar(arg: tuple[int, str, tuple[str]]) -> Chat: ...
249 |
250 | with pytest.raises(TypeError):
251 | rg.prompt(bar)
252 |
253 |
254 | def test_prompt_parse_fail_unique_ouput() -> None:
255 | async def foo(arg: int) -> tuple[str, str]: ...
256 |
257 | with pytest.raises(TypeError):
258 | rg.prompt(foo)
259 |
--------------------------------------------------------------------------------
/rigging/error.py:
--------------------------------------------------------------------------------
1 | """
2 | We try to avoid creating custom exceptions unless they are necessary.
3 |
4 | We use the built-in and pydantic exceptions as much as possible.
5 | """
6 |
7 | import functools
8 | import typing as t
9 |
10 | import typing_extensions as te
11 |
12 | if t.TYPE_CHECKING:
13 | from rigging.chat import PipelineStep
14 | from rigging.message import Message
15 |
16 |
17 | # User Throwable Exceptions
18 |
19 |
20 | class Stop(Exception): # noqa: N818
21 | """
22 | Raise inside a pipeline to indicate a stopping condition.
23 |
24 | Example:
25 | ```
26 | import rigging as rg
27 |
28 | async def read_file(path: str) -> str:
29 | "Read the contents of a file."
30 |
31 | if no_more_files(path):
32 | raise rg.Stop("There are no more files to read.")
33 |
34 | ...
35 |
36 | chat = await pipeline.using(read_file).run()
37 | ```
38 | """
39 |
40 | def __init__(self, message: str):
41 | super().__init__(message)
42 | self.message = message
43 | """The message associated with the stop."""
44 |
45 |
46 | # Warnings
47 |
48 |
49 | class PipelineWarning(Warning):
50 | """
51 | Base class for all pipeline warnings.
52 |
53 | This is used to indicate that something unexpected happened during the pipeline execution,
54 | but it is not critical enough to stop the execution.
55 | """
56 |
57 |
58 | class ToolWarning(Warning):
59 | """
60 | Base class for all tool warnings.
61 |
62 | This is used to indicate that something unexpected happened during the tool execution,
63 | but it is not critical enough to stop the execution.
64 | """
65 |
66 |
67 | class MessageWarning(Warning):
68 | """
69 | Base class for all message warnings.
70 |
71 | This is used to indicate that something unexpected happened during the message processing,
72 | but it is not critical enough to stop the execution.
73 | """
74 |
75 |
76 | class TokenizerWarning(Warning):
77 | """
78 | Base class for all tokenization warnings.
79 |
80 | This is used to indicate that something unexpected happened during the tokenization process,
81 | but it is not critical enough to stop the execution.
82 | """
83 |
84 |
85 | class GeneratorWarning(Warning):
86 | """
87 | Base class for all generator warnings.
88 |
89 | This is used to indicate that something unexpected happened during the generator execution,
90 | but it is not critical enough to stop the execution.
91 | """
92 |
93 |
94 | # System Exceptions
95 |
96 |
97 | class UnknownToolError(Exception):
98 | """
99 | Raised when the an api tool call is made for an unknown tool.
100 | """
101 |
102 | def __init__(self, tool_name: str):
103 | super().__init__(f"Unknown tool call was requested for '{tool_name}'")
104 | self.tool_name = tool_name
105 | """The name of the tool which was unknown."""
106 |
107 |
108 | class ToolDefinitionError(Exception):
109 | """
110 | Raised when a tool cannot be properly defined.
111 | """
112 |
113 | def __init__(self, message: str):
114 | super().__init__(message)
115 |
116 |
117 | class ExhaustedMaxRoundsError(Exception):
118 | """
119 | Raised when the maximum number of rounds is exceeded while generating.
120 | """
121 |
122 | def __init__(self, max_rounds: int):
123 | super().__init__(f"Exhausted max rounds ({max_rounds}) while generating")
124 | self.max_rounds = max_rounds
125 | """The number of rounds which was exceeded."""
126 |
127 |
128 | class MessagesExhaustedMaxRoundsError(ExhaustedMaxRoundsError):
129 | """
130 | Raised when the maximum number of rounds is exceeded while generating messages.
131 | """
132 |
133 | def __init__(self, max_rounds: int, messages: list["Message"]):
134 | super().__init__(max_rounds)
135 | self.messages = messages
136 | """The messages which were being generated when the exception occurred."""
137 |
138 |
139 | class CompletionExhaustedMaxRoundsError(ExhaustedMaxRoundsError):
140 | """
141 | Raised when the maximum number of rounds is exceeded while generating completions.
142 | """
143 |
144 | def __init__(self, max_rounds: int, completion: str):
145 | super().__init__(max_rounds)
146 | self.completion = completion
147 | """The completion which was being generated when the exception occurred."""
148 |
149 |
150 | class MaxDepthError(Exception):
151 | """
152 | Raised when the maximum depth is exceeded while generating.
153 | """
154 |
155 | def __init__(self, max_depth: int, step: "PipelineStep", reference: str):
156 | super().__init__(f"Exceeded max depth ({max_depth}) while generating ('{reference}')")
157 | self.max_depth = max_depth
158 | """The maximum depth of nested pipeline generations which was exceeded."""
159 | self.step = step
160 | """The pipeline step which cause the depth error."""
161 |
162 |
163 | class InvalidGeneratorError(Exception):
164 | """
165 | Raised when an invalid identifier is specified when getting a generator.
166 | """
167 |
168 | def __init__(self, model: str):
169 | super().__init__(f"Invalid model specified: {model}")
170 |
171 |
172 | class InvalidTokenizerError(Exception):
173 | """
174 | Raised when an invalid tokenizer is specified.
175 | """
176 |
177 | def __init__(self, tokenizer: str):
178 | super().__init__(f"Invalid tokenizer specified: {tokenizer}")
179 | self.tokenizer = tokenizer
180 | """The name of the tokenizer which was invalid."""
181 |
182 |
183 | class MissingModelError(Exception):
184 | """
185 | Raised when a model is missing when parsing a message.
186 | """
187 |
188 | def __init__(self, content: str):
189 | super().__init__(content)
190 |
191 |
192 | class ProcessingError(Exception):
193 | """
194 | Raised when an error occurs during internal generator processing.
195 | """
196 |
197 | def __init__(self, content: str):
198 | super().__init__(content)
199 |
200 |
201 | P = te.ParamSpec("P")
202 | R = t.TypeVar("R")
203 |
204 |
205 | def raise_as(
206 | error_type: type[Exception],
207 | message: str,
208 | ) -> t.Callable[[t.Callable[P, R]], t.Callable[P, R]]:
209 | "When the wrapped function raises an exception, `raise ... from` with the new error type."
210 |
211 | def _raise_as(func: t.Callable[P, R]) -> t.Callable[P, R]:
212 | @functools.wraps(func)
213 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
214 | try:
215 | return func(*args, **kwargs)
216 | except Exception as e:
217 | error = error_type(message)
218 | raise error from e
219 |
220 | if wrapper.__doc__ is None:
221 | wrapper.__doc__ = ""
222 |
223 | wrapper.__doc__ += f"\n\nRaises:\n {error_type.__name__}{': ' + message}"
224 |
225 | return wrapper
226 |
227 | return _raise_as
228 |
--------------------------------------------------------------------------------
/rigging/generator/vllm_.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import inspect
3 | import typing as t
4 |
5 | import torch # type: ignore [import-not-found, import-untyped, unused-ignore]
6 | import vllm # type: ignore [import-not-found,import-untyped, unused-ignore]
7 |
8 | from rigging.generator.base import (
9 | GeneratedMessage,
10 | GeneratedText,
11 | GenerateParams,
12 | Generator,
13 | trace_messages,
14 | trace_str,
15 | )
16 |
17 | if t.TYPE_CHECKING:
18 | from rigging.message import Message
19 |
20 | # Any batch over this size will trigger a dedicated
21 | # cache warmup step
22 | CACHE_TRIGGER = 8
23 |
24 |
25 | class VLLMGenerator(Generator):
26 | """
27 | Generator backed by the vLLM library for local model loading.
28 |
29 | Find more information about supported models and formats [in their docs.](https://docs.vllm.ai/en/latest/index.html)
30 |
31 | Warning:
32 | The use of VLLM requires the `vllm` package to be installed directly or by
33 | installing rigging as `rigging[all]`.
34 |
35 | Note:
36 | This generator doesn't leverage any async capabilities.
37 |
38 | Note:
39 | The model load into memory will occur lazily when the first generation is requested.
40 | If you'd want to force this to happen earlier, you can use the
41 | [`.load()`][rigging.generator.Generator.load] method.
42 |
43 | To unload, call [`.unload()`][rigging.generator.Generator.unload].
44 | """
45 |
46 | dtype: str = "auto"
47 | """Tensor dtype passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)"""
48 | quantization: str | None = None
49 | """Quantiziation passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)"""
50 | gpu_memory_utilization: float = 0.9
51 | """Memory utilization passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)"""
52 | enforce_eager: bool = False
53 | """Eager enforcement passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)"""
54 | trust_remote_code: bool = False
55 | """Trust remote code passed to [`vllm.LLM`](https://docs.vllm.ai/en/latest/offline_inference/llm.html)"""
56 |
57 | # TODO: We should look at leveraging the AsyncLLMEngine or an
58 | # async alternative to the LLM class to allow for async generation
59 |
60 | _llm: vllm.LLM | None = None
61 |
62 | @property
63 | def llm(self) -> vllm.LLM:
64 | """The underlying [`vLLM model`](https://docs.vllm.ai/en/latest/offline_inference/llm.html) instance."""
65 | # Lazy initialization
66 | if self._llm is None:
67 | self._llm = vllm.LLM(
68 | self.model,
69 | dtype=self.dtype,
70 | quantization=self.quantization,
71 | gpu_memory_utilization=self.gpu_memory_utilization,
72 | enforce_eager=self.enforce_eager,
73 | trust_remote_code=self.trust_remote_code,
74 | )
75 | return self._llm
76 |
77 | @classmethod
78 | def from_obj(
79 | cls,
80 | model: str,
81 | llm: vllm.LLM,
82 | *,
83 | params: GenerateParams | None = None,
84 | ) -> "VLLMGenerator":
85 | """Create a generator from an existing vLLM instance.
86 |
87 | Args:
88 | llm: The vLLM instance to create the generator from.
89 |
90 | Returns:
91 | The VLLMGenerator instance.
92 | """
93 | generator = cls(model=model, params=params or GenerateParams(), api_key=None)
94 | generator._llm = llm
95 | return generator
96 |
97 | def load(self) -> "VLLMGenerator":
98 | _ = self.llm
99 | return self
100 |
101 | def unload(self) -> "VLLMGenerator":
102 | del self._llm
103 | gc.collect()
104 | torch.cuda.empty_cache()
105 | return self
106 |
107 | def _generate(
108 | self,
109 | texts: list[str],
110 | params: t.Sequence[GenerateParams],
111 | ) -> list[GeneratedText]:
112 | sampling_params_args = list(
113 | inspect.signature(vllm.SamplingParams.__init__).parameters.keys(),
114 | )
115 | sampling_params = (
116 | [
117 | vllm.SamplingParams(
118 | **{
119 | k: v
120 | for k, v in self.params.merge_with(p).to_dict().items()
121 | if k in sampling_params_args
122 | },
123 | )
124 | for p in params
125 | ]
126 | if params
127 | else None
128 | )
129 |
130 | # Do a cache warmup step if we have a lot of texts
131 | if len(texts) > CACHE_TRIGGER:
132 | self.llm.generate(
133 | prompts=min(texts, key=len),
134 | use_tqdm=False,
135 | )
136 |
137 | outputs = self.llm.generate(
138 | prompts=texts,
139 | sampling_params=sampling_params,
140 | use_tqdm=False,
141 | )
142 | return [
143 | GeneratedText(
144 | text=o.outputs[-1].text,
145 | stop_reason=o.outputs[-1].finish_reason or "unknown",
146 | extra={
147 | "request_id": o.request_id,
148 | "metrics": o.metrics,
149 | "stop_token": o.outputs[-1].stop_reason,
150 | },
151 | )
152 | for o in outputs
153 | ]
154 |
155 | async def generate_messages(
156 | self,
157 | messages: t.Sequence[t.Sequence["Message"]],
158 | params: t.Sequence[GenerateParams],
159 | ) -> t.Sequence[GeneratedMessage]:
160 | message_dicts = [[m.to_openai() for m in _messages] for _messages in messages]
161 | tokenizer = self.llm.get_tokenizer()
162 | if not hasattr(tokenizer, "apply_chat_template"):
163 | raise RuntimeError(
164 | "The tokenizer does not support the apply_chat_template method.",
165 | )
166 |
167 | texts = tokenizer.apply_chat_template(
168 | message_dicts,
169 | add_generation_prompt=True,
170 | tokenize=False,
171 | )
172 | generated_texts = self._generate(t.cast("list[str]", texts), params=params)
173 | generated = [g.to_generated_message() for g in generated_texts]
174 |
175 | for i, (in_messages, out_message) in enumerate(zip(messages, generated, strict=False)):
176 | trace_messages(in_messages, f"Messages {i + 1}/{len(in_messages)}")
177 | trace_messages([out_message], f"Response {i + 1}/{len(in_messages)}")
178 |
179 | return generated
180 |
181 | async def generate_texts(
182 | self,
183 | texts: t.Sequence[str],
184 | params: t.Sequence[GenerateParams],
185 | ) -> t.Sequence[GeneratedText]:
186 | generated = self._generate(list(texts), params=params)
187 |
188 | for i, (text, response) in enumerate(zip(texts, generated, strict=False)):
189 | trace_str(text, f"Text {i + 1}/{len(texts)}")
190 | trace_str(response, f"Generated {i + 1}/{len(texts)}")
191 |
192 | return generated
193 |
--------------------------------------------------------------------------------
/tests/test_watchers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import typing as t
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | import rigging as rg
8 | from rigging.chat import Chat
9 | from rigging.message import Message
10 |
11 | # ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001, FBT001, FBT002, N803
12 |
13 |
14 | @pytest.fixture
15 | def sample_chats() -> list[Chat]:
16 | chat1 = Chat(
17 | messages=[
18 | Message(role="user", content="Hello"),
19 | Message(role="assistant", content="Hi there!"),
20 | ],
21 | )
22 | chat2 = Chat(
23 | messages=[
24 | Message(role="user", content="How are you?"),
25 | Message(role="assistant", content="I'm doing well!"),
26 | ],
27 | )
28 | return [chat1, chat2]
29 |
30 |
31 | @pytest.mark.asyncio
32 | async def test_write_chats_to_jsonl(tmp_path: Path, sample_chats: list[Chat]) -> None:
33 | output_file = tmp_path / "chats.jsonl"
34 | watcher = rg.watchers.write_chats_to_jsonl(output_file)
35 |
36 | await watcher(sample_chats)
37 |
38 | assert output_file.exists()
39 |
40 | # read and verify contents
41 | with output_file.open() as f:
42 | lines = f.readlines()
43 | assert len(lines) == 2
44 |
45 | # verify each chat was written correctly
46 | for i, line in enumerate(lines):
47 | saved_chat = json.loads(line)
48 | original_chat = sample_chats[i]
49 | for k, message in enumerate(saved_chat["messages"]):
50 | assert message["role"] == original_chat.messages[k].role
51 | assert message["content"] == [
52 | {"text": original_chat.messages[k].content, "type": "text"},
53 | ]
54 |
55 |
56 | @pytest.mark.asyncio
57 | async def test_write_chats_to_jsonl_append(tmp_path: Path, sample_chats: list[Chat]) -> None:
58 | output_file = tmp_path / "chats.jsonl"
59 | watcher = rg.watchers.write_chats_to_jsonl(output_file)
60 |
61 | # write first batch
62 | await watcher(sample_chats[:1])
63 |
64 | # write second batch
65 | await watcher(sample_chats[1:])
66 |
67 | with output_file.open() as f:
68 | lines = f.readlines()
69 | assert len(lines) == 2
70 |
71 |
72 | @pytest.mark.asyncio
73 | async def test_write_chats_to_jsonl_replace(tmp_path: Path, sample_chats: list[Chat]) -> None:
74 | output_file = tmp_path / "chats.jsonl"
75 |
76 | # write initial content
77 | watcher1 = rg.watchers.write_chats_to_jsonl(output_file)
78 | await watcher1(sample_chats)
79 |
80 | # create new watcher with replace=True
81 | watcher2 = rg.watchers.write_chats_to_jsonl(output_file, replace=True)
82 |
83 | # write only one chat - should replace previous content
84 | await watcher2(sample_chats[:1])
85 |
86 | with output_file.open() as f:
87 | lines = f.readlines()
88 | assert len(lines) == 1
89 | saved_chat = json.loads(lines[0])
90 | original_chat = sample_chats[0]
91 | for k, message in enumerate(saved_chat["messages"]):
92 | assert message["role"] == original_chat.messages[k].role
93 | assert message["content"] == [
94 | {"text": original_chat.messages[k].content, "type": "text"},
95 | ]
96 |
97 | # write another chat - should append since already replaced once
98 | await watcher2(sample_chats[1:2])
99 |
100 | with output_file.open() as f:
101 | lines = f.readlines()
102 | assert len(lines) == 2
103 |
104 | # verify both chats were written correctly
105 | for i, line in enumerate(lines):
106 | saved_chat = json.loads(line)
107 | original_chat = sample_chats[i]
108 | for j, message in enumerate(saved_chat["messages"]):
109 | assert message["role"] == original_chat.messages[j].role
110 | assert message["content"] == [
111 | {"text": original_chat.messages[j].content, "type": "text"},
112 | ]
113 |
114 |
115 | class MockS3Client:
116 | class exceptions: # noqa: N801
117 | class ClientError(Exception):
118 | def __init__(self, code: str):
119 | self.response = {"Error": {"Code": code}}
120 |
121 | class Body:
122 | def __init__(self, content: str):
123 | self.content = content
124 |
125 | def read(self) -> bytes:
126 | return self.content.encode()
127 |
128 | def __init__(self) -> None:
129 | self.buckets: dict[str, t.Any] = {"test-bucket": {}}
130 |
131 | def head_object(self, Bucket: str, Key: str) -> t.Any:
132 | if Bucket not in self.buckets:
133 | raise self.exceptions.ClientError("404")
134 | if Key not in self.buckets[Bucket]:
135 | raise self.exceptions.ClientError("404")
136 | return self.buckets[Bucket][Key]
137 |
138 | def get_object(self, Bucket: str, Key: str) -> t.Any:
139 | if Bucket not in self.buckets:
140 | raise self.exceptions.ClientError("404")
141 | if Key not in self.buckets[Bucket]:
142 | raise self.exceptions.ClientError("404")
143 | return {"Body": MockS3Client.Body(self.buckets[Bucket][Key])}
144 |
145 | def delete_object(self, Bucket: str, Key: str) -> None:
146 | if Bucket not in self.buckets:
147 | raise self.exceptions.ClientError("404")
148 | if Key not in self.buckets[Bucket]:
149 | raise self.exceptions.ClientError("404")
150 | del self.buckets[Bucket][Key]
151 |
152 | def put_object(self, Bucket: str, Key: str, Body: str) -> None:
153 | self.buckets[Bucket][Key] = Body
154 |
155 |
156 | @pytest.mark.asyncio
157 | async def test_write_chats_to_s3(sample_chats: list[Chat]) -> None:
158 | s3_mock_client = MockS3Client()
159 |
160 | bucket = "test-bucket"
161 | key = "test/chats.jsonl"
162 |
163 | expected_content = ""
164 | for chat in sample_chats[:1]:
165 | expected_content += chat.model_dump_json() + "\n"
166 |
167 | watcher = rg.watchers.write_chats_to_s3(s3_mock_client, bucket, key) # type: ignore [arg-type]
168 |
169 | # write first batch
170 | await watcher(sample_chats[:1])
171 |
172 | got = s3_mock_client.get_object(Bucket=bucket, Key=key)
173 | assert got["Body"].read() == expected_content.encode()
174 |
175 | # write second batch
176 | await watcher(sample_chats[1:])
177 |
178 | expected_content = ""
179 | for chat in sample_chats:
180 | expected_content += chat.model_dump_json() + "\n"
181 |
182 | got = s3_mock_client.get_object(Bucket=bucket, Key=key)
183 | assert got["Body"].read() == expected_content.encode()
184 |
185 | # create a new watcher with replace=True
186 | watcher = rg.watchers.write_chats_to_s3(s3_mock_client, bucket, key, replace=True) # type: ignore [arg-type]
187 |
188 | # write a single chat
189 | await watcher(sample_chats[:1])
190 |
191 | expected_content = ""
192 | for chat in sample_chats[:1]:
193 | expected_content += chat.model_dump_json() + "\n"
194 |
195 | # verify it's been replaced
196 | got = s3_mock_client.get_object(Bucket=bucket, Key=key)
197 | assert got["Body"].read() == expected_content.encode()
198 |
199 | # write second batch
200 | await watcher(sample_chats[1:])
201 |
202 | expected_content = ""
203 | for chat in sample_chats:
204 | expected_content += chat.model_dump_json() + "\n"
205 |
206 | # verify replace happens only once
207 | got = s3_mock_client.get_object(Bucket=bucket, Key=key)
208 | assert got["Body"].read() == expected_content.encode()
209 |
--------------------------------------------------------------------------------
/docs/topics/iterating-and-batching.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Iterating and Batching"
3 | description: "Iterating over messages, params, and generators, as well as batching of requests."
4 | public: true
5 | ---
6 |
7 | Rigging has good support for iterating over messages, parameters, and generators, as well as large batching of requests. How efficiently these mechanisms operates is dependent on the underlying generator that's being used, but Rigging has been developed with scale in mind.
8 |
9 | ## Multiple Generations
10 |
11 | The `run_many` functions let you scale out generation N times with the same inputs:
12 |
13 | - `ChatPipeline.run_many()`
14 | - `CompletionPipeline.run_many()`
15 | - `Prompt.run_many()`
16 |
17 |
18 | ```python Run Many Code
19 | import rigging as rg
20 |
21 | async def check_animal(chats: list[rg.Chat]) -> list[rg.Chat]:
22 | return [
23 | await chat.continue_(f"Why did you pick that animal?").meta(questioned=True).run()
24 | if any(a in chat.last.content.lower() for a in ["cat", "dog", "cow", "mouse"])
25 | else chat
26 | for chat in chats
27 | ]
28 |
29 | chats = (
30 | await
31 | rg.get_generator("gpt-3.5-turbo")
32 | .chat("Tell me a joke about an animal.")
33 | .map(check_animal)
34 | .run_many(3)
35 | )
36 |
37 | for i, chat in enumerate(chats):
38 | questioned = chat.metadata.get("questioned", False)
39 | print(f"--- Chat {i+1} (?: {questioned}) ---")
40 | print(chat.conversation)
41 | print()
42 | ```
43 |
44 | ```text Outputs
45 | --- Chat 1 (?: False) ---
46 | [user]: Tell me a joke about an animal.
47 |
48 | [assistant]: Why did the spider go to the computer?
49 |
50 | To check his website!
51 |
52 | --- Chat 2 (?: False) ---
53 | [user]: Tell me a joke about an animal.
54 |
55 | [assistant]: Why did the chicken join a band? Because it had the drumsticks!
56 |
57 | --- Chat 3 (?: True) ---
58 | [user]: Tell me a joke about an animal.
59 |
60 | [assistant]: Why don't elephants use computers?
61 |
62 | Because they're afraid of the mouse!
63 |
64 | [user]: Why did you pick that animal?
65 |
66 | [assistant]: I chose an elephant because they are known for their intelligence and gentle nature, making them a popular subject for jokes and humorous anecdotes. Plus, imagining an elephant trying to use a computer and being scared of a tiny mouse is a funny visual image!
67 | ```
68 |
69 |
70 | ## Batching Inputs
71 |
72 | The `run_batch` functions let you batch across a set of inputs:
73 |
74 | - `ChatPipeline.run_batch()`
75 | - `CompletionPipeline.run_batch()`
76 |
77 | As processing proceeds with things like `.then()` or `.map()`, the chats will resolve individually and collapse into the final results.
78 |
79 |
80 | ```python Batching Inputs
81 | import rigging as rg
82 | from rigging.model import CommaDelimitedAnswer
83 |
84 | pipeline = (
85 | rg.get_generator('gpt-4-turbo')
86 | .chat({
87 | "role": "system",
88 | "content": f"Always respond with {CommaDelimitedAnswer.xml_tags()} tags."}
89 | )
90 | .until_parsed_as(CommaDelimitedAnswer, attempt_recovery=True)
91 | )
92 |
93 | many = [f"Give me 3 famous {thing}" for thing in ["authors", "painters", "musicians", "hackers"]]
94 |
95 | chats = await pipeline.run_batch(many, on_failed='skip')
96 |
97 | for i, chat in enumerate(chats):
98 | print(f"--- Chat {i+1} ({len(chat)}) ---")
99 | print(chat.last.parse(CommaDelimitedAnswer).items)
100 | print()
101 | ```
102 |
103 | ```text Output
104 | --- Chat 1 (2) ---
105 | ['Leonardo da Vinci', 'Vincent van Gogh', 'Pablo Picasso']
106 |
107 | --- Chat 2 (2) ---
108 | ['Michael Jackson', 'Beyonce', 'The Beatles']
109 | ```
110 |
111 |
112 |
113 | **Skipping failed results**
114 |
115 | Passing `on_failed='skip'` to `.run_batch`, or configuring a pipeline with `.catch(..., on_failed='skip')` will cause the function to ignore any parsing errors like `ExhaustedMaxRoundsError` and only return successful chats.
116 |
117 |
118 | ## Batching Parameters
119 |
120 | In addition to batching against input messages or strings, you can fix a single input and build a batch across a set of generation parameters. The inputs to `.run_batch` will scale either the generate parameters or the input messages if either is a single item.
121 |
122 |
123 | ```python Batching
124 | import rigging as rg
125 |
126 | pipeline = rg.get_generator("gpt-3.5-turbo").chat()
127 |
128 | chats = await pipeline.run_batch(
129 | ["Tell me a short fact about an japanese city."],
130 | [rg.GenerateParams(temperature=t) for t in [0.6, 0.9, 1.2, 1.5, 1.8]]
131 | )
132 |
133 | for i, chat in enumerate(chats):
134 | print(f"--- Chat {i+1} ---")
135 | print(chat.generator_id)
136 | print()
137 | print(chat.conversation)
138 | print()
139 | ```
140 |
141 | ```text Output
142 | --- Chat 1 ---
143 | litellm!gpt-3.5-turbo,temperature=0.6
144 |
145 | [assistant]: Tokyo, the capital city of Japan, is the most populous
146 | metropolitan area in the world, with over 37 million residents.
147 |
148 | --- Chat 2 ---
149 | litellm!gpt-3.5-turbo,temperature=0.9
150 |
151 | [assistant]: Tokyo is the largest metropolitan area in the world,
152 | with a population of over 37 million people.
153 |
154 | --- Chat 3 ---
155 | litellm!gpt-3.5-turbo,temperature=1.2
156 |
157 | [assistant]: Kyoto, a city in Japan known for its historic temples
158 | and gardens, was once the capital of Japan for over 1,000 years from
159 | 794 until the capital was moved to Tokyo in 1869.
160 |
161 | --- Chat 4 ---
162 | litellm!gpt-3.5-turbo,temperature=1.5
163 |
164 | [assistant]: Nagoya, Japan is known for being one of the leading
165 | manufacturing and industrial regions in the country, with a strong
166 | automotive presence including major factories for Toyota, Honda, and Mitsubishi.
167 |
168 | --- Chat 5 ---
169 | litellm!gpt-3.5-turbo,temperature=1.8
170 |
171 | [assistant]: Sendai is the largest city in the Tohoku region of
172 | Japan and is known for its incredible natural scenery, such as the
173 | nearby Sendai Bay and Zuihoden mausoleum.
174 | ```
175 |
176 |
177 | ## Iterating over Models
178 |
179 | The `run_over` functions let you execute generation over a set of generators:
180 |
181 | - `ChatPipeline.run_over()`
182 | - `CompletionPipeline.run_over()`
183 | - `Prompt.run_over()`
184 |
185 | Generators can be passed as string identifiers or full instances of `Generator`. By default the original generator associated with the `ChatPipeline` is included in the iteration, configurable with the `include_original` parameter.
186 |
187 | Much like the `run_many` and `run_batch` functions, you can control the handling of failures with the `on_failed` parameter.
188 |
189 |
190 | ```python Run Over
191 | import rigging as rg
192 | from rigging.model import Answer
193 |
194 | QUESTION = "What is the capital of France?"
195 | ANSWER = "paris"
196 |
197 | async def score_output(chats: list[rg.Chat]) -> list[rg.Chat]:
198 | return [
199 | chat.meta(correct=chat.last.parse(Answer).content.lower() == ANSWER)
200 | for chat in chats
201 | ]
202 |
203 | chats = (
204 | await
205 | rg.get_generator("gpt-3.5-turbo")
206 | .chat([
207 | {"role": "system", "content": f"Always respond in one word between {Answer.xml_tags()} tags."},
208 | {"role": "user", "content": QUESTION}
209 | ])
210 | .until_parsed_as(Answer, max_rounds=3)
211 | .map(score_output)
212 | .run_over("gpt-4-turbo", "claude-3-haiku-20240307,temperature=0.5", "claude-3-sonnet-20240229")
213 | )
214 |
215 | for chat in chats:
216 | print("Model: ", chat.generator.model)
217 | print("Msg: ", chat.last.content)
218 | print("Meta: ", chat.metadata)
219 | print()
220 | ```
221 |
222 | ```text Outputs
223 | Model: gpt-4-turbo
224 | Msg: Paris
225 | Meta: {'correct': True}
226 |
227 | Model: claude-3-haiku-20240307
228 | Msg: Paris
229 | Meta: {'correct': True}
230 |
231 | Model: claude-3-sonnet-20240229
232 | Msg: Paris
233 | Meta: {'correct': True}
234 |
235 | Model: openai/gpt-3.5-turbo
236 | Msg: Paris
237 | Meta: {'correct': True}
238 | ```
239 |
240 |
--------------------------------------------------------------------------------
/examples/tokenize.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "8f675f31",
6 | "metadata": {},
7 | "source": [
8 | "### Load a math dataset\n"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "f679b928",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import typing as t\n",
19 | "\n",
20 | "import datasets\n",
21 | "import contextlib\n",
22 | "import rigging as rg\n",
23 | "\n",
24 | "def is_basic_question(sample: dict[str, t.Any]) -> bool:\n",
25 | " with contextlib.suppress(ValueError):\n",
26 | " float(sample[\"answer\"])\n",
27 | " return True\n",
28 | " return False\n",
29 | " \n",
30 | "dataset = [\n",
31 | " rg.Message(\"user\", sample[\"problem\"], metadata={**sample})\n",
32 | " for sample in datasets.load_dataset(\"HuggingFaceH4/MATH-500\", split=\"test\").filter(is_basic_question)\n",
33 | "]\n",
34 | "\n",
35 | "print(f\"Loaded {len(dataset)} basic questions from MATH-500.\")"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "id": "33fe0a3f",
41 | "metadata": {},
42 | "source": [
43 | "### Define tools\n"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "28abc318",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "import rigging as rg\n",
54 | "from io import StringIO\n",
55 | "from contextlib import redirect_stdout\n",
56 | "\n",
57 | "# Dedicated calculator tool\n",
58 | "\n",
59 | "class Calculator:\n",
60 | " @rg.tool_method(catch=True)\n",
61 | " def add(self, x: float, y: float) -> float:\n",
62 | " \"\"\"Adds two numbers.\"\"\"\n",
63 | " return x + y\n",
64 | "\n",
65 | " @rg.tool_method(catch=True)\n",
66 | " def subtract(self, x: float, y: float) -> float:\n",
67 | " \"\"\"Subtracts the second number from the first.\"\"\"\n",
68 | " return x - y\n",
69 | "\n",
70 | " @rg.tool_method(catch=True)\n",
71 | " def multiply(self, x: float, y: float) -> float:\n",
72 | " \"\"\"Multiplies two numbers.\"\"\"\n",
73 | " return x * y\n",
74 | "\n",
75 | " @rg.tool_method(catch=True)\n",
76 | " def divide(self, x: float, y: float) -> float:\n",
77 | " \"\"\"Divides the first number by the second.\"\"\"\n",
78 | " if y == 0:\n",
79 | " raise ValueError(\"Cannot divide by zero.\")\n",
80 | " return x / y\n",
81 | "\n",
82 | "calculator = Calculator()\n",
83 | "\n",
84 | "# Python execution tool\n",
85 | "\n",
86 | "@rg.tool(catch=True)\n",
87 | "def execute_python(code: str) -> str:\n",
88 | " \"\"\"\n",
89 | " Executes Python code and returns stdout from the execution.\n",
90 | " \n",
91 | " - Use print() to output results.\n",
92 | " - Be thoughtful with indentation and syntax.\n",
93 | " \"\"\"\n",
94 | " output = StringIO()\n",
95 | " with redirect_stdout(output):\n",
96 | " exec(code)\n",
97 | " return output.getvalue()\n",
98 | " "
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "id": "d1ea856c",
104 | "metadata": {},
105 | "source": [
106 | "### Define our agents\n"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "2962c1ac",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "import contextlib\n",
117 | "\n",
118 | "# Define a model for parsing the agent answer\n",
119 | "\n",
120 | "class Answer(rg.Model):\n",
121 | " value: str\n",
122 | "\n",
123 | "# Callback to inspect results after the agent execution\n",
124 | "\n",
125 | "async def inspect_results(chat: rg.Chat) -> rg.Chat:\n",
126 | " used_tools = any(msg.tool_calls for msg in chat.messages if msg.role == \"assistant\")\n",
127 | " answer = chat.message_metadata[\"answer\"]\n",
128 | " agent_answer = chat.last.try_parse(Answer)\n",
129 | "\n",
130 | " correct = False\n",
131 | " with contextlib.suppress(ValueError):\n",
132 | " correct = agent_answer is not None and float(agent_answer.value) == float(answer)\n",
133 | " \n",
134 | " return chat.meta(\n",
135 | " correct=correct,\n",
136 | " gave_answer=answer is not None,\n",
137 | " agent_answer=agent_answer.value if agent_answer else None,\n",
138 | " used_tools=used_tools,\n",
139 | " true_answer=answer\n",
140 | " )\n",
141 | "\n",
142 | "\n",
143 | "# Build our core pipeline\n",
144 | "\n",
145 | "pipeline = (\n",
146 | " rg.get_generator(\"groq/meta-llama/llama-4-maverick-17b-128e-instruct\")\n",
147 | " .chat(\"Answer math questions and return basic floats between tags.\")\n",
148 | " .until_parsed_as(Answer)\n",
149 | " .then(inspect_results)\n",
150 | " .catch(Exception, on_failed=\"include\")\n",
151 | ")\n",
152 | "\n",
153 | "# Define 3 agents with different capabilities\n",
154 | "\n",
155 | "agent_no_tools = pipeline.clone().meta(variant=\"no_tools\")\n",
156 | "agent_with_calculator = pipeline.clone().using(calculator, mode=\"xml\").meta(variant=\"with_calculator\")\n",
157 | "agent_with_python = pipeline.clone().using(execute_python, mode=\"xml\").meta(variant=\"with_python\")\n"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "id": "e0dae7e4",
163 | "metadata": {},
164 | "source": [
165 | "### Run 10 samples through our 3 agents\n"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "id": "39bf862f",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "import random\n",
176 | "\n",
177 | "samples = random.sample(dataset, 10)\n",
178 | "\n",
179 | "chats_no_tools = await agent_no_tools.run_batch(samples)\n",
180 | "chats_with_calculator = await agent_with_calculator.run_batch(samples)\n",
181 | "chats_with_python = await agent_with_python.run_batch(samples)"
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "id": "b8305557",
187 | "metadata": {},
188 | "source": [
189 | "### Calculate success rates\n"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "id": "436079a5",
196 | "metadata": {},
197 | "outputs": [],
198 | "source": [
199 | "def get_success_rate(chats: t.List[rg.Chat]) -> float:\n",
200 | " return sum(chat.metadata.get(\"correct\", False) for chat in chats) / len(chats)\n",
201 | "\n",
202 | "no_tools_success_rate = get_success_rate(chats_no_tools)\n",
203 | "with_calculator_success_rate = get_success_rate(chats_with_calculator)\n",
204 | "with_python_success_rate = get_success_rate(chats_with_python)\n",
205 | "\n",
206 | "print(f\"Success rate without tools: {no_tools_success_rate:.2%}\")\n",
207 | "print(f\"Success rate with calculator: {with_calculator_success_rate:.2%}\")\n",
208 | "print(f\"Success rate with Python: {with_python_success_rate:.2%}\")"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "id": "12655a0c",
214 | "metadata": {},
215 | "source": [
216 | "### Tokenize the chat with a tokenizer\n"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "id": "3a7ef6dc",
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "tokenized = await chats_with_python.to_tokens('Qwen/Qwen2.5-1.5B-Instruct', transform=\"json-with-tag\")\n",
227 | "\n",
228 | "print(tokenized[0].metadata)\n",
229 | "print(tokenized[0].slices)"
230 | ]
231 | }
232 | ],
233 | "metadata": {
234 | "kernelspec": {
235 | "display_name": ".venv",
236 | "language": "python",
237 | "name": "python3"
238 | },
239 | "language_info": {
240 | "codemirror_mode": {
241 | "name": "ipython",
242 | "version": 3
243 | },
244 | "file_extension": ".py",
245 | "mimetype": "text/x-python",
246 | "name": "python",
247 | "nbconvert_exporter": "python",
248 | "pygments_lexer": "ipython3",
249 | "version": "3.10.14"
250 | }
251 | },
252 | "nbformat": 4,
253 | "nbformat_minor": 5
254 | }
255 |
--------------------------------------------------------------------------------
/.hooks/generate_docs.py:
--------------------------------------------------------------------------------
1 | import argparse # noqa: INP001
2 | import re
3 | import typing as t
4 | from pathlib import Path
5 |
6 | from markdown import Markdown # type: ignore [import-untyped]
7 | from markdownify import MarkdownConverter # type: ignore [import-untyped]
8 | from markupsafe import Markup
9 | from mkdocstrings_handlers.python._internal.config import PythonConfig
10 | from mkdocstrings_handlers.python._internal.handler import (
11 | PythonHandler,
12 | )
13 |
14 | # ruff: noqa: T201
15 |
16 |
17 | class CustomMarkdownConverter(MarkdownConverter): # type: ignore [misc]
18 | # Strip extra whitespace from code blocks
19 | def convert_pre(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any:
20 | return super().convert_pre(el, text.strip(), parent_tags)
21 |
22 | # bold items with doc-section-title in a span class
23 | def convert_span(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: # noqa: ARG002
24 | if "doc-section-title" in el.get("class", []):
25 | return f"**{text.strip()}**"
26 | return text
27 |
28 | # Remove the div wrapper for inline descriptions
29 | def convert_div(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any:
30 | if "doc-md-description" in el.get("class", []):
31 | return text.strip()
32 | return super().convert_div(el, text, parent_tags)
33 |
34 | # Map mkdocstrings details classes to Mintlify callouts
35 | def convert_details(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any: # noqa: ARG002
36 | classes = el.get("class", [])
37 |
38 | # Handle source code details specially
39 | if "quote" in classes:
40 | summary = el.find("summary")
41 | if summary:
42 | file_path = summary.get_text().replace("Source code in ", "").strip()
43 | content = text[text.find("```") :]
44 | return f'\n\n{content}\n\n'
45 |
46 | callout_map = {
47 | "note": "Note",
48 | "warning": "Warning",
49 | "info": "Info",
50 | "tip": "Tip",
51 | }
52 |
53 | callout_type = None
54 | for cls in classes:
55 | if cls in callout_map:
56 | callout_type = callout_map[cls]
57 | break
58 |
59 | if not callout_type:
60 | return text
61 |
62 | content = text.strip()
63 | if content.startswith(callout_type):
64 | content = content[len(callout_type) :].strip()
65 |
66 | return f"\n<{callout_type}>\n{content}\n{callout_type}>\n"
67 |
68 | def convert_table(self, el: t.Any, text: str, parent_tags: t.Any) -> t.Any:
69 | # Check if this is a highlighttable (source code with line numbers)
70 | if "highlighttable" in el.get("class", []):
71 | code_cells = el.find_all("td", class_="code")
72 | if code_cells:
73 | code = code_cells[0].get_text()
74 | code = code.strip()
75 | code = code.replace("```", "~~~")
76 | return f"\n```python\n{code}\n```\n"
77 |
78 | return super().convert_table(el, text, parent_tags)
79 |
80 |
81 | class AutoDocGenerator:
82 | def __init__(self, source_paths: list[str], theme: str = "material", **options: t.Any) -> None:
83 | self.source_paths = source_paths
84 | self.theme = theme
85 | self.handler = PythonHandler(PythonConfig.from_data(), base_dir=Path.cwd())
86 | self.options = options
87 |
88 | self.handler._update_env( # noqa: SLF001
89 | Markdown(),
90 | config={"mdx": ["toc"]},
91 | )
92 |
93 | md = Markdown(extensions=["fenced_code"])
94 |
95 | def simple_convert_markdown(
96 | text: str,
97 | heading_level: int,
98 | html_id: str = "",
99 | **kwargs: t.Any,
100 | ) -> t.Any:
101 | return Markup(md.convert(text) if text else "") # noqa: S704 # nosec
102 |
103 | self.handler.env.filters["convert_markdown"] = simple_convert_markdown
104 |
105 | def generate_docs_for_module(
106 | self,
107 | module_path: str,
108 | ) -> str:
109 | options = self.handler.get_options(
110 | {
111 | "docstring_section_style": "list",
112 | "merge_init_into_class": True,
113 | "show_signature_annotations": True,
114 | "separate_signature": True,
115 | "show_source": True,
116 | "show_labels": False,
117 | "show_bases": False,
118 | **self.options,
119 | },
120 | )
121 |
122 | module_data = self.handler.collect(module_path, options)
123 | html = self.handler.render(module_data, options)
124 |
125 | return str(
126 | CustomMarkdownConverter(
127 | code_language="python",
128 | ).convert(html),
129 | )
130 |
131 | def process_mdx_file(self, file_path: Path) -> bool:
132 | content = file_path.read_text(encoding="utf-8")
133 | original_content = content
134 |
135 | # Find the header comment block
136 | header_match = re.search(
137 | r"\{\s*/\*\s*((?:::.*?\n?)*)\s*\*/\s*\}",
138 | content,
139 | re.MULTILINE | re.DOTALL,
140 | )
141 |
142 | if not header_match:
143 | return False
144 |
145 | header = header_match.group(0)
146 | module_lines = header_match.group(1).strip().split("\n")
147 |
148 | # Generate content for each module
149 | markdown_blocks = []
150 | for line in module_lines:
151 | if line.startswith(":::"):
152 | module_path = line.strip()[3:].strip()
153 | if module_path:
154 | markdown = self.generate_docs_for_module(module_path)
155 | markdown_blocks.append(markdown)
156 |
157 | keep_end = content.find(header) + len(header)
158 | new_content = content[:keep_end] + "\n\n" + "\n".join(markdown_blocks)
159 |
160 | # Write back if changed
161 | if new_content != original_content:
162 | file_path.write_text(new_content, encoding="utf-8")
163 | print(f"[+] Updated: {file_path}")
164 | return True
165 |
166 | return False
167 |
168 | def process_directory(self, directory: Path, pattern: str = "**/*.mdx") -> int:
169 | if not directory.exists():
170 | print(f"[!] Directory does not exist: {directory}")
171 | return 0
172 |
173 | files_processed = 0
174 | files_modified = 0
175 |
176 | for mdx_file in directory.glob(pattern):
177 | if mdx_file.is_file():
178 | files_processed += 1
179 | if self.process_mdx_file(mdx_file):
180 | files_modified += 1
181 |
182 | return files_modified
183 |
184 |
185 | def main() -> None:
186 | """Main entry point for the script."""
187 |
188 | parser = argparse.ArgumentParser(description="Generate auto-docs for MDX files")
189 | parser.add_argument("--directory", help="Directory containing MDX files", default="docs")
190 | parser.add_argument("--pattern", default="**/*.mdx", help="File pattern to match")
191 | parser.add_argument(
192 | "--source-paths",
193 | nargs="+",
194 | default=["dreadnode"],
195 | help="Python source paths for module discovery",
196 | )
197 | parser.add_argument(
198 | "--show-if-no-docstring",
199 | type=bool,
200 | default=False,
201 | help="Show module/class/function even if no docstring is present",
202 | )
203 | parser.add_argument("--theme", default="material", help="Theme to use for rendering")
204 |
205 | args = parser.parse_args()
206 |
207 | # Create generator
208 | generator = AutoDocGenerator(
209 | source_paths=args.source_paths,
210 | theme=args.theme,
211 | show_if_no_docstring=args.show_if_no_docstring,
212 | )
213 |
214 | # Process directory
215 | directory = Path(args.directory)
216 | modified_count = generator.process_directory(directory, args.pattern)
217 |
218 | print(f"\n[+] Auto-doc generation complete. {modified_count} files were updated.")
219 |
220 |
221 | if __name__ == "__main__":
222 | main()
223 |
--------------------------------------------------------------------------------
/rigging/generator/transformers_.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import typing as t
3 |
4 | import torch # type: ignore [import-not-found, import-untyped, unused-ignore]
5 | import transformers # type: ignore [import-not-found, import-untyped, unused-ignore]
6 | from transformers import ( # type: ignore [import-not-found, attr-defined, unused-ignore]
7 | AutoModelForCausalLM,
8 | AutoTokenizer,
9 | TextGenerationPipeline,
10 | )
11 |
12 | from rigging.generator.base import (
13 | GeneratedMessage,
14 | GeneratedText,
15 | GenerateParams,
16 | Generator,
17 | trace_messages,
18 | trace_str,
19 | )
20 |
21 | if t.TYPE_CHECKING:
22 | from rigging.message import Message
23 |
24 | DEFAULT_MAX_TOKENS = 1024
25 | """Lifting the default max tokens from transformers"""
26 |
27 |
28 | class TransformersGenerator(Generator):
29 | """
30 | Generator backed by the Transformers library for local model loading.
31 |
32 | Warning:
33 | The use of Transformers requires the `transformers` package to be installed directly or by
34 | installing rigging as `rigging[all]`.
35 |
36 | Warning:
37 | The `transformers` library is expansive with many different models, tokenizers,
38 | options, constructors, etc. We do our best to implement a consistent interface,
39 | but there may be limitations. Where needed, use
40 | [`.from_obj()`][rigging.generator.transformers_.TransformersGenerator.from_obj].
41 |
42 | Note:
43 | This generator doesn't leverage any async capabilities.
44 |
45 | Note:
46 | The model load into memory will occur lazily when the first generation is requested.
47 | If you'd want to force this to happen earlier, you can use the
48 | [`.load()`][rigging.generator.Generator.load] method.
49 |
50 | To unload, call [`.unload()`][rigging.generator.Generator.unload].
51 | """
52 |
53 | torch_dtype: str = "auto"
54 | """Torch dtype passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
55 | device_map: str = "auto"
56 | """Device map passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
57 | trust_remote_code: bool = False
58 | """Trust remote code passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
59 | load_in_8bit: bool = False
60 | """Load in 8 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
61 | load_in_4bit: bool = False
62 | """Load in 4 bit passed to [`AutoModelForCausalLM.from_pretrained`](https://huggingface.co/docs/transformers/v4.41.0/en/model_doc/auto)"""
63 |
64 | _llm: AutoModelForCausalLM | None = None
65 | _tokenizer: AutoTokenizer | None = None
66 | _pipeline: TextGenerationPipeline | None = None
67 |
68 | @property
69 | def llm(self) -> AutoModelForCausalLM:
70 | """The underlying `AutoModelForCausalLM` instance."""
71 | # Lazy initialization
72 | if self._llm is None:
73 | llm_kwargs = self.model_dump(
74 | exclude_unset=True,
75 | include={
76 | "torch_dtype",
77 | "device_map",
78 | "trust_remote_code",
79 | "load_in_8bit",
80 | "load_in_4bit",
81 | },
82 | )
83 | self._llm = AutoModelForCausalLM.from_pretrained(self.model, **llm_kwargs) # type: ignore [no-untyped-call, assignment, unused-ignore] # nosec
84 | if self._llm is None:
85 | raise ValueError(f"Failed to load model '{self.model}'")
86 | return self._llm
87 |
88 | @property
89 | def tokenizer(self) -> AutoTokenizer:
90 | """The underlying `AutoTokenizer` instance."""
91 | if self._tokenizer is None:
92 | self._tokenizer = AutoTokenizer.from_pretrained(self.model) # type: ignore [no-untyped-call, unused-ignore] # nosec
93 | return self._tokenizer
94 |
95 | @property
96 | def pipeline(self) -> TextGenerationPipeline:
97 | """The underlying `TextGenerationPipeline` instance."""
98 | if self._pipeline is None:
99 | self._pipeline = transformers.pipeline( # type: ignore [attr-defined, call-overload, assignment, unused-ignore]
100 | "text-generation",
101 | return_full_text=False,
102 | model=self.llm, # type: ignore [arg-type, unused-ignore]
103 | tokenizer=self.tokenizer, # type: ignore [arg-type, unused-ignore]
104 | )
105 | return self._pipeline # type: ignore [return-value, unused-ignore]
106 |
107 | @classmethod
108 | def from_obj(
109 | cls,
110 | model: t.Any,
111 | tokenizer: AutoTokenizer,
112 | *,
113 | pipeline: TextGenerationPipeline | None = None,
114 | params: GenerateParams | None = None,
115 | ) -> "TransformersGenerator":
116 | """
117 | Create a new instance of TransformersGenerator from an already loaded model and tokenizer.
118 |
119 | Args:
120 | model: The loaded model for text generation.
121 | tokenizer: The tokenizer associated with the model.
122 | pipeline: The text generation pipeline. Defaults to None.
123 | params: Generation parameters. Defaults to None.
124 |
125 | Returns:
126 | The TransformersGenerator instance.
127 | """
128 | instance = cls(model=model, params=params or GenerateParams(), api_key=None)
129 | instance._llm = model
130 | instance._tokenizer = tokenizer
131 | instance._pipeline = pipeline
132 | return instance
133 |
134 | def load(self) -> "TransformersGenerator":
135 | _ = self.pipeline
136 | return self
137 |
138 | def unload(self) -> "TransformersGenerator":
139 | del self._pipeline
140 | del self._llm
141 | del self._tokenizer
142 | gc.collect()
143 | torch.cuda.empty_cache()
144 | return self
145 |
146 | def _generate(
147 | self,
148 | inputs: t.Sequence[str] | t.Sequence[t.Sequence[dict[str, str]]],
149 | params: t.Sequence[GenerateParams],
150 | ) -> list[GeneratedText]:
151 | param_set = {p.model_dump_json() for p in params}
152 | if len(param_set) != 1:
153 | raise ValueError("All GenerateParams must be identical for this generator")
154 |
155 | # Generation Args + Fixups
156 | if self.params.max_tokens is None:
157 | self.params.max_tokens = DEFAULT_MAX_TOKENS
158 |
159 | kwargs = self.params.merge_with(params[0]).to_dict()
160 | if "max_tokens" in kwargs:
161 | kwargs["max_new_tokens"] = kwargs.pop("max_tokens")
162 | if any(k in kwargs for k in ["temperature", "top_k", "top_p"]):
163 | kwargs["do_sample"] = True
164 |
165 | outputs = self.pipeline(inputs, **kwargs) # type: ignore [call-overload]
166 |
167 | # TODO: We do strip() here as it's often needed, but I think
168 | # we should return and standardize this behavior.
169 | return [GeneratedText(text=o["generated_text"].strip()) for o in outputs]
170 |
171 | async def generate_messages(
172 | self,
173 | messages: t.Sequence[t.Sequence["Message"]],
174 | params: t.Sequence[GenerateParams],
175 | ) -> t.Sequence[GeneratedMessage]:
176 | message_dicts = [
177 | [m.to_openai(compatibility_flags={"content_as_str"}) for m in _messages]
178 | for _messages in messages
179 | ]
180 | outputs = self._generate(message_dicts, params)
181 | generated = [o.to_generated_message() for o in outputs]
182 |
183 | for i, (in_messages, out_message) in enumerate(zip(messages, generated, strict=False)):
184 | trace_messages(in_messages, f"Messages {i + 1}/{len(in_messages)}")
185 | trace_messages([out_message], f"Response {i + 1}/{len(in_messages)}")
186 |
187 | return generated
188 |
189 | async def generate_texts(
190 | self,
191 | texts: t.Sequence[str],
192 | params: t.Sequence[GenerateParams],
193 | ) -> t.Sequence[GeneratedText]:
194 | generated = self._generate(texts, params)
195 |
196 | for i, (text, response) in enumerate(zip(texts, generated, strict=False)):
197 | trace_str(text, f"Text {i + 1}/{len(texts)}")
198 | trace_str(response, f"Generated {i + 1}/{len(texts)}")
199 |
200 | return generated
201 |
--------------------------------------------------------------------------------
/docs/api/parsing.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: rigging.parsing
3 | ---
4 |
5 | {/*
6 | ::: rigging.parsing
7 | */}
8 |
9 | Parsing helpers for extracting rigging models from text
10 |
11 | parse
12 | -----
13 |
14 | ```python
15 | parse(
16 | text: str, model_type: type[ModelT]
17 | ) -> tuple[ModelT, slice]
18 | ```
19 |
20 | Parses a single model from text.
21 |
22 | **Parameters:**
23 |
24 | * **`text`**
25 | (`str`)
26 | –The content to parse.
27 | * **`model_type`**
28 | (`type[ModelT]`)
29 | –The type of model to parse.
30 |
31 | **Returns:**
32 |
33 | * `tuple[ModelT, slice]`
34 | –The parsed model.
35 |
36 | **Raises:**
37 |
38 | * `ValueError`
39 | –If no models of the given type are found and `fail_on_missing` is set to `True`.
40 |
41 |
42 | ```python
43 | def parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice]:
44 | """
45 | Parses a single model from text.
46 |
47 | Args:
48 | text: The content to parse.
49 | model_type: The type of model to parse.
50 |
51 | Returns:
52 | The parsed model.
53 |
54 | Raises:
55 | ValueError: If no models of the given type are found and `fail_on_missing` is set to `True`.
56 | """
57 | return try_parse_many(text, model_type, fail_on_missing=True)[0]
58 | ```
59 |
60 |
61 |
62 |
63 | parse\_many
64 | -----------
65 |
66 | ```python
67 | parse_many(
68 | text: str, *types: type[ModelT]
69 | ) -> list[tuple[ModelT, slice]]
70 | ```
71 |
72 | Parses multiple models of the specified non-identical types from text.
73 |
74 | **Parameters:**
75 |
76 | * **`text`**
77 | (`str`)
78 | –The content to parse.
79 | * **`*types`**
80 | (`type[ModelT]`, default:
81 | `()`
82 | )
83 | –The types of models to parse.
84 |
85 | **Returns:**
86 |
87 | * `list[tuple[ModelT, slice]]`
88 | –A list of parsed models.
89 |
90 | **Raises:**
91 |
92 | * `MissingModelError`
93 | –If any of the models are missing.
94 |
95 |
96 | ```python
97 | def parse_many(text: str, *types: type["ModelT"]) -> list[tuple["ModelT", slice]]:
98 | """
99 | Parses multiple models of the specified non-identical types from text.
100 |
101 | Args:
102 | text: The content to parse.
103 | *types: The types of models to parse.
104 |
105 | Returns:
106 | A list of parsed models.
107 |
108 | Raises:
109 | MissingModelError: If any of the models are missing.
110 | """
111 | return try_parse_many(text, *types, fail_on_missing=True)
112 | ```
113 |
114 |
115 |
116 |
117 | parse\_set
118 | ----------
119 |
120 | ```python
121 | parse_set(
122 | text: str,
123 | model_type: type[ModelT],
124 | *,
125 | minimum: int | None = None,
126 | ) -> list[tuple[ModelT, slice]]
127 | ```
128 |
129 | Parses a set of models with the specified identical type from text.
130 |
131 | **Parameters:**
132 |
133 | * **`text`**
134 | (`str`)
135 | –The content to parse.
136 | * **`model_type`**
137 | (`type[ModelT]`)
138 | –The type of models to parse.
139 | * **`minimum`**
140 | (`int | None`, default:
141 | `None`
142 | )
143 | –The minimum number of models required.
144 |
145 | **Returns:**
146 |
147 | * `list[tuple[ModelT, slice]]`
148 | –A list of parsed models.
149 |
150 | **Raises:**
151 |
152 | * `MissingModelError`
153 | –If the minimum number of models is not met.
154 |
155 |
156 | ```python
157 | def parse_set(
158 | text: str,
159 | model_type: type["ModelT"],
160 | *,
161 | minimum: int | None = None,
162 | ) -> list[tuple["ModelT", slice]]:
163 | """
164 | Parses a set of models with the specified identical type from text.
165 |
166 | Args:
167 | text: The content to parse.
168 | model_type: The type of models to parse.
169 | minimum: The minimum number of models required.
170 |
171 | Returns:
172 | A list of parsed models.
173 |
174 | Raises:
175 | MissingModelError: If the minimum number of models is not met.
176 | """
177 | return try_parse_set(text, model_type, minimum=minimum, fail_on_missing=True)
178 | ```
179 |
180 |
181 |
182 |
183 | try\_parse
184 | ----------
185 |
186 | ```python
187 | try_parse(
188 | text: str, model_type: type[ModelT]
189 | ) -> tuple[ModelT, slice] | None
190 | ```
191 |
192 | Tries to parse a model from text.
193 |
194 | **Parameters:**
195 |
196 | * **`text`**
197 | (`str`)
198 | –The content to parse.
199 | * **`model_type`**
200 | (`type[ModelT]`)
201 | –The type of model to search for.
202 |
203 | **Returns:**
204 |
205 | * `tuple[ModelT, slice] | None`
206 | –The first model that matches the given model type, or None if no match is found.
207 |
208 |
209 | ```python
210 | def try_parse(text: str, model_type: type["ModelT"]) -> tuple["ModelT", slice] | None:
211 | """
212 | Tries to parse a model from text.
213 |
214 | Args:
215 | text: The content to parse.
216 | model_type: The type of model to search for.
217 |
218 | Returns:
219 | The first model that matches the given model type, or None if no match is found.
220 | """
221 | return next(iter(try_parse_many(text, model_type)), None)
222 | ```
223 |
224 |
225 |
226 |
227 | try\_parse\_many
228 | ----------------
229 |
230 | ```python
231 | try_parse_many(
232 | text: str,
233 | *types: type[ModelT],
234 | fail_on_missing: bool = False,
235 | ) -> list[tuple[ModelT, slice]]
236 | ```
237 |
238 | Tries to parses multiple models of the specified non-identical types from text.
239 |
240 | **Parameters:**
241 |
242 | * **`text`**
243 | (`str`)
244 | –The content to parse.
245 | * **`*types`**
246 | (`type[ModelT]`, default:
247 | `()`
248 | )
249 | –The types of models to parse.
250 | * **`fail_on_missing`**
251 | (`bool`, default:
252 | `False`
253 | )
254 | –Whether to raise an exception if a model type is missing.
255 |
256 | **Returns:**
257 |
258 | * `list[tuple[ModelT, slice]]`
259 | –A list of parsed models.
260 |
261 | **Raises:**
262 |
263 | * `MissingModelError`
264 | –If a model type is missing and `fail_on_missing` is True.
265 | * `Exception`
266 | –If the model is malformed and `fail_on_missing` is True.
267 |
268 |
269 | ```python
270 | def try_parse_many(
271 | text: str,
272 | *types: type["ModelT"],
273 | fail_on_missing: bool = False,
274 | ) -> list[tuple["ModelT", slice]]:
275 | """
276 | Tries to parses multiple models of the specified non-identical types from text.
277 |
278 | Args:
279 | text: The content to parse.
280 | *types: The types of models to parse.
281 | fail_on_missing: Whether to raise an exception if a model type is missing.
282 |
283 | Returns:
284 | A list of parsed models.
285 |
286 | Raises:
287 | MissingModelError: If a model type is missing and `fail_on_missing` is True.
288 | Exception: If the model is malformed and `fail_on_missing` is True.
289 | """
290 | model: ModelT
291 | parsed: list[tuple[ModelT, slice]] = []
292 |
293 | try:
294 | for model_class in types:
295 | for model, slice_ in model_class.from_text(text):
296 | parsed.append((model, slice_))
297 | except Exception:
298 | if fail_on_missing:
299 | raise
300 |
301 | return sorted(parsed, key=lambda x: x[1].start)
302 | ```
303 |
304 |
305 |
306 |
307 | try\_parse\_set
308 | ---------------
309 |
310 | ```python
311 | try_parse_set(
312 | text: str,
313 | model_type: type[ModelT],
314 | *,
315 | minimum: int | None = None,
316 | fail_on_missing: bool = False,
317 | ) -> list[tuple[ModelT, slice]]
318 | ```
319 |
320 | Tries to parse a set of models with the specified identical type from text.
321 |
322 | **Parameters:**
323 |
324 | * **`text`**
325 | (`str`)
326 | –The content to parse.
327 | * **`model_type`**
328 | (`type[ModelT]`)
329 | –The type of model to parse.
330 | * **`minimum`**
331 | (`int | None`, default:
332 | `None`
333 | )
334 | –The minimum number of models expected.
335 | * **`fail_on_missing`**
336 | (`bool`, default:
337 | `False`
338 | )
339 | –Whether to raise an exception if models are missing.
340 |
341 | **Returns:**
342 |
343 | * `list[tuple[ModelT, slice]]`
344 | –The parsed models.
345 |
346 | **Raises:**
347 |
348 | * `MissingModelError`
349 | –If the number of parsed models is less than the minimum required.
350 |
351 |
352 | ```python
353 | def try_parse_set(
354 | text: str,
355 | model_type: type["ModelT"],
356 | *,
357 | minimum: int | None = None,
358 | fail_on_missing: bool = False,
359 | ) -> list[tuple["ModelT", slice]]:
360 | """
361 | Tries to parse a set of models with the specified identical type from text.
362 |
363 | Args:
364 | text: The content to parse.
365 | model_type: The type of model to parse.
366 | minimum: The minimum number of models expected.
367 | fail_on_missing: Whether to raise an exception if models are missing.
368 |
369 | Returns:
370 | The parsed models.
371 |
372 | Raises:
373 | MissingModelError: If the number of parsed models is less than the minimum required.
374 | """
375 | models = try_parse_many(text, model_type, fail_on_missing=fail_on_missing)
376 | if minimum is not None and len(models) < minimum:
377 | raise MissingModelError(f"Expected at least {minimum} {model_type.__name__} in message")
378 | return models
379 | ```
380 |
381 |
382 |
--------------------------------------------------------------------------------
/docs/api/error.mdx:
--------------------------------------------------------------------------------
1 | ---
2 | title: rigging.error
3 | ---
4 |
5 | {/*
6 | ::: rigging.error
7 | */}
8 |
9 | We try to avoid creating custom exceptions unless they are necessary.
10 |
11 | We use the built-in and pydantic exceptions as much as possible.
12 |
13 | CompletionExhaustedMaxRoundsError
14 | ---------------------------------
15 |
16 | ```python
17 | CompletionExhaustedMaxRoundsError(
18 | max_rounds: int, completion: str
19 | )
20 | ```
21 |
22 | Raised when the maximum number of rounds is exceeded while generating completions.
23 |
24 |
25 | ```python
26 | def __init__(self, max_rounds: int, completion: str):
27 | super().__init__(max_rounds)
28 | self.completion = completion
29 | """The completion which was being generated when the exception occurred."""
30 | ```
31 |
32 |
33 |
34 |
35 | ### completion
36 |
37 | ```python
38 | completion = completion
39 | ```
40 |
41 | The completion which was being generated when the exception occurred.
42 |
43 | ExhaustedMaxRoundsError
44 | -----------------------
45 |
46 | ```python
47 | ExhaustedMaxRoundsError(max_rounds: int)
48 | ```
49 |
50 | Raised when the maximum number of rounds is exceeded while generating.
51 |
52 |
53 | ```python
54 | def __init__(self, max_rounds: int):
55 | super().__init__(f"Exhausted max rounds ({max_rounds}) while generating")
56 | self.max_rounds = max_rounds
57 | """The number of rounds which was exceeded."""
58 | ```
59 |
60 |
61 |
62 |
63 | ### max\_rounds
64 |
65 | ```python
66 | max_rounds = max_rounds
67 | ```
68 |
69 | The number of rounds which was exceeded.
70 |
71 | GeneratorWarning
72 | ----------------
73 |
74 | Base class for all generator warnings.
75 |
76 | This is used to indicate that something unexpected happened during the generator execution,
77 | but it is not critical enough to stop the execution.
78 |
79 | InvalidGeneratorError
80 | ---------------------
81 |
82 | ```python
83 | InvalidGeneratorError(model: str)
84 | ```
85 |
86 | Raised when an invalid identifier is specified when getting a generator.
87 |
88 |
89 | ```python
90 | def __init__(self, model: str):
91 | super().__init__(f"Invalid model specified: {model}")
92 | ```
93 |
94 |
95 |
96 |
97 | InvalidTokenizerError
98 | ---------------------
99 |
100 | ```python
101 | InvalidTokenizerError(tokenizer: str)
102 | ```
103 |
104 | Raised when an invalid tokenizer is specified.
105 |
106 |
107 | ```python
108 | def __init__(self, tokenizer: str):
109 | super().__init__(f"Invalid tokenizer specified: {tokenizer}")
110 | self.tokenizer = tokenizer
111 | """The name of the tokenizer which was invalid."""
112 | ```
113 |
114 |
115 |
116 |
117 | ### tokenizer
118 |
119 | ```python
120 | tokenizer = tokenizer
121 | ```
122 |
123 | The name of the tokenizer which was invalid.
124 |
125 | MaxDepthError
126 | -------------
127 |
128 | ```python
129 | MaxDepthError(
130 | max_depth: int, step: PipelineStep, reference: str
131 | )
132 | ```
133 |
134 | Raised when the maximum depth is exceeded while generating.
135 |
136 |
137 | ```python
138 | def __init__(self, max_depth: int, step: "PipelineStep", reference: str):
139 | super().__init__(f"Exceeded max depth ({max_depth}) while generating ('{reference}')")
140 | self.max_depth = max_depth
141 | """The maximum depth of nested pipeline generations which was exceeded."""
142 | self.step = step
143 | """The pipeline step which cause the depth error."""
144 | ```
145 |
146 |
147 |
148 |
149 | ### max\_depth
150 |
151 | ```python
152 | max_depth = max_depth
153 | ```
154 |
155 | The maximum depth of nested pipeline generations which was exceeded.
156 |
157 | ### step
158 |
159 | ```python
160 | step = step
161 | ```
162 |
163 | The pipeline step which cause the depth error.
164 |
165 | MessageWarning
166 | --------------
167 |
168 | Base class for all message warnings.
169 |
170 | This is used to indicate that something unexpected happened during the message processing,
171 | but it is not critical enough to stop the execution.
172 |
173 | MessagesExhaustedMaxRoundsError
174 | -------------------------------
175 |
176 | ```python
177 | MessagesExhaustedMaxRoundsError(
178 | max_rounds: int, messages: list[Message]
179 | )
180 | ```
181 |
182 | Raised when the maximum number of rounds is exceeded while generating messages.
183 |
184 |
185 | ```python
186 | def __init__(self, max_rounds: int, messages: list["Message"]):
187 | super().__init__(max_rounds)
188 | self.messages = messages
189 | """The messages which were being generated when the exception occurred."""
190 | ```
191 |
192 |
193 |
194 |
195 | ### messages
196 |
197 | ```python
198 | messages = messages
199 | ```
200 |
201 | The messages which were being generated when the exception occurred.
202 |
203 | MissingModelError
204 | -----------------
205 |
206 | ```python
207 | MissingModelError(content: str)
208 | ```
209 |
210 | Raised when a model is missing when parsing a message.
211 |
212 |
213 | ```python
214 | def __init__(self, content: str):
215 | super().__init__(content)
216 | ```
217 |
218 |
219 |
220 |
221 | PipelineWarning
222 | ---------------
223 |
224 | Base class for all pipeline warnings.
225 |
226 | This is used to indicate that something unexpected happened during the pipeline execution,
227 | but it is not critical enough to stop the execution.
228 |
229 | ProcessingError
230 | ---------------
231 |
232 | ```python
233 | ProcessingError(content: str)
234 | ```
235 |
236 | Raised when an error occurs during internal generator processing.
237 |
238 |
239 | ```python
240 | def __init__(self, content: str):
241 | super().__init__(content)
242 | ```
243 |
244 |
245 |
246 |
247 | Stop
248 | ----
249 |
250 | ```python
251 | Stop(message: str)
252 | ```
253 |
254 | Raise inside a pipeline to indicate a stopping condition.
255 |
256 | Example
257 |
258 | ```python
259 | import rigging as rg
260 |
261 | async def read_file(path: str) -> str:
262 | "Read the contents of a file."
263 |
264 | if no_more_files(path):
265 | raise rg.Stop("There are no more files to read.")
266 |
267 | ...
268 |
269 | chat = await pipeline.using(read_file).run()
270 | ```
271 |
272 |
273 |
274 | ```python
275 | def __init__(self, message: str):
276 | super().__init__(message)
277 | self.message = message
278 | """The message associated with the stop."""
279 | ```
280 |
281 |
282 |
283 |
284 | ### message
285 |
286 | ```python
287 | message = message
288 | ```
289 |
290 | The message associated with the stop.
291 |
292 | TokenizerWarning
293 | ----------------
294 |
295 | Base class for all tokenization warnings.
296 |
297 | This is used to indicate that something unexpected happened during the tokenization process,
298 | but it is not critical enough to stop the execution.
299 |
300 | ToolDefinitionError
301 | -------------------
302 |
303 | ```python
304 | ToolDefinitionError(message: str)
305 | ```
306 |
307 | Raised when a tool cannot be properly defined.
308 |
309 |
310 | ```python
311 | def __init__(self, message: str):
312 | super().__init__(message)
313 | ```
314 |
315 |
316 |
317 |
318 | ToolWarning
319 | -----------
320 |
321 | Base class for all tool warnings.
322 |
323 | This is used to indicate that something unexpected happened during the tool execution,
324 | but it is not critical enough to stop the execution.
325 |
326 | UnknownToolError
327 | ----------------
328 |
329 | ```python
330 | UnknownToolError(tool_name: str)
331 | ```
332 |
333 | Raised when the an api tool call is made for an unknown tool.
334 |
335 |
336 | ```python
337 | def __init__(self, tool_name: str):
338 | super().__init__(f"Unknown tool call was requested for '{tool_name}'")
339 | self.tool_name = tool_name
340 | """The name of the tool which was unknown."""
341 | ```
342 |
343 |
344 |
345 |
346 | ### tool\_name
347 |
348 | ```python
349 | tool_name = tool_name
350 | ```
351 |
352 | The name of the tool which was unknown.
353 |
354 | raise\_as
355 | ---------
356 |
357 | ```python
358 | raise_as(
359 | error_type: type[Exception], message: str
360 | ) -> t.Callable[[t.Callable[P, R]], t.Callable[P, R]]
361 | ```
362 |
363 | When the wrapped function raises an exception, `raise ... from` with the new error type.
364 |
365 |
366 | ```python
367 | def raise_as(
368 | error_type: type[Exception],
369 | message: str,
370 | ) -> t.Callable[[t.Callable[P, R]], t.Callable[P, R]]:
371 | "When the wrapped function raises an exception, `raise ... from` with the new error type."
372 |
373 | def _raise_as(func: t.Callable[P, R]) -> t.Callable[P, R]:
374 | @functools.wraps(func)
375 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
376 | try:
377 | return func(*args, **kwargs)
378 | except Exception as e:
379 | error = error_type(message)
380 | raise error from e
381 |
382 | if wrapper.__doc__ is None:
383 | wrapper.__doc__ = ""
384 |
385 | wrapper.__doc__ += f"\n\nRaises:\n {error_type.__name__}{': ' + message}"
386 |
387 | return wrapper
388 |
389 | return _raise_as
390 | ```
391 |
392 |
393 |
--------------------------------------------------------------------------------