├── colabs
└── images
│ ├── paligemma_fox.png
│ ├── paligemma_puffin.png
│ ├── README.md
│ └── image_provenance.py
├── onetwo
├── backends
│ ├── testdata
│ │ └── bird.jpg
│ ├── onetwo_api_test_script.sh
│ ├── run_model_server.py
│ ├── formatters_test.py
│ ├── formatters.py
│ ├── backends_base.py
│ ├── onetwo_api_manual_test.py
│ ├── openai_mock.py
│ ├── model_server.py
│ └── onetwo_api.py
├── version.py
├── __init__.py
├── core
│ ├── constants.py
│ ├── updating_test.py
│ ├── core_test_utils.py
│ ├── sampling_test.py
│ ├── updating.py
│ ├── routing.py
│ ├── executing_with_context_test.py
│ ├── executing_impl_test.py
│ └── executing_impl.py
├── ot.py
├── builtins
│ ├── callbacks_test.py
│ ├── tool_use_test.py
│ ├── tool_use.py
│ ├── composables.py
│ ├── prompt_templating.py
│ └── llm_utils.py
├── stdlib
│ ├── ensembling
│ │ ├── distribution_metrics.py
│ │ └── distribution_metrics_test.py
│ └── code_execution
│ │ ├── python_execution_test.py
│ │ ├── python_execution_test_utils.py
│ │ └── python_execution_utils.py
└── agents
│ ├── agents_test_utils_test.py
│ ├── agents_test_utils.py
│ ├── critics_test.py
│ ├── tasks
│ ├── game_of_24_test.py
│ └── game_of_24.py
│ ├── distribution_test.py
│ └── critics.py
├── .gitignore
├── CONTRIBUTING.md
├── pyproject.toml
├── docs
├── faq.md
└── basics.md
└── README.md
/colabs/images/paligemma_fox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/onetwo/HEAD/colabs/images/paligemma_fox.png
--------------------------------------------------------------------------------
/colabs/images/paligemma_puffin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/onetwo/HEAD/colabs/images/paligemma_puffin.png
--------------------------------------------------------------------------------
/onetwo/backends/testdata/bird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/onetwo/HEAD/onetwo/backends/testdata/bird.jpg
--------------------------------------------------------------------------------
/colabs/images/README.md:
--------------------------------------------------------------------------------
1 | # Image Assets
2 |
3 | This directory contains various image assets used in the Colab tutorial.
4 |
5 | The complete source and licensing information for every image file is documented
6 | programmatically in `image_provenance.py` within this directory.
7 |
--------------------------------------------------------------------------------
/onetwo/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Current OneTwo version at head on GitHub."""
16 |
17 | # A new GitHub release will be pushed everytime `__version__` is increased.
18 | # When changing this, also update the CHANGELOG.md.
19 | __version__ = '0.3.0'
20 |
--------------------------------------------------------------------------------
/onetwo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """OneTwo API."""
16 |
17 | # Do NOT add anything here !!
18 | # Indeed, top-level `__init__.py` makes it hard to import a specific sub-module
19 | # without triggering a full import of the codebase.
20 | # Instead, the public API is exposed in `ot.py`.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[cod]
2 | *.sw[op]
3 |
4 | # C extensions
5 | *.so
6 |
7 | # Packages
8 | *.egg
9 | *.egg-info
10 | dist
11 | build
12 | eggs
13 | .eggs
14 | parts
15 | bin
16 | var
17 | sdist
18 | develop-eggs
19 | .installed.cfg
20 | lib
21 | lib64
22 | __pycache__
23 |
24 | # Installer logs
25 | pip-log.txt
26 |
27 | # Unit test / coverage reports
28 | .coverage
29 | .nox
30 | .cache
31 | .pytest_cache
32 |
33 |
34 | # Mac
35 | .DS_Store
36 |
37 | # JetBrains
38 | .idea
39 |
40 | # VS Code
41 | .vscode
42 |
43 | # emacs
44 | *~
45 |
46 | # Built documentation
47 | docs/_build
48 | bigquery/docs/generated
49 | docs.metadata
50 |
51 | # Virtual environment
52 | env/
53 | venv/
54 |
55 | # Test logs
56 | coverage.xml
57 | *sponge_log.xml
58 |
59 | # System test environment variables.
60 | system_tests/local_test_setup
61 |
62 | # Make sure a generated file isn't accidentally committed.
63 | pylintrc
64 | pylintrc.test
65 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | At this time we do not plan to accept non-trivial contributions.
4 | We encourage forking the repository and continued development (as permitted by
5 | the license).
6 |
7 | ## Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a Contributor License
10 | Agreement. You (or your employer) retain the copyright to your contribution,
11 | this simply gives us permission to use and redistribute your contributions as
12 | part of the project. Head over to to see
13 | your current agreements on file or to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows [Google's Open Source Community
29 | Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/onetwo/core/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Some constants used throughout the onetwo codebase."""
16 |
17 | # Used when errors occur.
18 | ERROR_STRING = '#ERROR#'
19 |
20 | # Field in the caching key where the name of the cached function is stored.
21 | CACHING_FUNCTION_NAME_KEY = '_destination'
22 |
23 | # Prompt prefix.
24 | PROMPT_PREFIX = 'prefix'
25 |
26 | # The field of the Jinja context that contains variables.
27 | CONTEXT_VARS = '__vars__'
28 |
29 | # The context variable containing the execution result for a Jinja template.
30 | RESULT_VAR = '_result'
31 |
32 | # The fields in the Jinja context variables with results of `choose` command.
33 | CHOICES_VAR = 'choices'
34 | SCORES_VAR = 'scores'
35 |
--------------------------------------------------------------------------------
/onetwo/backends/onetwo_api_test_script.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 |
17 | # Run from the onetwo root dir that contains README.md and other files.
18 | ONETWO_ROOT_DIR="."
19 | START_SERVER_CMD="python3 ${ONETWO_ROOT_DIR}/onetwo/backends/run_model_server.py"
20 | START_CLIENT_CMD="python3 ${ONETWO_ROOT_DIR}/onetwo/backends/onetwo_api_manual_test.py"
21 | CACHE_DIR="${ONETWO_ROOT_DIR}/tmp"
22 | # Requires portpicker to be installed.
23 | PORT=`python3 -m portpicker $$`
24 |
25 | function clean_fail() {
26 | kill $!;
27 | exit 1
28 | }
29 |
30 | echo "Using port ${PORT}"
31 |
32 | set -o xtrace
33 |
34 | # Start model_server.
35 | ${START_SERVER_CMD} --port="${PORT}" &
36 | sleep 10
37 |
38 | # Start client and run test: from scratch.
39 | ${START_CLIENT_CMD} \
40 | --endpoint="http://localhost:${PORT}" \
41 | --cache_dir=${CACHE_DIR} & sleep 3 || clean_fail
42 |
43 | # Start client and run test: cached replies.
44 | ${START_CLIENT_CMD} \
45 | --endpoint="http://localhost:${PORT}" --cache_dir=${CACHE_DIR}\
46 | --load_cache_file=True || clean_fail
47 |
48 | # Stop model_server.
49 | kill $!
50 |
--------------------------------------------------------------------------------
/onetwo/ot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Entry point to the OneTwo library.
16 |
17 | ```python
18 | from onetwo import ot
19 | ```
20 | """
21 |
22 | from onetwo import version
23 | from onetwo.core import composing
24 | from onetwo.core import executing
25 | from onetwo.core import results
26 | from onetwo.core import routing
27 | from onetwo.core import sampling
28 | from onetwo.evaluation import evaluation
29 |
30 | __version__: str = version.__version__
31 |
32 | compare_with_critic = evaluation.compare_with_critic
33 | copy_registry = routing.copy_registry
34 | evaluate = evaluation.evaluate
35 | Executable = executing.Executable
36 | function_call = routing.function_registry
37 | function_registry = routing.function_registry
38 | HTMLRenderer = results.HTMLRenderer
39 | make_composable = composing.make_composable
40 | make_executable = executing.make_executable
41 | naive_comparison_critic = evaluation.naive_comparison_critic
42 | naive_evaluation_critic = evaluation.naive_evaluation_critic
43 | naive_fuzzy_evaluation_critic = evaluation.naive_fuzzy_evaluation_critic
44 | par_iter = executing.par_iter
45 | parallel = executing.parallel
46 | RegistryContext = routing.RegistryContext
47 | repeat = sampling.repeat
48 | run = executing.run
49 | safe_stream = executing.safe_stream
50 | set_registry = routing.set_registry
51 | stream_updates = executing.stream_updates
52 | stream_with_callback = executing.stream_with_callback
53 | with_current_registry = routing.with_current_registry
54 | with_registry = routing.with_registry
55 |
56 |
--------------------------------------------------------------------------------
/colabs/images/image_provenance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Stores provenance information for images used in OneTwo Colab.
16 |
17 | This module contains metadata about images, including their external URLs and
18 | licensing information, primarily used for image provenance tracking.
19 | """
20 |
21 | import dataclasses
22 | from typing import Dict
23 |
24 |
25 | @dataclasses.dataclass
26 | class ImageSource:
27 | """Stores provenance and licensing information for an image asset."""
28 | # The external URL where the image was originally found.
29 | external_url: str = ''
30 | # The local path where the image is stored in the repository.
31 | local_path: str = ''
32 | # The license under which the image is distributed.
33 | license: str = ''
34 |
35 |
36 | # A constant mapping image names (keys) to their ImageSource metadata (values).
37 | image_sources: Dict[str, ImageSource] = {
38 | 'paligemma_fox': ImageSource(
39 | external_url='https://big-vision-paligemma.hf.space/file=/tmp/gradio/4aa2d3fd01a6308961397f68e043b2015bc91493/image.png',
40 | local_path='paligemma_fox.png',
41 | license=(
42 | 'CC0 by [XiaohuaZhai@](https://sites.google.com/corp/view/xzhai)'
43 | ),
44 | ),
45 | 'paligemma_puffin': ImageSource(
46 | external_url='https://big-vision-paligemma.hf.space/file=/tmp/gradio/78f93b49088f8d72ee546d656387403d647b413f/image.png',
47 | local_path='paligemma_puffin.png',
48 | license=(
49 | 'CC0 by [XiaohuaZhai@](https://sites.google.com/corp/view/xzhai)'
50 | ),
51 | ),
52 | }
53 |
--------------------------------------------------------------------------------
/onetwo/backends/run_model_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Binary that starts a OneTwo Model Server (serves as an example/test)."""
16 |
17 | import argparse
18 | import importlib
19 | import json
20 | from typing import Final
21 |
22 | import uvicorn
23 |
24 |
25 | _DEFAULT_PORT = 8000
26 | _DEFAULT_BACKEND_MODULE: Final[str] = 'onetwo.backends.test_utils'
27 | _DEFAULT_BACKEND_CLASS: Final[str] = 'LLMForTest'
28 | _DEFAULT_BACKEND_ARGS: Final[str] = '{"default_reply": "Test reply"}'
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser('OneTwo Model Server.')
33 | parser.add_argument(
34 | '--port', type=int, default=_DEFAULT_PORT, help='Port to listen on.'
35 | )
36 | parser.add_argument(
37 | '--backend_module',
38 | type=str,
39 | default=_DEFAULT_BACKEND_MODULE,
40 | help='Backend module to load.',
41 | )
42 | parser.add_argument(
43 | '--backend_class',
44 | type=str,
45 | default=_DEFAULT_BACKEND_CLASS,
46 | help='Backend class to instantiate.',
47 | )
48 | parser.add_argument(
49 | '--backend_args',
50 | type=str,
51 | default=_DEFAULT_BACKEND_ARGS,
52 | help='Arguments for the backend class constructor (in JSON format).',
53 | )
54 | args = parser.parse_args()
55 | backend_module = importlib.import_module(args.backend_module)
56 | backend_class = getattr(backend_module, args.backend_class)
57 | backend_args = json.loads(args.backend_args)
58 | backend = backend_class(**backend_args)
59 | backend.register()
60 |
61 | uvicorn.run(
62 | 'onetwo.backends.model_server:ModelServer',
63 | host='0.0.0.0',
64 | port=args.port,
65 | factory=True,
66 | )
67 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "onetwo"
3 | description = "LLM Orchestration Library by Google DeepMind."
4 | readme = "README.md"
5 | requires-python = ">=3.10"
6 | license = {file = "LICENSE"}
7 | authors = [{name = "OneTwo Authors", email="no-reply@google.com"}]
8 | classifiers = [
9 | "Programming Language :: Python :: 3",
10 | "Programming Language :: Python :: 3 :: Only",
11 | "License :: OSI Approved :: Apache Software License",
12 | "Intended Audience :: Science/Research",
13 | ]
14 | keywords = []
15 |
16 | # pip dependencies of the project.
17 | dependencies = [
18 | "absl-py",
19 | "aenum",
20 | "dataclasses-json",
21 | "fastapi",
22 | "freezegun",
23 | # Note: there is a PYPI `gemma` library, which is not what we want.
24 | "gemma@git+https://github.com/google-deepmind/gemma.git",
25 | "google-cloud-aiplatform",
26 | "google-generativeai",
27 | "html5lib",
28 | "immutabledict",
29 | "jinja2",
30 | "numpy",
31 | "openai",
32 | "pillow",
33 | "pytest",
34 | "portpicker", # For `backends/onetwo_api_test_script.sh`.
35 | "pyyaml", # For `import yaml`.
36 | "termcolor",
37 | "tqdm",
38 | "typing_extensions",
39 | "uvicorn",
40 | ]
41 |
42 | # This is set automatically by flit using `onetwo.__version__`.
43 | dynamic = ["version"]
44 |
45 | [project.urls]
46 | homepage = "https://github.com/google-deepmind/onetwo"
47 | repository = "https://github.com/google-deepmind/onetwo"
48 |
49 | [project.optional-dependencies]
50 | # Installed through `pip install '.[dev]'`.
51 | dev = [
52 | "pytest-xdist",
53 | "pylint>=2.6.0",
54 | "pyink",
55 | ]
56 |
57 | # Installed through `pip install '.[docs]'`.
58 | docs = [
59 | # Install `apitree` with all extensions (sphinx, theme,...)
60 | "sphinx-apitree[ext]",
61 | ]
62 |
63 | [tool.pytest.ini_options]
64 | addopts = [
65 | "--import-mode=importlib",
66 | ]
67 |
68 | [tool.pyink]
69 | # Formatting configuration to follow Google style-guide.
70 | line-length = 80
71 | preview = true
72 | pyink-indentation = 2
73 | pyink-use-majority-quotes = true
74 |
75 | [build-system]
76 | # Build system specify which backend is used to build/install the project (flit,
77 | # poetry, setuptools, ...). All backends are supported by `pip install`.
78 | requires = ["flit_core >=3.8,<4"]
79 | build-backend = "flit_core.buildapi"
80 |
81 | [tool.flit.sdist]
82 | # Flit specific options (files to exclude from the PyPI package).
83 | # If using another build backend (setuptools, poetry), you can remove this
84 | # section.
85 | exclude = [
86 | # Do not release tests files on PyPI.
87 | "**/*_test.py",
88 | ]
89 |
--------------------------------------------------------------------------------
/onetwo/backends/formatters_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for formatting."""
16 | import asyncio
17 | from collections.abc import Sequence
18 | from typing import TypeAlias
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | # Necessary for the FormatterName enum to be populated.
23 | from onetwo.backends import formatters # pylint: disable=unused-import
24 | from onetwo.builtins import formatting
25 | from onetwo.core import content as content_lib
26 |
27 |
28 | _Message: TypeAlias = content_lib.Message
29 | _PredefinedRole: TypeAlias = content_lib.PredefinedRole
30 |
31 |
32 | class FormattersTest(parameterized.TestCase):
33 |
34 | @parameterized.named_parameters(
35 | (
36 | 'gemma_user_only',
37 | formatting.FormatterName.GEMMA,
38 | [_Message(role=_PredefinedRole.USER, content='Hello')],
39 | 'user\nHello',
40 | ),
41 | (
42 | 'gemma_user_and_model',
43 | formatting.FormatterName.GEMMA,
44 | [
45 | _Message(role=_PredefinedRole.USER, content='Hello'),
46 | _Message(role=_PredefinedRole.MODEL, content='What'),
47 | ],
48 | 'user\nHello\nmodel\nWhat',
49 | ),
50 | (
51 | 'gemma_user_and_empty_model',
52 | formatting.FormatterName.GEMMA,
53 | [
54 | _Message(role=_PredefinedRole.USER, content='Hello'),
55 | _Message(role=_PredefinedRole.MODEL, content=''),
56 | ],
57 | 'user\nHello\nmodel',
58 | ),
59 | )
60 | def test_format(
61 | self,
62 | formatter: formatting.FormatterName,
63 | messages: Sequence[_Message],
64 | expected: str,
65 | ):
66 | async def wrapper():
67 | formatter_class = formatting.FORMATTER_CLASS_BY_NAME[formatter]
68 | formatter_instance = formatter_class() # pytype: disable=not-instantiable
69 | result = formatter_instance.format(messages)
70 | return result
71 |
72 | result = asyncio.run(wrapper())
73 | self.assertEqual(str(result), expected)
74 |
75 |
76 | if __name__ == '__main__':
77 | absltest.main()
78 |
--------------------------------------------------------------------------------
/onetwo/builtins/callbacks_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Sequence
16 | from typing import TypeVar
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from onetwo.builtins import callbacks
21 | from onetwo.builtins import llm
22 | from onetwo.core import content as content_lib
23 | from onetwo.core import executing
24 | from onetwo.core import templating
25 |
26 |
27 | _T = TypeVar('_T')
28 |
29 |
30 | Chunk = content_lib.Chunk
31 | ChunkList = content_lib.ChunkList
32 |
33 |
34 | # TODO: For now all the tests are in prompt_templating_test.py.
35 | # We should move them here.
36 | class CallbacksTest(parameterized.TestCase):
37 |
38 | def setUp(self):
39 | super().setUp()
40 |
41 | # This class tests various `llm` builtins. In case `import llm` is not
42 | # executed (this may happen when running `pytest` with multiple tests that
43 | # import `llm` module) various builtins from `llm` may be already configured
44 | # elsewhere in unexpected ways. We manually reset all the default builtin
45 | # implementations to make sure they are set properly.
46 | llm.reset_defaults()
47 |
48 | def generate(
49 | prompt: str | ChunkList,
50 | *,
51 | temperature: float | None = None,
52 | max_tokens: int | None = None,
53 | stop: Sequence[str] | None = None,
54 | top_k: int | None = None,
55 | top_p: float | None = None,
56 | ) -> str:
57 | del prompt, temperature, max_tokens, stop, top_k, top_p
58 | return ' done'
59 |
60 | def score(
61 | prompt: str | ChunkList, targets: Sequence[str]
62 | ) -> Sequence[float]:
63 | del prompt
64 | # We score by the length of the target.
65 | return [float(len(target)) for target in targets]
66 |
67 | llm.generate_text.configure(generate)
68 | llm.score_text.configure(score)
69 |
70 | def test_generate_text(self):
71 | tpl = templating.JinjaTemplate(text='{{ generate_text() }}')
72 | tpl.register_callback(
73 | 'generate_text', callbacks.generate_text, pass_context=True
74 | )
75 | res = executing.run(tpl.render())
76 | self.assertEqual(res['prefix'], ' done')
77 |
78 |
79 | if __name__ == '__main__':
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/onetwo/stdlib/ensembling/distribution_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Metric functions taking a distribution of model predictions as input.
16 |
17 | These can be used, for example, for evaluating the outputs of a SelfConsistency
18 | strategy.
19 | """
20 |
21 | from collections.abc import Sequence
22 | import dataclasses
23 | from typing import Generic, TypeVar
24 | from onetwo.core import executing
25 | from onetwo.core import tracing
26 | from onetwo.core import utils
27 | from onetwo.evaluation import agent_evaluation
28 |
29 | # Type representing a strategy's output.
30 | _O = TypeVar('_O')
31 |
32 |
33 | @dataclasses.dataclass
34 | class AccuracyAtK(Generic[_O]):
35 | """Returns the maximum accuracy among the first k predictions.
36 |
37 | Attributes:
38 | k: Number of candidate answers to consider for accuracy. If None, then will
39 | consider all predictions (even those with probability 0). If non-None,
40 | then will only consider the first k.
41 | base_metric: Metric to use for calculating the accuracy of an individual
42 | prediction vs. the target. Defaults to exact-match accuracy. Bigger is
43 | assumed to mean better.
44 | """
45 |
46 | k: int | None = 1
47 | base_metric: agent_evaluation.MetricFunction = lambda t, p: 1.0 * (t == p)
48 |
49 | @executing.make_executable(copy_self=False)
50 | @tracing.trace(name=utils.FROM_INSTANCE_CLASS_NAME)
51 | async def __call__(
52 | self, target: _O, prediction: Sequence[tuple[_O, float]]
53 | ) -> float:
54 | """Returns accuracy at k (value from 0.0 to 1.0).
55 |
56 | Args:
57 | target: The target (single answer).
58 | prediction: The predicted distribution over possible answers.
59 | """
60 | k = self.k if self.k is not None else len(prediction)
61 | if k < 0:
62 | raise ValueError(f'k must be non-negative, got {k}')
63 |
64 | # For simplicity, we are currently calling `self.base_metric` sequentially
65 | # for each of the predictions. In the case where `self.base_metric` is an
66 | # `async` function, though, this could potentially be optimized in the
67 | # future by wrapping the calls to the base metric with `executing.parallel`.
68 | base_metric_values = [
69 | await utils.call_and_maybe_await(self.base_metric, target, x)
70 | for x, _ in prediction[:k]
71 | ]
72 | return max(base_metric_values, default=0.0)
73 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | # OneTwo FAQ
2 |
3 | ## Is there a way to see actual formatted prompts that are sent to the backends?
4 |
5 | While some of the builtin functions (e.g., `llm.generate_text`) apply little to
6 | no modification to the prompt before it is sent to the model, others (e.g.,
7 | `llm.instruct` or `llm.chat`) may apply a lot of formatting. For example,
8 | one natural way of implementing `llm.instruct` for a model that is only
9 | pre-trained (PT) but not instruction-tuned (IT) is to first format the
10 | user-provided task, e.g. `'Write me a short poem'`, into a longer prompt similar
11 | to `'Here is the task: Write me a short poem. Here is the answer:'` and then
12 | send it to pure completion with `llm.generate_text`. Indeed, this is precisely
13 | how the default implementation of `llm.instruct`
14 | (`onetwo.builtins.default_instruct`) works.
15 |
16 | In cases like this, where a non-trivial formatting of prompts takes place, the
17 | user may naturally want to see the actual fully formatted prompt that is sent to
18 | a model. A simple way to do it, which often comes handy when debugging, is to
19 | configure (mock) `llm.generate_text` with a fake implementation that simply
20 | returns the prompt (for convenience we provide such an implementation in
21 | `onetwo.builtins.echo_generate_text`):
22 |
23 | ```python
24 | import ot
25 | from onetwo.builtins import llm
26 | backend = ...
27 | # Assume this backend only configures `llm.generate_text`, i.e. `llm.instruct`
28 | # is configured to use the default OneTwo implementation.
29 | backend.register()
30 | print(ot.run(llm.generate_text('Once upon a')))
31 | print(ot.run(llm.instruct('Name three cities in France.')))
32 |
33 | def fake_generate_text(prompt: str | content_lib.ChunkList, **kwargs):
34 | return prompt
35 |
36 | # Alternatively, use `onetwo.builtins.echo_generate_text`.
37 | llm.generate_text.configure(fake_generate_text)
38 | # Now `llm.generate_text` simply returns its input.
39 |
40 | assert ot.run(llm.generate_text('Once upon a')) == 'Once upon a'
41 | # `llm.instruct` formats the prompt and sends it to `llm.generate_text`,
42 | # which returns the formatted prompt.
43 | print(ot.run(llm.instruct('Name three cities in France.')))
44 | # We should get something like 'Task: Name three cities in France.\n Answer:'.
45 |
46 | backend.register()
47 | # Now `llm.generate_text` points again to the backend implementation.
48 | ```
49 |
50 | This approach assumes that you know exactly where the first call to the
51 | *external model API* happens (e.g., `llm.generate_text` in the example above) so
52 | that you can mock it.
53 |
54 | In the future we plan to introduce a more principled and unified way of doing
55 | this.
56 | Likely we will base it on the `onetwo.core.tracing.trace` decorator that is
57 | already available in OneTwo (refer to the "Agents and Tool Use" section of our
58 | [Colab](https://colab.research.google.com/github/google-deepmind/onetwo/blob/main/colabs/tutorial.ipynb)
59 | for more details on tracing with OneTwo).
60 |
--------------------------------------------------------------------------------
/onetwo/stdlib/code_execution/python_execution_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 | from onetwo.stdlib.code_execution import python_execution
18 |
19 | _ExecutionStatus = python_execution.ExecutionStatus
20 | _SandboxResult = python_execution.SandboxResult
21 |
22 |
23 | class SandboxResultTest(parameterized.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | (
27 | 'success_no_output',
28 | python_execution.SandboxResult(),
29 | 'None',
30 | ),
31 | (
32 | 'success_stdout_only',
33 | python_execution.SandboxResult(stdout='Hi'),
34 | 'Hi',
35 | ),
36 | (
37 | 'success_final_expression_value_only',
38 | python_execution.SandboxResult(final_expression_value=2),
39 | '2',
40 | ),
41 | (
42 | 'success_stdout_and_final_expression_value',
43 | python_execution.SandboxResult(final_expression_value=2, stdout='Hi'),
44 | 'Hi\n2',
45 | ),
46 | (
47 | 'error_no_output',
48 | python_execution.SandboxResult(
49 | execution_status=_ExecutionStatus.EXECUTION_ERROR,
50 | status_message='Error message.',
51 | ),
52 | 'EXECUTION_ERROR: Error message.',
53 | ),
54 | (
55 | 'error_stdout_only',
56 | python_execution.SandboxResult(
57 | stdout='Hi',
58 | execution_status=_ExecutionStatus.EXECUTION_ERROR,
59 | status_message='Error message.',
60 | ),
61 | 'Hi\nEXECUTION_ERROR: Error message.',
62 | ),
63 | (
64 | 'error_final_expression_value_only',
65 | python_execution.SandboxResult(
66 | final_expression_value=2,
67 | execution_status=_ExecutionStatus.EXECUTION_ERROR,
68 | status_message='Error message.',
69 | ),
70 | '2\nEXECUTION_ERROR: Error message.',
71 | ),
72 | (
73 | 'error_stdout_and_final_expression_value',
74 | python_execution.SandboxResult(
75 | final_expression_value=2,
76 | stdout='Hi',
77 | execution_status=_ExecutionStatus.EXECUTION_ERROR,
78 | status_message='Error message.',
79 | ),
80 | 'Hi\n2\nEXECUTION_ERROR: Error message.',
81 | ),
82 | )
83 | def test_to_string(self, result, expected_string):
84 | self.assertEqual(expected_string, str(result))
85 |
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/onetwo/builtins/tool_use_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from onetwo.builtins import tool_use
20 | from onetwo.core import executing
21 | from onetwo.core import routing
22 |
23 |
24 | def _add_sync(arg1: Any, arg2: Any) -> Any:
25 | return arg1 + arg2
26 |
27 |
28 | async def _add_async(arg1: Any, arg2: Any) -> Any:
29 | return arg1 + arg2
30 |
31 |
32 | @executing.make_executable # pytype: disable=wrong-arg-types
33 | def _add_executable(arg1: Any, arg2: Any) -> Any:
34 | return arg1 + arg2
35 |
36 |
37 | class ToolUseTest(parameterized.TestCase):
38 |
39 | def setUp(self):
40 | super().setUp()
41 |
42 | # This class tests routing.function_registry. In case `import routing` is
43 | # not executed (this may happen when running `pytest` with multiple tests
44 | # that import `llm` module) the `function_registry` may be already filled
45 | # with various functions elsewhere in unexpected ways. We manually remove
46 | # all the keys to make sure it is empty.
47 | routing.function_registry.clear()
48 | # Unfortunately, we also removed all the builtins configured when importing
49 | # tool_use. Let's re-set them.
50 | # TODO:` or such
51 | # for better control of reproducibility.
52 | tool_use.reset_defaults()
53 |
54 | @parameterized.named_parameters(
55 | ('sync', _add_sync),
56 | ('async', _add_async),
57 | ('executable', _add_executable),
58 | )
59 | def test_default_run_tool_function_types(self, add_function):
60 | routing.function_registry['add'] = add_function
61 | result = executing.run(
62 | tool_use.run_tool( # pytype: disable=wrong-keyword-args
63 | tool_name='add', tool_args=['a', 'b'], tool_kwargs={}
64 | )
65 | )
66 | self.assertEqual('ab', result)
67 |
68 | @parameterized.named_parameters(
69 | ('positional_args_numeric', [1, 2], {}, 3),
70 | ('keyword_args_numeric', [], {'arg1': 1, 'arg2': 2}, 3),
71 | ('positional_args_string', ['1', '2'], {}, '12'),
72 | )
73 | def test_default_run_tool_args_types(self, args, kwargs, expected_result):
74 | def add(arg1: Any, arg2: Any) -> Any:
75 | return arg1 + arg2
76 |
77 | routing.function_registry['add'] = add
78 | result = executing.run(
79 | tool_use.run_tool( # pytype: disable=wrong-keyword-args
80 | tool_name='add', tool_args=args, tool_kwargs=kwargs
81 | )
82 | )
83 | self.assertEqual(expected_result, result)
84 |
85 | if __name__ == '__main__':
86 | absltest.main()
87 |
--------------------------------------------------------------------------------
/onetwo/stdlib/ensembling/distribution_metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 | from onetwo.core import executing
18 | from onetwo.stdlib.ensembling import distribution_metrics
19 |
20 |
21 | # Simple custom base metric that gives full credit if the correct answer is
22 | # included anywhere in the prediction (e.g., if target is `12` and the
23 | # prediction is 'The answer is 12.')
24 | def answer_included_sync(target: str, prediction: str):
25 | return 1.0 if target in prediction else 0.0
26 |
27 |
28 | async def answer_included_async(target: str, prediction: str):
29 | return answer_included_sync(target, prediction)
30 |
31 |
32 | @executing.make_executable # pytype: disable=wrong-arg-types
33 | def answer_included_executable(target: str, prediction: str):
34 | return answer_included_sync(target, prediction)
35 |
36 |
37 | class DistributionMetricsTest(parameterized.TestCase):
38 |
39 | @parameterized.named_parameters(
40 | ('correct_at_1', 'a', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 1, 1.0),
41 | ('wrong_at_1', 'b', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 1, 0.0),
42 | ('correct_at_2', 'b', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 2, 1.0),
43 | ('wrong_at_2', 'c', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 2, 0.0),
44 | ('case_sensitive', 'A', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 1, 0.0),
45 | ('empty_distribution', 'a', [], 1, 0.0),
46 | ('number_correct', 1.0, [(1, 0.5), (2, 0.3), (3, 0.2)], 1, 1.0),
47 | ('number_wrong', 2.0, [(1, 0.5), (2, 0.3), (3, 0.2)], 1, 0.0),
48 | )
49 | def test_accuracy_at_k_default_base_metric(
50 | self, target, predicted_distribution, k, expected
51 | ):
52 | metric = distribution_metrics.AccuracyAtK(k=k)
53 | actual = executing.run(
54 | metric(target=target, prediction=predicted_distribution) # pytype: disable=wrong-keyword-args
55 | )
56 | self.assertEqual(expected, actual)
57 |
58 | @parameterized.named_parameters(
59 | ('sync', answer_included_sync),
60 | ('async', answer_included_async),
61 | ('executable', answer_included_executable),
62 | )
63 | def test_accuracy_at_k_custom_base_metric(self, base_metric):
64 | metric = distribution_metrics.AccuracyAtK(
65 | k=2, base_metric=base_metric
66 | )
67 |
68 | actual = executing.run(
69 | metric(target='b', prediction=[('=a', 0.5), ('=b', 0.3), ('=c', 0.2)]) # pytype: disable=wrong-keyword-args
70 | )
71 | with self.subTest('correct_at_2'):
72 | self.assertEqual(1.0, actual, actual)
73 |
74 | actual = executing.run(
75 | metric(target='c', prediction=[('=a', 0.5), ('=b', 0.3), ('=c', 0.2)]) # pytype: disable=wrong-keyword-args
76 | )
77 | with self.subTest('wrong_at_2'):
78 | self.assertEqual(0.0, actual, actual)
79 |
80 |
81 | if __name__ == '__main__':
82 | absltest.main()
83 |
--------------------------------------------------------------------------------
/onetwo/backends/formatters.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implementation of various formatters for different models."""
16 |
17 | from collections.abc import Sequence
18 | from typing import Final, TypeAlias
19 |
20 | import aenum
21 | from onetwo.builtins import formatting
22 | from onetwo.core import content as content_lib
23 |
24 |
25 | _PredefinedRole = content_lib.PredefinedRole
26 | _Message: TypeAlias = content_lib.Message
27 | _Chunk: TypeAlias = content_lib.Chunk
28 | _ChunkList: TypeAlias = content_lib.ChunkList
29 | _FormatterName = formatting.FormatterName
30 | _Formatter = formatting.Formatter
31 |
32 |
33 | _START_OF_TURN: Final[str] = ''
34 | _END_OF_TURN: Final[str] = ''
35 |
36 |
37 | class GemmaFormatter(_Formatter):
38 | """Gemma formatter for instruction tuned models."""
39 |
40 | @property
41 | def role_map(self) -> dict[str | _PredefinedRole, str]:
42 | """Returns a mapping from role to string representation.
43 |
44 | This serves two purposes:
45 | - It specifies which roles are supported by the formatter. Any role that is
46 | not in this map will be ignored or raise an error if
47 | `raise_error_if_unsupported_roles=True` in the call to `format`.
48 | - It allows to specify how the role name is represented in the prompt.
49 | """
50 | return {
51 | _PredefinedRole.USER: 'user',
52 | _PredefinedRole.MODEL: 'model'
53 | }
54 |
55 | def is_role_supported(self, role: str| _PredefinedRole) -> bool:
56 | """Overridden from base class (Formatter)."""
57 | return role in self.role_map
58 |
59 | def is_already_formatted(self, content: Sequence[_Message]) -> bool:
60 | """Overridden from base class (Formatter)."""
61 | if not content:
62 | return False
63 | return any(
64 | _START_OF_TURN in str(message.content)
65 | for message in content
66 | )
67 |
68 | def extra_stop_sequences(self) -> list[str]:
69 | """Overridden from base class (Formatter)."""
70 | return [_START_OF_TURN]
71 |
72 | def _format(
73 | self,
74 | content: Sequence[_Message],
75 | ) -> _ChunkList:
76 | """Overridden from base class (Formatter)."""
77 | res = []
78 | for i, message in enumerate(content):
79 | role_name = self.role_map[message.role]
80 | if i == len(content) - 1 and message.role == _PredefinedRole.MODEL:
81 | if (str(message.content)):
82 | res.append(f'{_START_OF_TURN}{role_name}\n{message.content}')
83 | else:
84 | res.append(f'{_START_OF_TURN}{role_name}')
85 | else:
86 | res.append(
87 | f'{_START_OF_TURN}{role_name}\n{message.content}{_END_OF_TURN}'
88 | )
89 | res = '\n'.join(res)
90 | return _ChunkList([_Chunk(res)])
91 |
92 |
93 | aenum.extend_enum(formatting.FormatterName, 'GEMMA', 'gemma')
94 | formatting.FORMATTER_CLASS_BY_NAME[_FormatterName.GEMMA] = GemmaFormatter
95 |
--------------------------------------------------------------------------------
/onetwo/core/updating_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import copy
18 | from typing import TypeAlias
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | from onetwo.core import updating
23 |
24 |
25 | Update = updating.Update
26 | ListUpdate = updating.ListUpdate
27 |
28 | _S: TypeAlias = updating.AddableUpdate[str]
29 |
30 |
31 | class UpdatingTest(parameterized.TestCase):
32 |
33 | def test_list_update_accumulate(self):
34 | u1 = Update('a')
35 | u2 = Update('b')
36 | l1 = ListUpdate([(u1, 0)])
37 | l2 = ListUpdate([(u2, 0)])
38 | l3 = copy.deepcopy(l2)
39 |
40 | with self.subTest('start_from_scratch'):
41 | self.assertEqual(Update() + l2, l3)
42 |
43 | with self.subTest('should_replace'):
44 | self.assertEqual(l1 + l2, l3)
45 |
46 | def test_plus_and_sum(self):
47 | list_updates1 = (ListUpdate([(i, i)]) for i in range(3))
48 | list_updates2 = (ListUpdate([(i, i)]) for i in range(3))
49 | update = Update()
50 | for u in list_updates1:
51 | update += u
52 | final = sum(list_updates2, start=Update())
53 | self.assertEqual(update, final)
54 | self.assertEqual(final.to_result(), [0, 1, 2])
55 |
56 | @parameterized.named_parameters(
57 | ('empty', ListUpdate([]), []),
58 | ('singleton', ListUpdate([(Update('a'), 0)]), ['a']),
59 | (
60 | 'non_contiguous',
61 | ListUpdate([
62 | (Update('a'), 0),
63 | (Update('b'), 2),
64 | (Update('c'), 4),
65 | ]),
66 | ['a', None, 'b', None, 'c'],
67 | ),
68 | (
69 | 'nested',
70 | ListUpdate([
71 | (ListUpdate([(Update('a'), 1)]), 1),
72 | ]),
73 | [None, [None, 'a']],
74 | ),
75 | (
76 | 'nested_another_example',
77 | ListUpdate([
78 | (ListUpdate([(10, 1)]), 0),
79 | (ListUpdate([(20, 0)]), 2),
80 | ]),
81 | [[None, 10], None, [20,]],
82 | ),
83 | (
84 | 'nested_deeper',
85 | ListUpdate([
86 | (
87 | ListUpdate([
88 | (ListUpdate([('ab', 1)]), 1)
89 | ]),
90 | 0
91 | ),
92 | (
93 | ListUpdate([
94 | (ListUpdate([('cd', 0)]), 1)
95 | ]),
96 | 2
97 | ),
98 | ]),
99 | [[None, [None, 'ab']], None, [None, ['cd']]],
100 | ),
101 | )
102 | def test_list_update_to_accumulate(self, list_update, result):
103 | self.assertListEqual(list_update.to_result(), result)
104 |
105 | def test_nested(self):
106 | u1 = Update('a')
107 | u2 = Update(u1)
108 | self.assertEqual(u1.to_result(), u2.to_result())
109 |
110 | def test_addable(self):
111 | l = _S('a')
112 | self.assertEqual(l.to_result(), 'a')
113 | l += 'b'
114 | self.assertEqual(l.to_result(), 'ab')
115 | l += _S('c') + 'd'
116 | self.assertEqual(l.to_result(), 'abcd')
117 |
118 |
119 | if __name__ == '__main__':
120 | absltest.main()
121 |
--------------------------------------------------------------------------------
/onetwo/backends/backends_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Defines the base Backend class."""
16 |
17 | import dataclasses
18 |
19 |
20 | @dataclasses.dataclass
21 | class Backend:
22 | """Interface for a class that registers its method.
23 |
24 | Attributes:
25 | name: An optional name for this backend instance.
26 | """
27 |
28 | # A name for this backend instance.
29 | name: str = dataclasses.field(
30 | # We set `init=False` to prevent `name` from becoming an `__init__`
31 | # argument with a default value. If `name` were an `__init__`
32 | # argument, subclasses would not be able to define positional
33 | # arguments, because Python does not allow positional arguments to
34 | # follow arguments with default values.
35 | init=False,
36 | default='',
37 | )
38 |
39 | # Name under which the methods are registered by default.
40 | _default_name: str = dataclasses.field(
41 | init=False,
42 | default='default',
43 | )
44 |
45 | # Not an abstract method since by default it is ok not to register anything.
46 | def register(self, name: str | None = None):
47 | """Add the relevant methods to the registry.
48 |
49 | It is necessary for a child class to override this method in order to
50 | specify which methods it wants to expose/register.
51 | In particular this can be used to call `configure` or `get_variant` on
52 | builtins or to directly register methods of this object under particular
53 | names.
54 | For example the body of this method can look like:
55 | ```
56 | registry = routing.function_registry
57 | # Register self.my_fn with the provided or default name prepended.
58 | registry[f'{name or self._default_name}.my_fn'] = self.my_fn
59 | # Configure some builtin.
60 | llm.generate_text.configure(
61 | self.my_generate_text,
62 | temperature=default_temperature
63 | )
64 | # Configure a variant of the builtin with a different default parameter.
65 | variant = llm.generate_text.get_variant(
66 | self.my_generate_text
67 | temperature=0.0
68 | )
69 | # Register the variant under a special name.
70 | registry['llm.zero_temp'] = variant
71 | ```
72 |
73 | Args:
74 | name: An optional argument to use as prefix of the names in the registry,
75 | if None, the _default_name attribute may be used (it can be set in a
76 | child class declaration).
77 | """
78 |
79 |
80 | def truncate_reply(reply: str, stop_sequences: list[str]) -> str:
81 | """Returns the truncated reply not including any stop sequence.
82 |
83 | Example:
84 | If reply is 'abcdefg' and stop_sequences is ['f', 'c'], then the truncated
85 | reply is 'ab'.
86 |
87 | Args:
88 | reply: The original reply to be truncated.
89 | stop_sequences: The substrings to truncate the reply at. The stop sequences
90 | are not included in the truncated reply.
91 | """
92 | if not stop_sequences:
93 | return reply
94 |
95 | # Some stop sequences may be overlapping, and we want to guarantee that we
96 | # have the shortest possible prefix after truncation.
97 | # So if we have a reply 'abcdefg' and use as stop sequences ['f', 'def']
98 | # we want to return 'abc'. This means we have to compute separately the
99 | # truncation for each stop sequence and pick the smallest obtained prefix.
100 | truncated_replies = set()
101 | for stop_sequence in stop_sequences:
102 | truncated_replies.add(reply.split(stop_sequence)[0])
103 |
104 | # min will sort alphabetically, which, since we only have prefixes of a common
105 | # string, will result in returning the shortest one.
106 | return min(truncated_replies)
107 |
--------------------------------------------------------------------------------
/onetwo/builtins/tool_use.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Definitions of built-in functions and utilities for tool use.
16 |
17 | A tool can be an arbitrary function that we want to expose to the LLM so that
18 | it can call it.
19 | """
20 |
21 | from typing import Any
22 |
23 | from onetwo.builtins import builtins_base
24 | from onetwo.core import constants
25 | from onetwo.core import executing
26 | from onetwo.core import routing
27 | from onetwo.core import utils
28 |
29 |
30 | @builtins_base.Builtin
31 | def run_tool(
32 | tool_name: str, tool_args: tuple[Any, ...], tool_kwargs: dict[str, Any]
33 | ) -> Any:
34 | """Interface of the run_tool built-in function.
35 |
36 | Runs a tool and returns the result.
37 |
38 | Args:
39 | tool_name: The name of the tool to run. Depending on the implementation,
40 | this could either be the name of a function in the function registry, or
41 | could be another tool name that is indirectly mapped to such a function.
42 | tool_args: Position args to pass to the tool function.
43 | tool_kwargs: Keyword args to pass to the tool function.
44 |
45 | Returns:
46 | The return value of the tool function.
47 | """
48 | del tool_name, tool_args, tool_kwargs
49 | raise NotImplementedError(
50 | 'The implementation should be provided at runtime by calling `configure`'
51 | ' or `get_variant`. This function cannot be called directly.'
52 | )
53 |
54 |
55 | @executing.make_executable # pytype: disable=wrong-arg-types
56 | async def default_run_tool(
57 | tool_name: str, tool_args: tuple[Any, ...], tool_kwargs: dict[str, Any]
58 | ) -> Any:
59 | """Default implementation of run_tool which calls function_registry."""
60 | if tool_name == constants.ERROR_STRING:
61 | # ERROR_STRING as the tool_name is a special case, where we are expected
62 | # to simply echo the error message stored in the tool argument.
63 | if len(tool_args) != 1:
64 | raise ValueError(
65 | 'When tool_name is ERROR_STRING, we expect there to be exactly one'
66 | ' argument containing the detailed error message (e.g., an error'
67 | ' that occurred when parsing the LLM response to determine the'
68 | f' tool call). Instead found {len(tool_args)} arguments:'
69 | f' {tool_name=}, {tool_args=}, {tool_kwargs=}'
70 | )
71 | return tool_args[0]
72 |
73 | try:
74 | tool_function = routing.function_registry.get(tool_name)
75 | if tool_function is None:
76 | raise ValueError(
77 | f'Function {tool_name} is not registered in the function_registry'
78 | f' ({routing.function_registry=}).'
79 | )
80 | return await utils.call_and_maybe_await(
81 | tool_function, *tool_args, **tool_kwargs
82 | )
83 |
84 | except ValueError as e:
85 | return f'{constants.ERROR_STRING}: {e}'
86 |
87 |
88 | def reset_defaults():
89 | """Resets default implementations for all builtins in this file."""
90 | # Keep all module level `some_builtin.configure(...)` commands in this method.
91 | run_tool.configure(default_run_tool)
92 |
93 | reset_defaults()
94 | # TODO: Define an additional `run_tool_text` builtin that takes as
95 | # input:
96 | # (1) a string like `f('a', 'b')` or `f(x, 'b')` or `y = f(x, 'b')` that can be
97 | # parsed into a tool call (possibly with variable references and/or variable
98 | # assignment)
99 | # (2) a context represented as a dictionary of variable values
100 | # (e.g., {'x': 'a'})
101 | # and which returns as output a string containing the tool's return value or
102 | # error message in a form appropriate for presenting to an LLM in a prompt,
103 | # along with a dictionary of context updates in the case of a variable
104 | # assignment (e.g., {'y': 'ab'})). This `run_tool_text` builtin can further be
105 | # wrapped as a composable or as a callback in a Jinja prompt template.
106 |
--------------------------------------------------------------------------------
/onetwo/backends/onetwo_api_manual_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Sequence
16 | import os
17 | import pprint
18 | import time
19 |
20 | from absl import app
21 | from absl import flags
22 | from onetwo.backends import onetwo_api
23 | from onetwo.builtins import llm
24 | from onetwo.core import executing
25 | from onetwo.core import sampling
26 |
27 | _ENDPOINT = flags.DEFINE_string(
28 | 'endpoint',
29 | default='http://localhost:9876',
30 | help='Endpoint to use.',
31 | )
32 | _CACHE_DIR = flags.DEFINE_string(
33 | 'cache_dir',
34 | default='/tmp/onetwo_api',
35 | help='Directory where the cache will be stored.',
36 | )
37 | _LOAD_CACHE = flags.DEFINE_bool(
38 | 'load_cache_file',
39 | default=False,
40 | help='Whether we should read the cache stored in file.'
41 | )
42 | _PRINT_DEBUG = flags.DEFINE_bool(
43 | 'print_debug',
44 | default=False,
45 | help='Debug logging.'
46 | )
47 |
48 |
49 | def main(argv: Sequence[str]) -> None:
50 | del argv
51 | fname = os.path.join(_CACHE_DIR.value, 'onetwo_api.json')
52 | backend = onetwo_api.OneTwoAPI(
53 | endpoint=_ENDPOINT.value,
54 | cache_filename=fname,
55 | batch_size=4,
56 | )
57 | backend.register()
58 | if _LOAD_CACHE.value:
59 | print('Loading cache from file %s', fname)
60 | load_start = time.time()
61 | backend.load_cache()
62 | load_end = time.time()
63 | print('Spent %.4fsec loading cache.' % (load_end - load_start))
64 |
65 | print('1. A single generate query.')
66 | prompt_text = """
67 | Question: Natural logarithm of $e^12$?
68 | Reasoning: "Natural" logarithm means logarithm to the base of $e$. For
69 | example, natural logarithm of $10$ means exponent to which $e$ must be
70 | raised to produce 10.
71 | Answer: 12.
72 | Question: Differentiate $\\frac{1}{\\log(x)}$.
73 | """
74 | res = executing.run(llm.generate_text( # pytype: disable=wrong-keyword-args
75 | prompt=prompt_text,
76 | stop=['\n\n'],
77 | ))
78 | if _PRINT_DEBUG.value:
79 | print('Returned value(s):')
80 | pprint.pprint(res)
81 | print('1.1 Same query to see if it has been cached.')
82 | res = executing.run(llm.generate_text( # pytype: disable=wrong-keyword-args
83 | prompt=prompt_text,
84 | stop=['\n\n'],
85 | ))
86 | if _PRINT_DEBUG.value:
87 | print('Returned value(s):')
88 | pprint.pprint(res)
89 | print('1.2 Same query but different parameters, run requests again.')
90 | res = executing.run(llm.generate_text( # pytype: disable=wrong-keyword-args
91 | prompt=prompt_text,
92 | temperature=0.,
93 | stop=['\n\n'],
94 | ))
95 | if _PRINT_DEBUG.value:
96 | print('Returned value(s):')
97 | pprint.pprint(res)
98 | print('2. Repeated generate request.')
99 | exe = executing.par_iter(sampling.repeat(
100 | executable=llm.generate_text( # pytype: disable=wrong-keyword-args
101 | prompt='Today is', temperature=0.5, stop=['.']),
102 | num_repeats=5,
103 | ))
104 | res = executing.run(exe)
105 | if _PRINT_DEBUG.value:
106 | print('Returned value(s):')
107 | pprint.pprint(res)
108 | print('3. Three batched generate queries.')
109 | exe = executing.par_iter([
110 | llm.generate_text(prompt='In summer', stop=['.']), # pytype: disable=wrong-keyword-args
111 | llm.generate_text(prompt='In winter', max_tokens=32), # pytype: disable=wrong-keyword-args
112 | llm.generate_text(prompt='In autumn'), # pytype: disable=wrong-keyword-args
113 | ])
114 | res = executing.run(exe)
115 | if _PRINT_DEBUG.value:
116 | print('Returned value(s):')
117 | pprint.pprint(res)
118 |
119 | if not _LOAD_CACHE.value:
120 | start = time.time()
121 | backend.save_cache(overwrite=True)
122 | print('Took %.4fsec saving cache to %s.' % (time.time() - start, fname))
123 |
124 | print('PASS')
125 |
126 |
127 | if __name__ == '__main__':
128 | app.run(main)
129 |
--------------------------------------------------------------------------------
/onetwo/core/core_test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for OneTwo unit tests."""
16 |
17 | import collections
18 | from collections.abc import Mapping
19 | import io
20 | import json
21 | import pprint
22 | from typing import Any, Literal
23 | import unittest
24 | from onetwo.core import results
25 | import PIL.Image
26 |
27 |
28 | class CounterAssertions(unittest.TestCase):
29 | """Mixin class for counter assertions."""
30 |
31 | # pylint: disable=invalid-name
32 | def assertCounterEqual(
33 | self,
34 | counter_first: collections.Counter[str],
35 | counter_second: collections.Counter[str],
36 | ) -> None:
37 | # Remove zero values.
38 | first = counter_first - collections.Counter()
39 | second = counter_second - collections.Counter()
40 | message = f'A - B contains: {pprint.pformat(first - second)}\n'
41 | message += f'B - A contains: {pprint.pformat(second - first)}'
42 | return self.assertEqual(dict(first), dict(second), message)
43 |
44 |
45 | def maybe_read_file(filepath: str) -> str:
46 | """Returns the contents of the file if it exists, or else empty string."""
47 | try:
48 | with open(filepath) as f:
49 | file_contents = f.read()
50 | return file_contents
51 | except IOError:
52 | return ''
53 |
54 |
55 | def maybe_read_json(filepath: str) -> Mapping[str, Any] | None:
56 | """Returns the file contents as JSON, or None if there is a problem."""
57 | file_contents = maybe_read_file(filepath)
58 | try:
59 | return json.loads(file_contents)
60 | except json.JSONDecodeError:
61 | return None
62 |
63 |
64 | class MockTimer:
65 | """Mock timer for use in unit tests."""
66 |
67 | def __init__(self):
68 | self._current_time = 0
69 |
70 | def __call__(self) -> float:
71 | self._current_time += 1
72 | return float(self._current_time)
73 |
74 |
75 | def reset_fields(
76 | er: list[results.ExecutionResult] | results.ExecutionResult,
77 | reset_values: Mapping[str, Any],
78 | ):
79 | """Recursively resset the given fields in the given execution result(s).
80 |
81 | Args:
82 | er: The ExecutionResult or list thereof to reset fields in.
83 | reset_values: Mapping field_name to the value to set when resetting.
84 |
85 | Returns:
86 | None.
87 | """
88 | if isinstance(er, list):
89 | for sub_er in er:
90 | reset_fields(sub_er, reset_values)
91 | return
92 |
93 | for name, value in reset_values.items():
94 | setattr(er, name, value)
95 |
96 | for stage in er.stages:
97 | reset_fields(stage, reset_values)
98 |
99 |
100 | def reset_times(
101 | er: list[results.ExecutionResult] | results.ExecutionResult,
102 | ) -> None:
103 | """Resets the start and end times in the given execution result(s)."""
104 | reset_fields(er, {'start_time': 0.0, 'end_time': 0.0})
105 |
106 |
107 | def create_test_pil_image(
108 | fmt: Literal['PNG', 'JPEG', 'GIF'] = 'PNG',
109 | mode: Literal['RGB', 'RGBA', 'L', 'P'] = 'RGB',
110 | size: tuple[int, int] = (1, 1),
111 | ) -> PIL.Image.Image:
112 | """Returns a PIL Image in the given format, for testing purposes.
113 |
114 | Args:
115 | fmt: The target image format.
116 | mode: The image mode.
117 | size: The image dimensions (width, height).
118 | """
119 | img = PIL.Image.new(mode, size)
120 | buffer = io.BytesIO()
121 | # GIF requires saving the palette.
122 | save_kwargs = (
123 | {'format': fmt, 'save_all': True}
124 | if fmt == 'GIF'
125 | else {'format': fmt}
126 | )
127 | try:
128 | # Save and reload the image to attempt to set the format attribute.
129 | img.save(buffer, **save_kwargs)
130 | buffer.seek(0)
131 | reloaded_img = PIL.Image.open(buffer)
132 | if reloaded_img.format is None:
133 | # PIL might not always preserve format. Inject it for testing purpose.
134 | # This might happen for simple modes like 'L' depending on PIL version.
135 | reloaded_img.format = fmt
136 | return reloaded_img
137 | except Exception as e:
138 | raise ValueError(
139 | f'Failed to create test image with format {fmt}: {e}'
140 | ) from e
141 |
--------------------------------------------------------------------------------
/onetwo/backends/openai_mock.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A module that mocks the OpenAI library, for testing purposes.
16 |
17 | This document explains how to use the chat completion method:
18 | https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models
19 | """
20 |
21 | from collections.abc import Mapping, Sequence
22 | from typing import Final, NamedTuple, TypeVar
23 |
24 | import pydantic
25 |
26 |
27 | _DEFAULT_REPLY: Final[str] = 'Hello'
28 | _DEFAULT_SCORE: Final[float] = -0.1
29 | _T = TypeVar('_T', bound=pydantic.BaseModel)
30 |
31 |
32 | class OpenAIMessage(NamedTuple):
33 | role: str
34 | content: str
35 | parsed: pydantic.BaseModel | None = None
36 | refusal: str | None = None
37 |
38 |
39 | class ChatCompletionTokenLogprob(NamedTuple):
40 | token: str
41 | logprob: float
42 |
43 |
44 | class ChoiceLogProbs(NamedTuple):
45 | content: Sequence[ChatCompletionTokenLogprob]
46 |
47 |
48 | class Choice(NamedTuple):
49 | index: int
50 | message: OpenAIMessage
51 | logprobs: ChoiceLogProbs
52 | finish_reason: str
53 |
54 |
55 | class ChatCompletion(NamedTuple):
56 | choices: Sequence[Choice]
57 |
58 |
59 | class OpenAI:
60 | """A mock of the OpenAI class that provides a chat.completions.create method.
61 |
62 | See
63 | https://github.com/openai/openai-python/blob/f0bdef04611a24ed150d19c4d180aacab3052704/src/openai/_client.py#L49
64 | """
65 |
66 | def __init__(self, api_key: str | None = None):
67 | del api_key
68 | self._reply = _DEFAULT_REPLY
69 |
70 | @property
71 | def chat(self):
72 | return self
73 |
74 | @property
75 | def completions(self):
76 | return self
77 |
78 | def create(
79 | self,
80 | model: str,
81 | messages: Sequence[Mapping[str, str]],
82 | **kwargs,
83 | ) -> ChatCompletion:
84 | del model, messages
85 | samples = 1
86 | if 'n' in kwargs:
87 | samples = kwargs['n']
88 | return ChatCompletion(
89 | choices=[
90 | Choice(
91 | index=i,
92 | message=OpenAIMessage(role='assistant', content=self._reply),
93 | logprobs=ChoiceLogProbs(
94 | content=[
95 | ChatCompletionTokenLogprob(
96 | token='a', logprob=_DEFAULT_SCORE
97 | )
98 | ]
99 | ),
100 | finish_reason='stop',
101 | )
102 | for i in range(samples)
103 | ]
104 | )
105 |
106 | def parse(
107 | self,
108 | *,
109 | model: str,
110 | messages: Sequence[Mapping[str, str]],
111 | response_format: type[_T],
112 | **kwargs,
113 | ) -> ChatCompletion:
114 | """Mocks client.chat.completions.parse for structured outputs."""
115 | del model, messages, kwargs
116 | parsed_object = None
117 | refusal_message = None
118 |
119 | if isinstance(response_format, type) and issubclass(
120 | response_format, pydantic.BaseModel
121 | ):
122 | try:
123 | mock_data = {
124 | 'name': 'Mock Test City',
125 | 'population': 500000,
126 | }
127 | # Create a validated instance from the mock data
128 | parsed_object = response_format(**mock_data)
129 | except Exception as e: # pylint: disable=broad-except
130 | refusal_message = f'Mock failed to construct {response_format}: {e}'
131 | else:
132 | refusal_message = (
133 | '`response_format` must be a pydantic.BaseModel subclass for '
134 | f'structured parsing, but got {response_format}.'
135 | )
136 |
137 | mock_message = OpenAIMessage(
138 | role='assistant',
139 | content=self._reply,
140 | parsed=parsed_object,
141 | refusal=refusal_message,
142 | )
143 | return ChatCompletion(
144 | choices=[
145 | Choice(
146 | index=0,
147 | message=mock_message,
148 | logprobs=ChoiceLogProbs(
149 | content=[
150 | ChatCompletionTokenLogprob(
151 | token='a', logprob=_DEFAULT_SCORE
152 | )
153 | ]
154 | ),
155 | finish_reason='stop',
156 | )
157 | ]
158 | )
159 |
--------------------------------------------------------------------------------
/onetwo/backends/model_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """OneTwo Model Server."""
16 |
17 | import logging
18 | import sys
19 | import traceback
20 | from typing import Any
21 |
22 | import fastapi
23 | from onetwo.builtins import llm
24 | from onetwo.core import batching
25 |
26 |
27 | _Body = fastapi.Body
28 |
29 |
30 | def _get_http_exception(exception: Exception) -> Exception:
31 | error_message = 'An exception {} occurred. Arguments: {}.'.format(
32 | type(exception).__name__, exception.args
33 | )
34 | logging.info(
35 | '%s\\nTraceback: %s', error_message, traceback.format_exc()
36 | )
37 | # Converts all exceptions to HTTPException.
38 | return fastapi.HTTPException(status_code=500, detail=error_message)
39 |
40 |
41 | class ModelServer:
42 | """Model server that wraps llm builtins and exposes them as API calls.
43 |
44 | See run_model_server.py for an example of how to use this class.
45 | """
46 |
47 | def __init__(self):
48 | """Initializes a fastapi application and sets the configs."""
49 |
50 | logging.basicConfig(
51 | format='%(asctime)s: %(message)s',
52 | datefmt='%m/%d/%Y %I:%M:%S %p',
53 | level=logging.INFO,
54 | stream=sys.stdout,
55 | )
56 |
57 | # We disable batching on the server (it may still be enabled on the client).
58 | # This is specific to the way we run onetwo in an async application. In
59 | # general, in client applications it is best to call onetwo.run().
60 | batching._enable_batching.set(False)
61 |
62 | self._app = fastapi.FastAPI()
63 |
64 | @self._app.post('/tokenize')
65 | async def _tokenize(
66 | content: str = _Body(..., embed=True) # Ellipsis: required parameter.
67 | ) -> list[int]:
68 | """Wraps llm.tokenize."""
69 | # We disable batching on the server (it may still be enabled on the
70 | # client). This is specific to the way we run onetwo in an async
71 | # application. In general, in client applications it is best to call
72 | # onetwo.run().
73 | batching._enable_batching.set(False) # pylint: disable=protected-access
74 |
75 | try:
76 | res = await llm.tokenize(content) # pytype: disable=wrong-arg-count
77 | except Exception as exception: # pylint: disable=broad-exception-caught
78 | raise _get_http_exception(exception) from exception
79 | return res
80 |
81 | @self._app.post('/generate_text')
82 | async def _generate_text(
83 | prompt: str = _Body(...), # Ellipsis: required parameter.
84 | temperature: float = _Body(default=None),
85 | max_tokens: int = _Body(default=None),
86 | include_details: bool = _Body(default=False),
87 | ) -> tuple[str, dict[str, Any]]:
88 | """Wraps llm.generate_text."""
89 | # We disable batching on the server (it may still be enabled on the
90 | # client). This is specific to the way we run onetwo in an async
91 | # application. In general, in client applications it is best to call
92 | # onetwo.run().
93 | batching._enable_batching.set(False) # pylint: disable=protected-access
94 | try:
95 | res = await llm.generate_text( # pytype: disable=wrong-keyword-args
96 | prompt=prompt,
97 | temperature=temperature,
98 | max_tokens=max_tokens,
99 | include_details=include_details,
100 | )
101 | except Exception as exception: # pylint: disable=broad-exception-caught
102 | raise _get_http_exception(exception) from exception
103 | if include_details:
104 | return res
105 | else:
106 | return res, {}
107 |
108 | @self._app.post('/count_tokens')
109 | async def _count_tokens(
110 | content: str = _Body(..., embed=True) # Ellipsis: required parameter.
111 | ) -> int:
112 | """Wraps llm.count_tokens."""
113 | # We disable batching on the server (it may still be enabled on the
114 | # client). This is specific to the way we run onetwo in an async
115 | # application. In general, in client applications it is best to call
116 | # onetwo.run().
117 | batching._enable_batching.set(False) # pylint: disable=protected-access
118 | try:
119 | res = await llm.count_tokens(content) # pytype: disable=wrong-arg-count
120 | except Exception as exception: # pylint: disable=broad-exception-caught
121 | raise _get_http_exception(exception) from exception
122 | return res
123 |
124 | self._app.add_api_route(
125 | path='/health',
126 | endpoint=self._health,
127 | methods=['GET'],
128 | )
129 |
130 | async def __call__(self, scope, receive, send):
131 | await self._app(scope, receive, send)
132 |
133 | def _health(self):
134 | """Executes a health check."""
135 | return {}
136 |
--------------------------------------------------------------------------------
/onetwo/agents/agents_test_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import TypeAlias
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 |
21 | from onetwo.agents import agents_base
22 | from onetwo.agents import agents_test_utils
23 | from onetwo.core import executing
24 |
25 |
26 | StringAgentState = agents_base.UpdateListState[str, str]
27 |
28 |
29 | _SU: TypeAlias = agents_base.ScoredUpdate[str]
30 | _ULS: TypeAlias = agents_base.UpdateListState[str, _SU]
31 |
32 |
33 | class DistributionAgentTest(parameterized.TestCase):
34 |
35 | @parameterized.named_parameters(
36 | ('single_end', {'hello': 1.0}, 'hello', {'$': 1.0}),
37 | ('impossible', {'hello': 1.0}, 'x', {'': 1.0}),
38 | ('single_start', {'hello': 1.0}, '', {'h': 1.0}),
39 | ('double_start', {'hello': 0.7, 'hallo': 0.3}, 'h', {'e': 0.7, 'a': 0.3}),
40 | ('double_end', {'hello': 0.7, 'hallo': 0.3}, 'hello', {'$': 1.0}),
41 | (
42 | 'triple',
43 | {'hello': 0.1, 'helly': 0.2, 'hallo': 0.2, 'world': 0.5},
44 | 'h',
45 | {'e': 0.6, 'a': 0.4},
46 | ),
47 | )
48 | def test_next_step_distribution(self, words, prefix, expected):
49 | async def wrapper(words: dict[str, float], prefix: str) -> list[_SU]:
50 | agent = agents_test_utils.DistributionAgentForTest(words)
51 | state = await agent.initialize_state(prefix) # pytype: disable=wrong-arg-count
52 | return await agent.get_next_step_distribution(state) # pytype: disable=wrong-arg-count
53 |
54 | res = executing.run(wrapper(words, prefix))
55 | res_as_dict = {
56 | scored_update.update: round(scored_update.score, 2)
57 | for scored_update in res
58 | }
59 | self.assertDictEqual(expected, res_as_dict)
60 |
61 | @parameterized.named_parameters(
62 | ('start', '', 1.0),
63 | ('impossible', 'x', 0.0),
64 | ('incomplete', 'h', 0.5),
65 | ('hello', 'hello', 0.3),
66 | ('hallo', 'hallo', 0.2),
67 | ('hello$', 'hello$', 0.1),
68 | )
69 | def test_score_state(self, state, expected):
70 | score = agents_test_utils.DistributionAgentForTest(
71 | {'hello': 0.1, 'hello_world': 0.2, 'hallo': 0.2, 'world': 0.5}
72 | ).score_state(state)
73 | self.assertEqual(
74 | expected,
75 | round(score, 2),
76 | )
77 |
78 | @parameterized.named_parameters(
79 | ('start', '', False),
80 | ('impossible', 'x', True),
81 | ('incomplete', 'h', False),
82 | ('hello', 'hello', False),
83 | ('hallo', 'hallo', False),
84 | ('hello$', 'hello$', True),
85 | )
86 | def test_is_finished(self, state, expected):
87 | self.assertEqual(
88 | expected,
89 | agents_test_utils.DistributionAgentForTest(
90 | {'hello': 0.1, 'hello_world': 0.2, 'hallo': 0.2, 'world': 0.5}
91 | ).is_finished(state),
92 | )
93 |
94 | @parameterized.named_parameters(
95 | ('empty', '', ['hello$', 'hallo$', 'hello_world$', 'world$']),
96 | ('h', 'h', ['hello$', 'hallo$', 'hello_world$']),
97 | ('he', 'he', ['hello$', 'hello_world$']),
98 | ('hello', 'hello', ['hello$', 'hello_world$']),
99 | ('hello$', 'hello$', ['hello$']),
100 | ('x', 'x', ['x']),
101 | )
102 | def test_execute(self, prefix, expected):
103 | agent = agents_test_utils.DistributionAgentForTest(
104 | {'hello': 0.1, 'hallo': 0.2, 'hello_world': 0.4, 'world': 0.3}
105 | )
106 | res = executing.run(agent(inputs=prefix)) # pytype: disable=wrong-keyword-args
107 | self.assertIn(res, expected)
108 |
109 |
110 | class StringAgentTest(parameterized.TestCase):
111 |
112 | def test_sample_next_step(self):
113 | agent = agents_test_utils.StringAgent(sequence=['a', 'b', 'c', 'd'])
114 | state = executing.run(agent.initialize_state(inputs='test')) # pytype: disable=wrong-keyword-args
115 |
116 | result = executing.run(
117 | agent.sample_next_step(state=state, num_candidates=2) # pytype: disable=wrong-keyword-args
118 | )
119 | with self.subTest('first_call_returns_first_steps_in_sequence'):
120 | self.assertEqual(result, ['a', 'b'])
121 |
122 | result = executing.run(
123 | agent.sample_next_step(state=state, num_candidates=2) # pytype: disable=wrong-keyword-args
124 | )
125 | with self.subTest('second_call_returns_next_steps_in_sequence'):
126 | self.assertEqual(result, ['c', 'd'])
127 |
128 | @parameterized.named_parameters(
129 | ('max_length_equals_seq_length', 3, ['a', 'a', 'b'], 'a a b'),
130 | ('max_length_greater_than_seq_length', 4, ['a', 'a', 'b'], 'a a b'),
131 | ('max_length_less_than_seq_length', 2, ['a', 'a', 'b'], 'a a'),
132 | )
133 | def test_execute(self, max_length, sequence, expected_result):
134 | agent = agents_test_utils.StringAgent(
135 | max_length=max_length, sequence=sequence
136 | )
137 | result = executing.run(agent(inputs='test')) # pytype: disable=wrong-keyword-args
138 | self.assertEqual(expected_result, result)
139 |
140 |
141 | if __name__ == '__main__':
142 | absltest.main()
143 |
--------------------------------------------------------------------------------
/onetwo/agents/agents_test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Agent implementations for use in tests."""
16 |
17 | import dataclasses
18 | from typing import TypeAlias
19 |
20 | from onetwo.agents import agents_base
21 | from onetwo.agents import distribution
22 | from onetwo.core import executing
23 |
24 | _SU: TypeAlias = agents_base.ScoredUpdate[str]
25 | _ULS: TypeAlias = agents_base.UpdateListState[str, _SU]
26 |
27 |
28 | @dataclasses.dataclass
29 | class DistributionAgentForTest(
30 | distribution.DistributionAgent[str, str, str, _SU, None]
31 | ):
32 | """Given a distribution over strings, form a distribution over next chars."""
33 |
34 | distribution: dict[str, float] = dataclasses.field(default_factory=dict)
35 |
36 | @executing.make_executable(copy_self=False, non_copied_args=['environment'])
37 | async def initialize_state(
38 | self, inputs: str, environment: None = None
39 | ) -> str:
40 | """Overridden from base class (Agent)."""
41 | return inputs
42 |
43 | def extract_output(self, state: str) -> str:
44 | """Overridden from base class (Agent)."""
45 | return state
46 |
47 | def is_finished(self, state: str) -> bool:
48 | """Overridden from base class (Agent)."""
49 | # We reached a final state if we have the end-of-sequence token `$`.
50 | if state.endswith('$'):
51 | return True
52 | # But we also consider the state to be final if it cannot be reached.
53 | found = False
54 | for word in self.distribution:
55 | if word.startswith(state):
56 | found = True
57 | return not found
58 |
59 | def score_state(self, state: str) -> float:
60 | """Returns the probability of reaching this state from an empty state."""
61 | sum_probabilities = 0.0
62 | for word, prob in self.distribution.items():
63 | if (word + '$').startswith(state):
64 | sum_probabilities += prob
65 | return sum_probabilities
66 |
67 | @executing.make_executable(copy_self=False)
68 | async def get_next_step_distribution(
69 | self, state: str, environment: None = None
70 | ) -> list[_SU]:
71 | """Overridden from base class (DistributionAgent)."""
72 | if self.is_finished(state):
73 | # If we are in a final state, we return a distribution with score 1 for
74 | # the empty update (which leaves the state unchanged).
75 | return [agents_base.ScoredUpdate(update='', score=1.0)]
76 | next_letter_probs = {}
77 | for word, score in self.distribution.items():
78 | if word.startswith(state):
79 | if len(word) > len(state):
80 | next_letter = word[len(state)]
81 | else:
82 | next_letter = '$' # End of sequence token.
83 | if next_letter not in next_letter_probs:
84 | next_letter_probs[next_letter] = score
85 | else:
86 | next_letter_probs[next_letter] += score
87 | # Normalize the distribution.
88 | total_prob = sum(next_letter_probs.values())
89 | if total_prob > 0.0:
90 | for k in next_letter_probs:
91 | next_letter_probs[k] /= total_prob
92 | else:
93 | # If the total probability is 0, we assume we cannot continue hence the
94 | # next update is necessarily empty.
95 | return [agents_base.ScoredUpdate(update='', score=1.0)]
96 | # Convert the disrtibution to ScoredUpdates.
97 | return [
98 | agents_base.ScoredUpdate(update=k, score=v)
99 | for k, v in next_letter_probs.items()
100 | ]
101 |
102 |
103 | _StringAgentState: TypeAlias = agents_base.UpdateListState[str, str]
104 |
105 |
106 | @dataclasses.dataclass
107 | class StringAgent(
108 | agents_base.SingleSampleAgent[str, str, _StringAgentState, str, None]
109 | ):
110 | """Simple test agent, whose input / updates are strings.
111 |
112 | Its output is a concatenation of the update strings, separate by space.
113 |
114 | Attributes:
115 | max_length: Maximum length of the agent's state (i.e., of its update list).
116 | If specified, then will finish when this length is reached. If None, then
117 | will by default run forever.
118 | sequence: A sequence of strings to be used by the agent to produce samples.
119 | """
120 |
121 | max_length: int = 5
122 | sequence: list[str] = dataclasses.field(default_factory=list)
123 |
124 | @executing.make_executable(copy_self=False, non_copied_args=['environment'])
125 | async def initialize_state(
126 | self, inputs: str, environment: None = None
127 | ) -> _StringAgentState:
128 | return _StringAgentState(inputs=inputs)
129 |
130 | def extract_output(self, state: _StringAgentState) -> str:
131 | """Overridden from base class (Agent)."""
132 | return ' '.join(state.updates)
133 |
134 | def is_finished(self, state: _StringAgentState) -> bool:
135 | """Overridden from base class (Agent)."""
136 | return len(state.updates) >= self.max_length or not self.sequence
137 |
138 | @executing.make_executable(copy_self=False)
139 | async def _sample_single_next_step(
140 | self, state: _StringAgentState, environment=None
141 | ) -> str:
142 | """Overridden from base class (SingleSampleAgent)."""
143 | return self.sequence.pop(0)
144 |
--------------------------------------------------------------------------------
/onetwo/builtins/composables.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Composable versions of builtin functions."""
16 |
17 | from collections.abc import Sequence
18 | from typing import Any, TypeVar
19 |
20 | from onetwo.builtins import llm
21 | from onetwo.builtins import prompt_templating
22 | from onetwo.core import composing
23 | from onetwo.core import content as content_lib
24 |
25 |
26 | _T = TypeVar('_T')
27 |
28 | store = composing.store
29 | section_start = composing.section_start
30 | section_end = composing.section_end
31 |
32 |
33 | @composing.make_composable # pytype: disable=wrong-arg-types
34 | def f(
35 | context: composing.Context, content: str, role: content_lib.RoleType = None
36 | ) -> content_lib.Chunk:
37 | """Composable version of the string.format function that uses context vars."""
38 | try:
39 | content = content.format(**context.variables)
40 | except KeyError as e:
41 | raise ValueError(
42 | f'Could not format {content} with context'
43 | f' {context.variables}, some variables were not found.'
44 | ) from e
45 | return content_lib.Chunk(content, role=role)
46 |
47 |
48 | @composing.make_composable # pytype: disable=wrong-arg-types
49 | def c(
50 | context: composing.Context, content: Any, role: content_lib.RoleType = None
51 | ) -> content_lib.ChunkList:
52 | """Composable chunk created from content."""
53 | del context
54 | if isinstance(content, content_lib.Chunk):
55 | if role:
56 | content.role = role
57 | return content_lib.ChunkList([content])
58 | elif isinstance(content, content_lib.ChunkList):
59 | if role:
60 | for chunk in content:
61 | chunk.role = role
62 | return content
63 | else:
64 | return content_lib.ChunkList([content_lib.Chunk(content, role=role)])
65 |
66 |
67 | @composing.make_composable # pytype: disable=wrong-arg-types
68 | async def j(
69 | context: composing.Context,
70 | template: str,
71 | name: str = 'JinjaTemplate',
72 | role: content_lib.RoleType = None,
73 | ) -> content_lib.Chunk:
74 | """Composable version of a jinja formatted prompt using context vars."""
75 | result = await prompt_templating.JinjaTemplateWithCallbacks(
76 | name=name, text=template
77 | ).render(**context.variables)
78 | # We extract the output variables from the template and store them into the
79 | # context.
80 | for key, value in result.items():
81 | if key != prompt_templating.PROMPT_PREFIX:
82 | context[key] = value
83 | return content_lib.Chunk(result[prompt_templating.PROMPT_PREFIX], role=role)
84 |
85 |
86 | @composing.make_composable # pytype: disable=wrong-arg-types
87 | def generate_text(context: composing.Context, **kwargs) -> ...:
88 | """Composable version of the llm.generate_text function."""
89 | return llm.generate_text(context.prefix, **kwargs) # pytype: disable=wrong-arg-count
90 |
91 |
92 | @composing.make_composable # pytype: disable=wrong-arg-types
93 | async def chat(
94 | context: composing.Context,
95 | **kwargs,
96 | ) -> content_lib.ChunkList:
97 | """Composable version of the llm.chat function."""
98 | return content_lib.ChunkList([
99 | content_lib.Chunk(
100 | await llm.chat(context.to_messages(), **kwargs), # pytype: disable=wrong-arg-count
101 | role=content_lib.PredefinedRole.MODEL,
102 | )
103 | ])
104 |
105 |
106 | @composing.make_join_composable # pytype: disable=wrong-arg-types
107 | async def select(
108 | context: composing.Context,
109 | options: Sequence[tuple[str, composing.Context]],
110 | ) -> tuple[str, composing.Context]:
111 | """Composable select function.
112 |
113 | Args:
114 | context: Context of execution.
115 | options: Sequence of options. Each option is a pair (text, context).
116 |
117 | Returns:
118 | The pair (text, context) that is selected among the possible options. For
119 | example, this could be the one with the highest score.
120 | """
121 | common_prefix = context.prefix
122 | text_options = [text for text, _ in options]
123 | value, index, _ = await llm.select( # pytype: disable=wrong-arg-count
124 | common_prefix, text_options, include_details=True
125 | )
126 | return value, options[index][1]
127 |
128 |
129 | @composing.make_composable # pytype: disable=wrong-arg-types
130 | def generate_object(
131 | context: composing.Context, cls: type[Any], **kwargs
132 | ) -> ...:
133 | """Composable version of the llm.generate_object function."""
134 | return llm.generate_object(context.prefix, cls, **kwargs) # pytype: disable=wrong-arg-count
135 |
136 |
137 | @composing.make_composable # pytype: disable=wrong-arg-types
138 | async def instruct(
139 | context: composing.Context, assistant_prefix: str | None = None, **kwargs
140 | ) -> ...:
141 | """Composable version of the llm.instruct function."""
142 | # TODO: We are awaiting llm.instruct which means we cannot
143 | # iterate through the result if we are doing streaming. We should instead
144 | # wrap this into an ExecutableWithPostprocessing which adds the
145 | # assistant_prefix, or change the prefix directly and call instruct with
146 | # the unchanged prefix (requires to deep-copy the prefix).
147 | result = await llm.instruct(context.prefix, assistant_prefix, **kwargs) # pytype: disable=wrong-arg-count
148 | if assistant_prefix is None:
149 | return result
150 | else:
151 | return assistant_prefix + result
152 |
--------------------------------------------------------------------------------
/onetwo/agents/critics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Sequence
16 | import dataclasses
17 | from typing import TypeAlias
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from onetwo.agents import agents_base
22 | from onetwo.agents import critics
23 | from onetwo.agents import distribution
24 | from onetwo.core import executing
25 |
26 |
27 | _SU: TypeAlias = agents_base.ScoredUpdate[str]
28 | _ULS: TypeAlias = agents_base.UpdateListState[str, _SU]
29 |
30 |
31 | class ScorerForTest(critics.ScoringFunction[str, float]):
32 |
33 | @executing.make_executable(copy_self=False)
34 | async def __call__(self, state: str, update: float) -> float:
35 | """Assigns as a score the update value."""
36 | return update
37 |
38 |
39 | class RankerForTest(critics.RankingFunction[str, float]):
40 |
41 | @executing.make_executable(copy_self=False)
42 | async def __call__(
43 | self, states_and_updates: Sequence[tuple[str, float]]
44 | ) -> Sequence[int]:
45 | """Returns the indices of the states sorted by the update value."""
46 | sorted_states_and_updates = sorted(
47 | enumerate(states_and_updates), key=lambda x: x[1][1], reverse=True
48 | )
49 | return [i for i, _ in sorted_states_and_updates]
50 |
51 |
52 | class SelectorForTest(critics.SelectingFunction[str, float]):
53 |
54 | @executing.make_executable(copy_self=False)
55 | async def __call__(
56 | self, states_and_updates: Sequence[tuple[str, float]]
57 | ) -> int:
58 | """Returns the index of the state with the highest update value."""
59 | return max(enumerate(states_and_updates), key=lambda x: x[1][1])[0]
60 |
61 |
62 | @dataclasses.dataclass
63 | class DistributionAgentForTest(
64 | distribution.DistributionAgent[
65 | str, str, str, agents_base.ScoredUpdate[str], None
66 | ]
67 | ):
68 | distribution: list[tuple[str, float]] = dataclasses.field(
69 | default_factory=list
70 | )
71 |
72 | @executing.make_executable(copy_self=False, non_copied_args=['environment'])
73 | async def initialize_state(self, inputs: str) -> str:
74 | """Overridden from base class (Agent)."""
75 | return inputs
76 |
77 | def extract_output(self, state: str) -> str:
78 | """Overridden from base class (Agent)."""
79 | return state
80 |
81 | def is_finished(self, state: str) -> bool:
82 | """Overridden from base class (Agent)."""
83 | return bool(state)
84 |
85 | @executing.make_executable(copy_self=False)
86 | async def get_next_step_distribution(
87 | self, state: str, environment: None = None
88 | ) -> list[agents_base.ScoredUpdate[str]]:
89 | """Overridden from base class (DistributionAgent)."""
90 | return [
91 | agents_base.ScoredUpdate(update=d[0], score=d[1])
92 | for d in self.distribution
93 | ]
94 |
95 |
96 | class CriticsTest(parameterized.TestCase):
97 |
98 | def test_scorer_to_ranker(self):
99 | scorer = ScorerForTest()
100 | converted_ranker = critics.ranker_from_scorer(scorer)
101 | ranker = RankerForTest()
102 | states_and_updates = [
103 | ('a', 1.0),
104 | ('b', 2.0),
105 | ('c', 3.0),
106 | ('d', 4.0),
107 | ('e', 5.0),
108 | ]
109 | res = executing.run(ranker(states_and_updates)) # pytype: disable=wrong-arg-count
110 | res2 = executing.run(converted_ranker(states_and_updates)) # pytype: disable=wrong-arg-count
111 | self.assertEqual(res, res2)
112 | self.assertEqual(res, [4, 3, 2, 1, 0])
113 |
114 | def test_selector_to_ranker(self):
115 | selector = SelectorForTest()
116 | converted_ranker = critics.ranker_from_selector(selector)
117 | ranker = RankerForTest()
118 | states_and_updates = [
119 | ('a', 1.0),
120 | ('b', 2.0),
121 | ('c', 3.0),
122 | ('d', 4.0),
123 | ('e', 5.0),
124 | ]
125 | res = executing.run(ranker(states_and_updates)) # pytype: disable=wrong-arg-count
126 | res2 = executing.run(converted_ranker(states_and_updates)) # pytype: disable=wrong-arg-count
127 | self.assertEqual(res, res2)
128 | self.assertEqual(res, [4, 3, 2, 1, 0])
129 |
130 | def test_score_from_update(self):
131 | """Use an agent with distribution to score its states."""
132 | distrib = [('a', 0.1), ('b', 0.3), ('c', 0.6)]
133 | dist_agent = DistributionAgentForTest(distribution=distrib)
134 | dist_map = {k: v for k, v in distrib}
135 | scoring_function = critics.ScoreFromUpdates()
136 | state = 'a'
137 | scores = []
138 | expected_scores = []
139 |
140 | async def wrapper():
141 | nonlocal scores
142 | updates = await dist_agent.sample_next_step(state=state, num_candidates=3) # pytype: disable=wrong-keyword-args
143 | for update in updates:
144 | scores.append(
145 | await scoring_function(state, update)
146 | )
147 | expected_scores.append(dist_map[update.update])
148 |
149 | executing.run(wrapper())
150 | self.assertEqual(scores, expected_scores)
151 |
152 | def test_score_from_update_list(self):
153 | """Use an agent with distribution to score its states."""
154 | scoring_function = critics.ScoreFromUpdateList()
155 | state = _ULS('', [_SU('a', 0.1), _SU('b', 0.3), _SU('c', 0.6)])
156 | update = _SU('d', 0.1)
157 |
158 | res = executing.run(scoring_function(state, update))
159 | self.assertEqual(res, 1.1)
160 |
161 |
162 | if __name__ == '__main__':
163 | absltest.main()
164 |
--------------------------------------------------------------------------------
/onetwo/agents/tasks/game_of_24_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from onetwo.agents import iterative_thought
17 | from onetwo.agents.tasks import game_of_24
18 | from onetwo.backends import backends_test_utils
19 | from onetwo.core import executing
20 |
21 |
22 | # Default reply for LLMForTest to return when it receives a prompt that it was
23 | # not expecting.
24 | DEFAULT_REPLY = 'UNKNOWN_PROMPT'
25 |
26 |
27 | class GameOf24Test(absltest.TestCase):
28 |
29 | def test_game_of_24_iterative_thought_prompt(self):
30 | # Some typical Game of 24 prompt inputs.
31 | description = game_of_24.GAME_OF_24_DESCRIPTION
32 | few_shots = game_of_24.GAME_OF_24_EXAMPLES
33 | state = iterative_thought.IterativeThoughtState(
34 | inputs='8 12 4 5', updates=['thought1']
35 | )
36 | prompt = iterative_thought.IterativeThoughtPromptJ2()
37 |
38 | # Now we define a test LLM to use in place of the actual LLM. To avoid the
39 | # need to hard-code here all of the expected requests and simulated replies,
40 | # however, we will just depend on a single default reply.
41 | llm_backend = backends_test_utils.LLMForTest(
42 | default_reply=DEFAULT_REPLY,
43 | default_score=0.0,
44 | )
45 | llm_backend.register()
46 |
47 | # Now we execute the prompt and verify that the prompt contained the
48 | # expected content. (Although we don't verify all of the prompt formatting,
49 | # these assertions should be sufficient to catch many basic bugs where we
50 | # omitted a for-loop, or failed to include some of the fields due to a typo,
51 | # etc.)
52 | next_step, result = executing.run(
53 | prompt(description=description, few_shots=few_shots, state=state),
54 | enable_tracing=True,
55 | )
56 | prefix = result.stages[0].outputs['prefix']
57 |
58 | with self.subTest('prompt_should_contain_the_task_description'):
59 | self.assertIn(description, prefix)
60 |
61 | with self.subTest('prompt_should_contain_the_exemplar_inputs'):
62 | self.assertIn(few_shots[0].inputs, prefix)
63 | self.assertIn(few_shots[-1].inputs, prefix)
64 |
65 | with self.subTest('prompt_should_contain_the_exemplar_steps_so_far'):
66 | self.assertIn(few_shots[0].updates[0], prefix)
67 | self.assertIn(few_shots[-1].updates[-1], prefix)
68 |
69 | with self.subTest('prompt_should_contain_the_actual_inputs'):
70 | self.assertIn(state.inputs, prefix)
71 |
72 | if state.updates:
73 | with self.subTest('prompt_should_contain_the_actual_steps_so_far'):
74 | self.assertIn(state.updates[0], prefix)
75 | self.assertIn(state.updates[-1], prefix)
76 |
77 | with self.subTest('should_return_the_llm_reply_as_next_step'):
78 | self.assertEqual(DEFAULT_REPLY, next_step)
79 |
80 | def test_game_of_24_iterative_thought_proposer_prompt(self):
81 | # Some typical Game of 24 propose prompt inputs.
82 | description = game_of_24.GAME_OF_24_DESCRIPTION
83 | few_shots = game_of_24.GAME_OF_24_PROPOSER_EXAMPLES
84 | state = iterative_thought.IterativeThoughtState(
85 | inputs='8 12 4 5', updates=['thought1']
86 | )
87 | prompt = iterative_thought.IterativeThoughtProposerPromptJ2()
88 |
89 | # Now we define a test LLM to use in place of the actual LLM. To avoid the
90 | # need to hard-code here all of the expected requests and simulated replies,
91 | # however, we will just depend on a single default reply.
92 | llm_backend = backends_test_utils.LLMForTest(
93 | default_reply='ta\ntb',
94 | default_score=0.0,
95 | )
96 | llm_backend.register()
97 |
98 | expected_next_steps = ['ta', 'tb']
99 |
100 | # Now we execute the prompt and verify that the prompt contained the
101 | # expected content. (Although we don't verify all of the prompt formatting,
102 | # these assertions should be sufficient to catch many basic bugs where we
103 | # omitted a for-loop, or failed to include some of the fields due to a typo,
104 | # etc.)
105 | next_steps, result = executing.run(
106 | prompt(description=description, few_shots=few_shots, state=state),
107 | enable_tracing=True,
108 | )
109 | prefix = result.stages[0].outputs['prefix']
110 |
111 | with self.subTest('prompt_should_contain_the_task_description'):
112 | self.assertIn(description, prefix)
113 |
114 | with self.subTest('prompt_should_contain_the_exemplar_inputs'):
115 | self.assertIn(few_shots[0].state.inputs, prefix)
116 | self.assertIn(few_shots[-1].state.inputs, prefix)
117 |
118 | with self.subTest('prompt_should_contain_the_exemplar_state'):
119 | if few_shots[0].state.updates:
120 | self.assertIn(few_shots[0].state.updates[0], prefix)
121 | if few_shots[-1].state.updates:
122 | self.assertIn(few_shots[-1].state.updates[-1], prefix)
123 |
124 | with self.subTest('prompt_should_contain_the_exemplar_next_steps'):
125 | self.assertIn(few_shots[0].next_steps[0], prefix)
126 | self.assertIn(few_shots[-1].next_steps[-1], prefix)
127 |
128 | with self.subTest('prompt_should_contain_the_actual_inputs'):
129 | self.assertIn(state.inputs, prefix)
130 |
131 | if state.updates:
132 | with self.subTest('prompt_should_contain_the_actual_steps_so_far'):
133 | self.assertIn(state.updates[0], prefix)
134 | self.assertIn(state.updates[-1], prefix)
135 |
136 | with self.subTest('should_return_the_parsed_llm_reply_as_next_steps'):
137 | self.assertEqual(expected_next_steps, next_steps)
138 |
139 | if __name__ == '__main__':
140 | absltest.main()
141 |
--------------------------------------------------------------------------------
/onetwo/stdlib/code_execution/python_execution_test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for OneTwo unit tests involving Python execution."""
16 |
17 | import collections
18 | from collections.abc import Mapping, Sequence
19 | import dataclasses
20 | import datetime
21 | from typing import Any, Callable
22 | import unittest
23 |
24 | from onetwo.core import executing
25 | from onetwo.stdlib.code_execution import python_execution
26 |
27 | # Aliases for brevity.
28 | _SandboxResult = python_execution.SandboxResult
29 |
30 |
31 | class SandboxResultAssertions(unittest.TestCase):
32 | """Mixin class for SandboxResult assertions."""
33 |
34 | # pylint: disable=invalid-name
35 | def assertSandboxResultEqualIgnoringTiming(
36 | self,
37 | expected_result: _SandboxResult,
38 | actual_result: _SandboxResult,
39 | ) -> None:
40 | # Remove timing-related content.
41 | expected_without_timing = dataclasses.replace(expected_result, timing=None)
42 | actual_without_timing = dataclasses.replace(actual_result, timing=None)
43 | self.assertEqual(
44 | expected_without_timing,
45 | actual_without_timing,
46 | 'SandboxResult differed by more than just'
47 | f' timing.\nExpected:\n{expected_result!r}\nActual:\n{actual_result!r}',
48 | )
49 |
50 |
51 | @dataclasses.dataclass
52 | class PythonSandboxForTest(python_execution.PythonSandbox):
53 | """Mock PythonSandbox.
54 |
55 | Attributes:
56 | hook_objects: Objects containing modifiable state of the hook functions.
57 | (Required as part of the PythonSandbox interface.)
58 | reply_by_request: Mapping from request to reply, or to a sequence of replies
59 | in case we want to return different results on the 1st call vs. the 2nd
60 | call, etc.
61 | default_reply: Default reply if not found in reply_by_request.
62 | requests: All `run` requests that were received, in the order received.
63 | (Same format as `unexpected_requests` below.)
64 | unexpected_requests: Requests that were not found in the corresponding
65 | mappings (i.e., requests for which we ended up falling back to returning
66 | the `default_reply`).
67 | """
68 |
69 | hook_objects: Mapping[str, Any] = dataclasses.field(default_factory=dict)
70 |
71 | # Attributes for controlling the replies to be returned.
72 | reply_by_request: Mapping[str, _SandboxResult | Sequence[_SandboxResult]] = (
73 | dataclasses.field(default_factory=dict)
74 | )
75 | default_reply: _SandboxResult = dataclasses.field(
76 | default_factory=_SandboxResult
77 | )
78 |
79 | # Attributes used for tracking the actual requests / replies (for assertions).
80 | requests: list[str] = dataclasses.field(init=False, default_factory=list)
81 | unexpected_requests: list[str] = dataclasses.field(
82 | init=False, default_factory=list
83 | )
84 |
85 | # Attributes used for tracking the actual requests / replies (internal).
86 | _num_run_calls_by_request: collections.Counter[str] = (
87 | dataclasses.field(init=False, default_factory=collections.Counter)
88 | )
89 |
90 | def is_stateful(self) -> bool:
91 | """See base class (PythonSandbox)."""
92 | return True
93 |
94 | @executing.make_executable # pytype: disable=wrong-arg-types
95 | def run(self, code: str) -> _SandboxResult:
96 | """See base class (PythonSandbox)."""
97 | self.requests.append(code)
98 |
99 | # By request.
100 | if code in self.reply_by_request:
101 | reply = self.reply_by_request[code]
102 | if isinstance(reply, str):
103 | # Single reply specified. Always return it.
104 | return reply
105 | else:
106 | # Sequence of replies specified. Return the next (until we run out).
107 | reply_index = self._num_run_calls_by_request[code]
108 | self._num_run_calls_by_request[code] += 1
109 | if reply_index < len(reply):
110 | return reply[reply_index]
111 |
112 | # Default.
113 | self.unexpected_requests.append(code)
114 | return self.default_reply
115 |
116 | async def set_variables(self, **variables: Any) -> None:
117 | """See base class (PythonSandbox)."""
118 | raise NotImplementedError()
119 |
120 | async def get_variables(self, *names: str) -> Mapping[str, Any]:
121 | """See base class (PythonSandbox)."""
122 | raise NotImplementedError()
123 |
124 | def get_hook_object(self, key: str) -> Any:
125 | """See base class (PythonSandbox)."""
126 | return self.hook_objects.get(key)
127 |
128 |
129 | @dataclasses.dataclass
130 | class PythonSandboxForTestFactory(python_execution.PythonSandboxFactory):
131 | """Mock PythonSandboxFactory that returns a hard-coded PythonSandboxForTest.
132 |
133 | Attributes:
134 | default_sandbox: Sandbox to be returned by default on calls to
135 | `create_sandbox`. Note that in order to satisfy behavior expectations of
136 | `PythonPlanningAgent`, the factory will automatically overwrite the
137 | `hook_objects` member of this default sandbox each time before returning
138 | it from `create_sandbox`. (This should be fine in cases where we only need
139 | to return one sandbox; for tests in which we expect multiple sandboxes to
140 | be created, however, we may need more fine-grained configuration
141 | controlling the sandboxes to return.)
142 | """
143 |
144 | default_sandbox: PythonSandboxForTest = dataclasses.field(
145 | default_factory=PythonSandboxForTest
146 | )
147 |
148 | def create_sandbox(
149 | self,
150 | *,
151 | timeout: datetime.timedelta = datetime.timedelta(seconds=10),
152 | imports: Sequence[str] | str = tuple(),
153 | hooks: Mapping[str, Callable[..., Any]] | None = None,
154 | hook_objects: Mapping[str, Any] | None = None,
155 | allow_restarts: bool = False,
156 | ) -> PythonSandboxForTest:
157 | # If necessary, we can add here more detailed configurations controlling the
158 | # sandbox to return, e.g., depending on the parameters specified, or on the
159 | # number of times that `create_sandbox` has been called. For now we just
160 | # reuse the single `default_sandbox`.
161 | self.default_sandbox.hook_objects = hook_objects or {}
162 | return self.default_sandbox
163 |
--------------------------------------------------------------------------------
/onetwo/builtins/prompt_templating.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library for executing a prompt template."""
16 |
17 | from collections.abc import Callable, Mapping
18 | import copy
19 | import dataclasses
20 | import functools
21 | from typing import Any, TypeVar
22 |
23 | import immutabledict
24 | from onetwo.builtins import callbacks
25 | from onetwo.core import executing
26 | from onetwo.core import templating
27 |
28 |
29 | _DRY_RUN_PREFIX_VAR = templating._DRY_RUN_PREFIX_VAR # pylint: disable=protected-access
30 | PROMPT_PREFIX = templating.PROMPT_PREFIX
31 |
32 | _T = TypeVar('_T')
33 | _EMPTY_MAP = immutabledict.immutabledict({})
34 |
35 |
36 | @dataclasses.dataclass
37 | class JinjaTemplateWithCallbacks(templating.JinjaTemplate):
38 | """A Jinja2 template augmented with additional LLM-specific callbacks."""
39 |
40 | def __post_init__(self):
41 | """See parent class."""
42 | super().__post_init__()
43 | self.register_callback('llm', callbacks.llm_callback, pass_context=True)
44 | # Also registering the llm callback as 'generate_text' for compatibility
45 | # with builtins and composables.
46 | # TODO: deprecate the llm version and make generate_text support
47 | # dry run.
48 | self.register_callback(
49 | 'generate_text', callbacks.generate_text, pass_context=True
50 | )
51 | self.register_callback(
52 | 'generate_texts', callbacks.generate_texts, pass_context=True
53 | )
54 | self.register_callback('choose', callbacks.choose, pass_context=True)
55 | self.register_callback(
56 | 'generate_object', callbacks.generate_object, pass_context=True
57 | )
58 |
59 | def _postprocess_iterable_reply(self, iterable_reply: Any) -> str:
60 | if isinstance(iterable_reply, tuple) and len(iterable_reply) == 2:
61 | # This is the case where we got a reply with details. We just use the
62 | # first part.
63 | # Note that this is specific to the generate_text reply so it may
64 | # not work for arbitrary streaming callbacks.
65 | return iterable_reply[0]
66 | else:
67 | return super()._postprocess_iterable_reply(iterable_reply)
68 |
69 | @executing.make_executable # pytype: disable=wrong-arg-types
70 | async def dry_run(
71 | self,
72 | inputs: Mapping[str, Any],
73 | llm_default_reply: str = 'default',
74 | generate_object_default_reply_map: Mapping[type[_T], _T] = _EMPTY_MAP,
75 | mock_callbacks: (
76 | list[tuple[str, Callable[[Any], Any], bool]] | None
77 | ) = None,
78 | ) -> Mapping[str, Any]:
79 | """Dry runs the prompt and returns prefixes sent to the LLM.
80 |
81 | This method renders the jinja2 prompt and gets the prefixes that
82 | are sent to the language model without actually executing the LLM requests.
83 | By default this method mocks the `llm` and `choose` operations.
84 | The mocked `llm` simply returns the `llm_default_reply` provided by the
85 | user, while the mocked `choose` chooses the first `top_k` provided options
86 | and populates scores with all `0.0`.
87 |
88 | The user can also mock other callbacks. This can be useful when there is a
89 | callback that sends LLM requests within itself, eg. by defining and
90 | executing its own PromptTemplate instance. In this case the callback
91 | functions provided by the user can store prefixes in the
92 | `_DRY_RUN_PREFIX_VAR` output variable.
93 |
94 | Args:
95 | inputs: Inputs to the prompt template.
96 | llm_default_reply: String that is used as a reply from `llm` operation.
97 | generate_object_default_reply_map: map from type to instance for default
98 | values calling generate_object
99 | mock_callbacks: List of (callback_name, callback_fn, pass_context) for
100 | additional callbacks to be mocked.
101 |
102 | Returns:
103 | A dict with at least one key `PROMPT_PREFIX` that contains the final
104 | rendered string of the entire prompt. If the template has `llm()` calls
105 | the dict contains an `llm` key that stores a list of string-valued
106 | rendered prefixes sent to the `llm()` operation. If template has
107 | `choose()` calls the dict contains a `choose` key that stores a list of
108 | string-valued rendered prefixes sent to the `choose()` operation.
109 | The output may contain other keys if the user provides additional
110 | mock callbacks that write to the `_DRY_RUN_PREFIX_VAR` output variable.
111 | """
112 | # To mock some of the callbacks we modify the `_callbacks` attribute
113 | # of the PromptTemplateJ2 instance. To avoid subtle issues (for example:
114 | # parallel async execution of multiple dry_run coroutines of the same
115 | # PromptTemplateJ2 instance) we create a copy of the instance.
116 | mock_prompt = copy.deepcopy(self)
117 | # By default we mock `choose` and `llm` operations.
118 | mock_prompt.register_callback(
119 | 'llm',
120 | functools.partial(
121 | callbacks.mock_llm_callback, default_reply=llm_default_reply
122 | ),
123 | pass_context=True,
124 | )
125 | mock_prompt.register_callback(
126 | 'choose', callbacks.mock_choose_callback, pass_context=True
127 | )
128 | mock_prompt.register_callback(
129 | 'generate_object',
130 | functools.partial(
131 | callbacks.mock_generate_object_callback,
132 | default_type_to_object_map=generate_object_default_reply_map,
133 | ),
134 | pass_context=True,
135 | )
136 | # We also mock any other callbacks provided by the user.
137 | if mock_callbacks is not None:
138 | for callback_name, callback_fn, callback_pass_context in mock_callbacks:
139 | mock_prompt.register_callback(
140 | name=callback_name,
141 | function=callback_fn,
142 | pass_context=callback_pass_context,
143 | )
144 | outputs = await mock_prompt.render(**inputs)
145 | try:
146 | result = outputs[_DRY_RUN_PREFIX_VAR]
147 | except KeyError as exc:
148 | raise ValueError(
149 | 'No DRY_RUN found in outputs. Please make sure that the template has'
150 | ' a llm() call.'
151 | ) from exc
152 | result[PROMPT_PREFIX] = outputs[PROMPT_PREFIX]
153 | return result
154 |
--------------------------------------------------------------------------------
/onetwo/agents/distribution_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Sequence
16 | import dataclasses
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from onetwo.agents import agents_base
21 | from onetwo.agents import distribution
22 | from onetwo.core import executing
23 |
24 |
25 | StringAgentState = agents_base.UpdateListState[str, str]
26 |
27 |
28 | @dataclasses.dataclass
29 | class StringDistributionAgent(
30 | distribution.DistributionAgent[
31 | str, str, StringAgentState, agents_base.ScoredUpdate[str], None
32 | ]
33 | ):
34 | """Simple test agent, whose input / updates are hard-coded strings.
35 |
36 | The update strings are of the form '1', '2', '3', etc.
37 |
38 | Its output is a concatenation of the update strings, separate by space.
39 |
40 | Attributes:
41 | max_length: Maximum length of the agent's state (i.e., of its update list).
42 | If specified, then will finish when this length is reached. If None, then
43 | will by default run forever.
44 | """
45 |
46 | end_of_sequence: int = 0
47 | vocab_size: int = 10
48 | distribution: list[tuple[str, float]] = dataclasses.field(
49 | default_factory=list
50 | )
51 |
52 | @executing.make_executable(copy_self=False, non_copied_args=['environment'])
53 | async def initialize_state(
54 | self, inputs: str, environment: None = None
55 | ) -> StringAgentState:
56 | return StringAgentState(inputs=inputs)
57 |
58 | def extract_output(self, state: StringAgentState) -> str:
59 | """Overridden from base class (Agent)."""
60 | return ' '.join(state.updates)
61 |
62 | def is_finished(self, state: StringAgentState) -> bool:
63 | """Overridden from base class (Agent)."""
64 | if not state.updates:
65 | return False
66 | return state.updates[-1] == self.end_of_sequence
67 |
68 | @executing.make_executable(copy_self=False)
69 | async def get_next_step_distribution(
70 | self, state: StringAgentState, environment: None = None
71 | ) -> list[agents_base.ScoredUpdate[str]]:
72 | if self.distribution:
73 | return [
74 | agents_base.ScoredUpdate(update=d[0], score=d[1])
75 | for d in self.distribution
76 | ]
77 | else:
78 | # Return a uniform distribution over the vocabulary.
79 | return [
80 | agents_base.ScoredUpdate(update=str(i), score=1.0 / self.vocab_size)
81 | for i in range(self.vocab_size)
82 | ]
83 |
84 |
85 | def _scored_updates_to_tuples(
86 | scored_updates: Sequence[agents_base.ScoredUpdate[str]],
87 | ) -> list[tuple[str, float]]:
88 | return [(s.update, s.score) for s in scored_updates]
89 |
90 |
91 | class DistributionAgentTest(parameterized.TestCase):
92 |
93 | def test_get_next_step_distribution(self):
94 | agent = StringDistributionAgent(vocab_size=10)
95 |
96 | with self.subTest('distribution_is_uniform'):
97 | dist = executing.run(
98 | agent.get_next_step_distribution( # pytype: disable=wrong-keyword-args
99 | state=agent.initialize_state(inputs='1') # pytype: disable=wrong-keyword-args
100 | )
101 | )
102 | self.assertSequenceEqual(
103 | [(str(i), 0.1) for i in range(10)], _scored_updates_to_tuples(dist)
104 | )
105 |
106 | with self.subTest('samples_are_correct'):
107 | samples = executing.run(
108 | agent.sample_next_step( # pytype: disable=wrong-keyword-args
109 | state=agent.initialize_state(inputs='1'), num_candidates=5 # pytype: disable=wrong-keyword-args
110 | )
111 | )
112 | samples = [s.update for s in samples]
113 | self.assertContainsSubset(set(samples), set(str(i) for i in range(10)))
114 |
115 | with self.subTest('samples_and_scores_are_correct'):
116 | samples_with_scores = executing.run(
117 | agent.sample_next_step( # pytype: disable=wrong-keyword-args
118 | state=agent.initialize_state(inputs='1'), num_candidates=5 # pytype: disable=wrong-keyword-args
119 | )
120 | )
121 | self.assertContainsSubset(
122 | set(_scored_updates_to_tuples(samples_with_scores)),
123 | set((str(i), 0.1) for i in range(10)),
124 | )
125 |
126 |
127 | class ReweightedDistributionAgentTest(parameterized.TestCase):
128 |
129 | @parameterized.named_parameters(
130 | ('top_k', None, None, 2, [('1', 0.62), ('2', 0.0), ('3', 0.38)]),
131 | ('top_p', None, 0.1, None, [('1', 1.0), ('2', 0.0), ('3', 0.0)]),
132 | ('top_k_top_p', None, 0.1, 2, [('1', 1.0), ('2', 0.0), ('3', 0.0)]),
133 | ('temp_0', 0.0, None, None, [('1', 1.0), ('2', 0.0), ('3', 0.0)]),
134 | ('temp_1', 1.0, None, None, [('1', 0.5), ('2', 0.2), ('3', 0.3)]),
135 | ('temp_5', 5.0, None, None, [('1', 0.37), ('2', 0.3), ('3', 0.33)]),
136 | ('temp_100', 100.0, None, None, [('1', 0.33), ('2', 0.33), ('3', 0.33)]),
137 | ('temp_1_top_k', 1.0, None, 2, [('1', 0.63), ('2', 0.0), ('3', 0.38)]),
138 | ('temp_0_top_p', 0.0, 0.5, None, [('1', 1.0), ('2', 0.0), ('3', 0.0)]),
139 | (
140 | 'temp_100_top_p',
141 | 100.0,
142 | 0.5,
143 | None,
144 | [('1', 0.5), ('2', 0.0), ('3', 0.5)],
145 | ),
146 | )
147 | def test_get_distribution(
148 | self,
149 | temperature: float,
150 | top_p: float,
151 | top_k: int,
152 | expected_distribution: list[tuple[str, float]],
153 | ):
154 | inner_agent = StringDistributionAgent(
155 | distribution=[('1', 0.5), ('2', 0.2), ('3', 0.3)]
156 | )
157 | outer_agent = distribution.ReweightedDistributionAgent(
158 | inner_agent=inner_agent,
159 | temperature=temperature,
160 | top_p=top_p,
161 | top_k=top_k,
162 | )
163 | state = outer_agent.initialize_state('test') # pytype: disable=wrong-arg-count
164 | result = executing.run(outer_agent.get_next_step_distribution(state=state)) # pytype: disable=wrong-keyword-args
165 | # We round off the distribution to make the comparison easier.
166 | result = [(s.update, round(s.score, 2)) for s in result]
167 | # Also to avoid issues when comparing floats, we convert everything to
168 | # strings.
169 | self.assertEqual(str(expected_distribution), str(result))
170 |
171 |
172 | if __name__ == '__main__':
173 | absltest.main()
174 |
--------------------------------------------------------------------------------
/onetwo/core/sampling_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pprint
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from onetwo.core import caching
20 | from onetwo.core import executing
21 | from onetwo.core import sampling
22 |
23 |
24 | @executing.make_executable # pytype: disable=wrong-arg-types
25 | async def process(request: str, suffix: str = '') -> tuple[str, str]:
26 | key = caching.context_sampling_key.get()
27 | return (request+suffix, key)
28 |
29 |
30 | class SamplingTest(parameterized.TestCase):
31 |
32 | def test_repeat(self):
33 | executable = process('test')
34 | executables = sampling.repeat(executable, 3)
35 |
36 | res = executing.run(executing.par_iter(executables))
37 |
38 | with self.subTest('different_sampling_keys'):
39 | self.assertEqual(res, [('test', ''), ('test', '1'), ('test', '2')])
40 |
41 | def test_repeat_multiple_times(self):
42 | executable = process('test')
43 | executables = sampling.repeat(executable, 3)
44 | res = list(executing.run(executing.par_iter(executables)))
45 | executables = sampling.repeat(executable, 3, start_index=3)
46 | res += executing.run(executing.par_iter(executables))
47 | executables = sampling.repeat(executable, 3, start_index=6)
48 | res += executing.run(executing.par_iter(executables))
49 | res_keys = [key for _, key in res]
50 |
51 | with self.subTest('different_sampling_keys'):
52 | self.assertEqual(res_keys, [''] + [str(i) for i in range(1, 9)])
53 |
54 | def test_nested_repeat(self):
55 | """We test the nesting of repeats, together with the dynamic keys."""
56 | # Whether we create the repeated executable before executing them (i.e.
57 | # statically) or while executing them (i.e. dynamically), the effect on
58 | # sampling keys should be the same.
59 |
60 | # Static creation
61 | static_executable = executing.serial(
62 | sampling.repeat_and_execute(process('test1'), 3),
63 | sampling.repeat_and_execute(process('test2'), 3),
64 | )
65 |
66 | # Dynamic creation
67 | @executing.make_executable # pytype: disable=wrong-arg-types
68 | async def dynamic_executable():
69 | return await executing.serial(
70 | sampling.repeat_and_execute(process('test1'), 3),
71 | sampling.repeat_and_execute(process('test2'), 3),
72 | )
73 |
74 | executable1 = sampling.repeat_and_execute(static_executable, 2)
75 | executable2 = sampling.repeat_and_execute(dynamic_executable(), 2)
76 | res1 = executing.run(executable1)
77 | res2 = executing.run(executable2)
78 |
79 | expected_results = [
80 | [
81 | [('test1', ''), ('test1', '1'), ('test1', '2')],
82 | [('test2', ''), ('test2', '1'), ('test2', '2')],
83 | ],
84 | [
85 | [('test1', '1#0'), ('test1', '1#1'), ('test1', '1#2')],
86 | [('test2', '1#0'), ('test2', '1#1'), ('test2', '1#2')],
87 | ],
88 | ]
89 |
90 | with self.subTest('static'):
91 | self.assertEqual(
92 | res1,
93 | expected_results,
94 | msg=pprint.pformat(res1),
95 | )
96 |
97 | with self.subTest('dynamic'):
98 | self.assertEqual(
99 | res2,
100 | expected_results,
101 | msg=pprint.pformat(res2),
102 | )
103 |
104 | def test_update(self):
105 | sample_size = 2
106 |
107 | def update_result(result, sample_id):
108 | return result[0], result[1], sample_id, sample_size
109 |
110 | executable = sampling.repeat_and_execute(
111 | process('test'), sample_size, update_result_fn=update_result
112 | )
113 | res = executing.run(executable)
114 | with self.subTest('should_update_results'):
115 | for i in range(sample_size):
116 | self.assertEqual(res[i], ('test', str(i) if i else '', i, sample_size))
117 |
118 | def test_streaming(self):
119 | executable = executing.par_iter(sampling.repeat(process('test'), 3))
120 |
121 | with executing.safe_stream(executable) as iterator:
122 | results = sum(iterator, start=executing.Update()).to_result()
123 |
124 | self.assertEqual(
125 | results,
126 | [('test', ''), ('test', '1'), ('test', '2')],
127 | )
128 |
129 | def test_repeat_sampler(self):
130 | sampler = sampling.Repeated(process)
131 | # Note that we pass in both a positional and a keyword argument to verify
132 | # that the sampler is correctly passing them through.
133 | res = executing.run(sampler('test', suffix='-a', num_samples=3)) # pytype: disable=wrong-arg-count
134 | expected = [('test-a', ''), ('test-a', '1'), ('test-a', '2')]
135 | self.assertEqual(expected, res, res)
136 |
137 | def test_round_robin_sampler(self):
138 | @executing.make_executable # pytype: disable=wrong-arg-types
139 | async def strategy1(request, **kwargs):
140 | return await process(f'{request}-1', **kwargs)
141 |
142 | @executing.make_executable # pytype: disable=wrong-arg-types
143 | async def strategy2(request, **kwargs):
144 | return await process(f'{request}-2', **kwargs)
145 |
146 | sampler = sampling.RoundRobin([
147 | sampling.Repeated(strategy1),
148 | sampling.Repeated(strategy2),
149 | ])
150 |
151 | # Note that we pass in both a positional and a keyword argument to verify
152 | # that the sampler is correctly passing them through.
153 | res = executing.run(sampler('test', suffix='-a', num_samples=5)) # pytype: disable=wrong-arg-count
154 | expected = [
155 | ('test-1-a', ''),
156 | ('test-2-a', '1'),
157 | ('test-1-a', '2'),
158 | ('test-2-a', '3'),
159 | ('test-1-a', '4'),
160 | ]
161 | with self.subTest('default_start_index_starts_at_0_with_first_sampler'):
162 | self.assertEqual(expected, res, res)
163 |
164 | # This time we omit the suffix kwarg for simplicity.
165 | res = executing.run(sampler(request='test', num_samples=5, start_index=1)) # pytype: disable=wrong-keyword-args
166 | expected = [
167 | ('test-2', '1'),
168 | ('test-1', '2'),
169 | ('test-2', '3'),
170 | ('test-1', '4'),
171 | ('test-2', '5'),
172 | ]
173 | with self.subTest('start_index_1_starts_at_1_with_second_sampler'):
174 | self.assertEqual(expected, res, res)
175 |
176 |
177 | if __name__ == '__main__':
178 | absltest.main()
179 |
--------------------------------------------------------------------------------
/onetwo/agents/tasks/game_of_24.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data config for applying IterativeThought to the Game of 24 problem."""
16 |
17 | import textwrap
18 |
19 | from onetwo.agents import iterative_thought
20 |
21 |
22 | GAME_OF_24_DESCRIPTION = textwrap.dedent("""\
23 | Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
24 | Each number can be only used once, but results can be used in subsequent
25 | steps.
26 | """)
27 |
28 |
29 | # Exemplars for use with IterativeThoughtAgent.
30 | GAME_OF_24_EXAMPLES = [
31 | iterative_thought.IterativeThoughtState(
32 | inputs='10 5 2 11',
33 | updates=[
34 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)',
35 | '10 5 22, we try 10 / 5 = 2 (remaining: 22 2)',
36 | '22 2, we try 22 + 2 = 24 (remaining: 24, success: we got 24)',
37 | ],
38 | ),
39 | iterative_thought.IterativeThoughtState(
40 | inputs='10 5 2 11',
41 | updates=[
42 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)',
43 | '10 5 22, we try 22 - 5 = 17 (remaining: 10 17)',
44 | '10 17, we try 10 + 17 = 27 '
45 | + '(remaining: 27, failure: we did not get 24)',
46 | ],
47 | ),
48 | iterative_thought.IterativeThoughtState(
49 | inputs='10 5 2 11',
50 | updates=[
51 | '10 5 2 11, we try 11 - 10 = 1 (remaining: 5 2 1)',
52 | '5 2 1, we try 2 + 1 = 3 (remaining: 5 3)',
53 | '5 3, we try 5 * 3 = 15 '
54 | + '(remaining: 15, failure: we did not get 24)',
55 | ],
56 | ),
57 | iterative_thought.IterativeThoughtState(
58 | inputs='2 6 5 3',
59 | updates=[
60 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)',
61 | '2 6 2, we try 2 * 6 = 12 (remaining: 2 12)',
62 | '2 12, we try 2 * 12 = 24 (remaining: 24, success: we got 24)',
63 | ],
64 | ),
65 | iterative_thought.IterativeThoughtState(
66 | inputs='2 6 5 3',
67 | updates=[
68 | '2 6 5 3, we try 3 * 2 = 6 (remaining: 5 6 6)',
69 | '5 6 6, we try 6 + 6 = 12 (remaining: 5 12)',
70 | '5 12, we try 12 + 5 = 17 '
71 | + '(remaining: 17, failure: we did not get 24)',
72 | ],
73 | ),
74 | iterative_thought.IterativeThoughtState(
75 | inputs='2 6 5 3',
76 | updates=[
77 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)',
78 | '2 6 2, we try 6 / 2 = 3 (remaining: 2 3)',
79 | '2 3, we try 2 * 3 = 6 '
80 | + '(remaining: 6, failure: we did not get 24)',
81 | ],
82 | ),
83 | iterative_thought.IterativeThoughtState(
84 | inputs='2 6 5 3',
85 | updates=[
86 | '2 6 5 3, we try 2 * 5 = 10 (remaining: 6 3 10)',
87 | '6 3 10, we try 6 * 10 = 60 (remaining: 3 60)',
88 | '3 60, we try 60 / 3 = 20 '
89 | + '(remaining: 20, failure: we did not get 24)',
90 | ],
91 | ),
92 | ]
93 |
94 |
95 | # Exemplars for use with IterativeThoughtProposerAgent.
96 | GAME_OF_24_PROPOSER_EXAMPLES = [
97 | iterative_thought.IterativeThoughtProposerExemplar(
98 | iterative_thought.IterativeThoughtState(
99 | inputs='10 5 2 11',
100 | updates=[],
101 | ),
102 | next_steps=[
103 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)',
104 | '10 5 2 11, we try 11 - 10 = 1 (remaining: 5 2 1)',
105 | ],
106 | ),
107 | iterative_thought.IterativeThoughtProposerExemplar(
108 | iterative_thought.IterativeThoughtState(
109 | inputs='10 5 2 11',
110 | updates=[
111 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)',
112 | ],
113 | ),
114 | next_steps=[
115 | '10 5 22, we try 10 / 5 = 2 (remaining: 22 2)',
116 | '10 5 22, we try 22 - 5 = 17 (remaining: 10 17)',
117 | ],
118 | ),
119 | iterative_thought.IterativeThoughtProposerExemplar(
120 | iterative_thought.IterativeThoughtState(
121 | inputs='10 5 2 11',
122 | updates=[
123 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)',
124 | '10 5 22, we try 10 / 5 = 2 (remaining: 22 2)',
125 | ],
126 | ),
127 | next_steps=[
128 | '22 2, we try 22 + 2 = 24 (remaining: 24, success: we got 24)',
129 | ],
130 | ),
131 | iterative_thought.IterativeThoughtProposerExemplar(
132 | iterative_thought.IterativeThoughtState(
133 | inputs='10 5 2 11',
134 | updates=[
135 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)',
136 | '10 5 22, we try 22 - 5 = 17 (remaining: 10 17)',
137 | ],
138 | ),
139 | next_steps=[
140 | '10 17, we try 10 + 17 = 27 '
141 | + '(remaining: 27, failure: we did not get 24)',
142 | ],
143 | ),
144 | iterative_thought.IterativeThoughtProposerExemplar(
145 | iterative_thought.IterativeThoughtState(
146 | inputs='2 6 5 3',
147 | updates=[],
148 | ),
149 | next_steps=[
150 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)',
151 | '2 6 5 3, we try 3 * 2 = 6 (remaining: 5 6 6)',
152 | '2 6 5 3, we try 2 * 5 = 10 (remaining: 6 3 10)',
153 | ],
154 | ),
155 | iterative_thought.IterativeThoughtProposerExemplar(
156 | iterative_thought.IterativeThoughtState(
157 | inputs='2 6 5 3',
158 | updates=[
159 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)',
160 | ],
161 | ),
162 | next_steps=[
163 | '2 6 2, we try 6 / 2 = 3 (remaining: 2 3)',
164 | '2 6 2, we try 2 * 6 = 12 (remaining: 2 12)',
165 | ],
166 | ),
167 | iterative_thought.IterativeThoughtProposerExemplar(
168 | iterative_thought.IterativeThoughtState(
169 | inputs='2 6 5 3',
170 | updates=[
171 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)',
172 | '2 6 2, we try 2 * 6 = 12 (remaining: 2 12)',
173 | ],
174 | ),
175 | next_steps=[
176 | '2 12, we try 2 * 12 = 24 (remaining: 24, success: we got 24)',
177 | ],
178 | ),
179 | iterative_thought.IterativeThoughtProposerExemplar(
180 | iterative_thought.IterativeThoughtState(
181 | inputs='2 6 5 3',
182 | updates=[
183 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)',
184 | '2 6 2, we try 6 / 2 = 3 (remaining: 2 3)',
185 | ],
186 | ),
187 | next_steps=[
188 | '2 3, we try 2 * 3 = 6 (remaining: 6, failure: we did not get 24)',
189 | ],
190 | ),
191 | ]
192 |
--------------------------------------------------------------------------------
/onetwo/core/updating.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library for objects representing updates of results.
16 |
17 | When executing a long computation, we may want to monitor its progress. The
18 | Update class and its subclasses allow one to receive periodic updates on the
19 | progress and compose these updates into a result representing everything done so
20 | far.
21 |
22 | Assuming we have an iterator over updates of an appropriate type (subclass of
23 | `Update` class with relevant implementation of `__add__` and `to_result`), we
24 | can compose a result in the following way:
25 |
26 | ```
27 | updates = Update()
28 | for update in process_yielding_updates:
29 | updates += update # `update` of type MyUpdate.
30 | print('Current result: ', updates.to_result())
31 |
32 | print('Final result', updates.to_result())
33 | ```
34 |
35 | Note that the + operator is overloaded, so one can also use this syntax:
36 |
37 | ```
38 | updates = Update()
39 | for update in process_yielding_updates:
40 | updates += update
41 | print('Current result: ', updates.to_result())
42 | ```
43 | Or one can use the `sum` operator as well:
44 |
45 | ```
46 | final_result = sum(process_yielding_updates, start=Update()).to_result()
47 | ```
48 | """
49 |
50 | from __future__ import annotations
51 |
52 | import dataclasses
53 | from typing import Generic, Protocol, TypeVar, cast
54 |
55 |
56 | _T = TypeVar('_T')
57 |
58 |
59 | class _Addable(Protocol):
60 | """Protocol for objects that can be added together with a + operator."""
61 |
62 | def __add__(self: _T, other: _T) -> _T: ...
63 |
64 |
65 | _Tadd = TypeVar('_Tadd', bound=_Addable)
66 |
67 |
68 | @dataclasses.dataclass
69 | class Update(Generic[_T]):
70 | """Updates that can be accumulated and used to get the result.
71 |
72 | Basic implementation where payload and result are of the same type.
73 |
74 | Attributes:
75 | payload: Content of the update.
76 | """
77 |
78 | payload: _T | None = None
79 |
80 | def __add__(self, other: Update[_T] | _T) -> Update[_T]:
81 | """Incorporate this update. By default we overwrite all previous updates."""
82 | if isinstance(other, Update):
83 | return other
84 | else:
85 | return Update(other)
86 |
87 | def to_result(self) -> _T | None:
88 | """Produce a final result from the accumulated updates."""
89 | if isinstance(self.payload, Update):
90 | return cast(Update[_T], self.payload).to_result()
91 | else:
92 | return self.payload
93 |
94 | def to_simplified_result(self) -> _T:
95 | """Produces a simplified result from the accumulated updates."""
96 | if isinstance(self.payload, Update):
97 | return cast(Update[_T], self.payload).to_simplified_result()
98 | else:
99 | return self.payload
100 |
101 |
102 | @dataclasses.dataclass
103 | class AddableUpdate(Generic[_Tadd], Update[_Tadd]):
104 | """Update class for adding objects together with a + operator.
105 |
106 | This could be used for lists, but unlike the ListUpdate the order is not
107 | necessarily preserved: if several processes yield such updates in parallel
108 | there is no guarantee that the final list will have retain the order of the
109 | processes for example.
110 | """
111 |
112 | payload: _Tadd | None = None
113 |
114 | def __add__(
115 | self, other: AddableUpdate[_Tadd] | _Tadd
116 | ) -> AddableUpdate[_Tadd]:
117 | """See base class."""
118 | if isinstance(other, AddableUpdate):
119 | self.payload += other.payload
120 | else:
121 | self.payload += other
122 | return self
123 |
124 |
125 | @dataclasses.dataclass
126 | class ListUpdate(Generic[_T], Update[_T]):
127 | """Update class for maintaining a list.
128 |
129 | End result is of type list[T]. Every intermediate update (payload) provides
130 | a list of values and indices where these values need to go in the final
131 | result.
132 |
133 | For instance, ListUpdate([('ab', 10), (Update('b'), 23)]) tells us that final
134 | result is of type list['str'] and that elements in the final result with
135 | indices 10 and 23 should be set to 'ab' and 'b' respectively.
136 |
137 | We can also maintain nested lists. For example:
138 | ListUpdate([
139 | (ListUpdate([32, 1]), 0),
140 | (ListUpdate(10, 2), 1),
141 | ])
142 | tells us that final result is of type list[list[int]] and after this update
143 | (assuming it is the only update we accumulated) the final result will be:
144 | [[None, 32], [None, None, 10]].
145 |
146 | Attributes:
147 | payload: List of tuples [value_or_update, index] that contain a value
148 | (possibly wrapped in an Update) together with index where this value
149 | should go in the final result.
150 | """
151 | payload: list[
152 | tuple[Update[_T] | _T, int]
153 | ] = dataclasses.field(default_factory=list)
154 |
155 | def __add__(self, other: ListUpdate[_T]) -> ListUpdate[_T]:
156 | """See base class."""
157 | for update_or_value, index in other.payload:
158 | # Indices accumulated so far.
159 | accumulated_indices = [u[1] for u in self.payload]
160 | if index in accumulated_indices:
161 | index_in_payload = [u[1] for u in self.payload].index(index)
162 | found_update_or_value = self.payload[index_in_payload][0]
163 | if isinstance(found_update_or_value, ListUpdate):
164 | # Nested list case.
165 | self.payload[index_in_payload] = (
166 | found_update_or_value + update_or_value, index
167 | )
168 | else:
169 | # If the content is not a nested list update we just use the value to
170 | # replace the current element.
171 | self.payload[index_in_payload] = (update_or_value, index)
172 | else:
173 | # If we never saw the index before, just append it to the payload.
174 | self.payload.append((update_or_value, index))
175 | return self
176 |
177 | def to_result(self) -> list[_T]:
178 | """See base class."""
179 | if not self.payload:
180 | return []
181 | largest_index = max([u[1] for u in self.payload])
182 | result = [None] * (largest_index + 1)
183 | for update_or_value, index in self.payload:
184 | if isinstance(update_or_value, Update):
185 | result[index] = update_or_value.to_result()
186 | else:
187 | result[index] = update_or_value
188 | return result
189 |
190 | def to_simplified_result(self) -> list[_T]:
191 | """See base class."""
192 | if not self.payload:
193 | return []
194 | largest_index = max([u[1] for u in self.payload])
195 | result = [None] * (largest_index + 1)
196 | for update_or_value, index in self.payload:
197 | if isinstance(update_or_value, Update):
198 | result[index] = update_or_value.to_simplified_result()
199 | else:
200 | result[index] = update_or_value
201 | return [r for r in result if r is not None]
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OneTwo
2 |
3 | TL;DR: OneTwo is a Python library designed to simplify interactions with large
4 | (language and multimodal) foundation models, primarily aimed at researchers
5 | in *prompting* and *prompting strategies*.
6 |
7 | Foundation Models are increasingly being used in complex scenarios with multiple
8 | back-and-forth interactions between the model and some traditional code
9 | (possibly generated by the model itself). This leads to the emergence of
10 | programs that combine ML models with traditional code. The goal of the OneTwo
11 | library is to enable the creation and execution of such programs.
12 | It is designed for researchers (and developers) who want to explore how to get
13 | the best out of foundational models in a situation where it is not necessarily
14 | possible to change their weights (i.e. perform fine-tuning).
15 |
16 | Some properties of OneTwo that are particularly impactful for researcher
17 | productivity include the following:
18 |
19 | - **Model-agnostic:** Provides a uniform API to access different models that
20 | can easily be swapped and compared.
21 | - **Flexible:** Supports implementation of arbitrarily complex computation
22 | graphs involving combinations of sequential and parallel operations,
23 | including interleaving of calls to foundation models and to other tools.
24 | - **Efficient:** Automatically optimizes request batching and other details of
25 | model server interactions under-the-hood for maximizing throughput, while
26 | allowing prompting strategies to be implemented straightforwardly, as if
27 | they were dealing with just single requests.
28 | - **Reproducible:** Automatically caches requests/replies for easy stop-and-go
29 | or replay of experiments.
30 |
31 | ### Features
32 | * **Uniform API**: OneTwo defines a set of primitives or so-called built-in
33 | functions representing the common ways to interact with foundation models.
34 | Since different open-source models or public APIs may have different
35 | capabilities and expose calls with different parameters, OneTwo attempts to
36 | provide a uniform way to access all of those, implementing some best-effort
37 | conversion to guarantee that a OneTwo program will run on any model.
38 | * **Convenient syntax**: Different syntaxes are proposed to enable writing
39 | prompts or sequences of prompts conveniently and in a modular fashion.
40 | One can use formatted strings or jinja templates or can assemble the prompt
41 | as a concatenation of (possibly custom-defined) function calls.
42 | * **Experimentation**: OneTwo offers functionality to easily run
43 | experiments on datasets and collect metrics, with automatic caching of the
44 | results so that minor modifications of the workflow can be replayed with
45 | minimal impact, without having to perform the same requests to the models,
46 | while at the same time being aware of the sampling behaviour (e.g. when
47 | different samples are needed for a given request).
48 | * **Tool use**: OneTwo supports different ways of calling tools from a model,
49 | including running model-generated Python code involving tool calls in a
50 | sandbox.
51 | * **Agents**: OneTwo provides abstractions for defining agents, i.e. functions
52 | that perform some complex action in a step-by-step manner, while updating some
53 | internal state. These agents can be combined naturally and used within
54 | generic optimization algorithms.
55 | * **Execution model**: The execution model takes care of the details of
56 | orchestrating the calls to the models and tools. It offers the following
57 | functionality:
58 | * **Asynchronous execution**: Thanks to the use of asynchronous execution, the
59 | actual calls to the models and tools can be performed in parallel. However,
60 | the user can define a complex workflow in an intuitive and functional manner
61 | without having to think of the details of the execution.
62 | * **Batching**: Similarly, whenever the backends support multiple
63 | simultaneous requests, or batched requests, the OneTwo library will leverage
64 | this and group the requests for maximizing throughput.
65 | * **Smart Caching**: In addition to making it easy to replay all or part of an
66 | experiment, the caching provided by the library handles the case of
67 | multiple random samples obtained from one model, keeping track of how many
68 | samples have been obtained for each request, so that no result is wasted
69 | and actual requests are minimized while the user is tuning their workflow.
70 |
71 | ## Quick start
72 |
73 | ### Installation
74 |
75 | You may want to install the package in a virtual environment, in which case you
76 | will need to start a virtual environment with the following command:
77 |
78 | ```shell
79 | python3 -m venv PATH_TO_DIRECTORY_FOR_VIRTUAL_ENV
80 | # Activate it.
81 | . PATH_TO_DIRECTORY_FOR_VIRTUAL_ENV/bin/activate
82 | ```
83 |
84 | Once you no longer need it, this virtual environment can be deleted with the
85 | following command:
86 |
87 | ```shell
88 | deactivate
89 | ```
90 |
91 | Install the package:
92 |
93 | ```shell
94 | pip install git+https://github.com/google-deepmind/onetwo
95 | ```
96 |
97 | To start using it, import it with:
98 |
99 | ```python
100 | from onetwo import ot
101 | ```
102 |
103 | ### Running unit tests
104 |
105 | In order to run the tests, first clone the repository:
106 |
107 | ```shell
108 | git clone https://github.com/google-deepmind/onetwo
109 | ```
110 |
111 | Then from the cloned directory you can invoke `pytest`:
112 |
113 | ```shell
114 | pytest onetwo/core
115 | ```
116 |
117 | However, doing `pytest onetwo` will not work as pytest collects all the test
118 | names without keeping their directory of origin so there may be name clashes, so
119 | you have to loop through the subdirectories.
120 |
121 |
122 | ## Tutorial and Documentation
123 |
124 | This
125 | [Colab](https://colab.research.google.com/github/google-deepmind/onetwo/blob/main/colabs/tutorial.ipynb)
126 | is a good starting point and demonstrates most of the features available in
127 | onetwo.
128 |
129 | Some background on the basic concepts of the library can be found here:
130 | [Basics](docs/basics.md).
131 |
132 | Some of the frequently asked questions are discussed here: [FAQ](docs/faq.md).
133 |
134 | ## Citing OneTwo
135 |
136 | To cite this repository:
137 |
138 | ```bibtex
139 | @software{onetwo2024github,
140 | author = {Olivier Bousquet and Nathan Scales and Nathanael Sch{\"a}rli and Ilya Tolstikhin},
141 | title = {{O}ne{T}wo: {I}nteracting with {L}arge {M}odels},
142 | url = {https://github.com/google-deepmind/onetwo},
143 | version = {0.3.0},
144 | year = {2024},
145 | }
146 | ```
147 |
148 | In the above BibTeX entry, names are in alphabetical order, the version number
149 | is intended to be the one returned by `ot.__version__` (i.e., the latest version
150 | mentioned in [version.py](version.py) and in the [CHANGELOG](CHANGELOG.md),
151 | and the year corresponds to the project's open-source release.
152 |
153 | ## License
154 |
155 | Copyright 2024 DeepMind Technologies Limited
156 |
157 | This code is licensed under the Apache License, Version 2.0 (the \"License\");
158 | you may not use this file except in compliance with the License. You may obtain
159 | a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
160 |
161 | Unless required by applicable law or agreed to in writing, software distributed
162 | under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR
163 | CONDITIONS OF ANY KIND, either express or implied. See the License for the
164 | specific language governing permissions and limitations under the License.
165 |
166 | ## Disclaimer
167 |
168 | This is not an official Google product.
--------------------------------------------------------------------------------
/onetwo/core/routing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Registry implementation.
16 |
17 | The execution code emits various kinds of requests and they should be
18 | processed by different types of backends.
19 | A backend may process several kinds of requests (e.g. LLM backends may process
20 | both `generate_text` and `score_text` requests).
21 | """
22 |
23 | from __future__ import annotations
24 | from collections.abc import Callable, MutableMapping, AsyncIterator
25 | import contextvars
26 | import copy
27 | import dataclasses
28 | from typing import Any, TypeAlias, TypeVar, final, Generic
29 |
30 | from onetwo.core import executing
31 | from onetwo.core import updating
32 |
33 | # This context variable contains the global function registry that can be used
34 | # anywhere in the code.
35 | # It should be accessed via the function_registry wrapper.
36 | _function_registry_var = contextvars.ContextVar[dict](
37 | 'registry', default=dict()
38 | )
39 |
40 |
41 | _RegistryEntry: TypeAlias = Callable[..., Any]
42 | _RegistryData: TypeAlias = tuple[dict[str, Any], dict[str, Any]]
43 | _T = TypeVar('_T')
44 |
45 |
46 | class RegistryReference:
47 | """Parent class to indicate that a Registry entry is a reference object.
48 |
49 | See routing._Registry.copy() for additional details.
50 |
51 | This is used to indicate when to actually copy the entry in the registry.
52 | Indeed, when creating a copy of the registry, we only do a shallow copy since
53 | the registered entries might be complex objects (e.g. a bound method that is
54 | bound to a complex object) that we don't want to copy.
55 | But we may also register some shallow objects that hold references to a
56 | method for example.
57 | So we use this parent class to indicate to _Registry.copy that it can
58 | copy it.
59 | """
60 |
61 |
62 | class _Registry(MutableMapping[str, _RegistryEntry]):
63 | """Registry to store the mapping between names and functions."""
64 |
65 | @executing.make_executable # pytype: disable=wrong-arg-types
66 | async def __call__(self, destination: str, *args, **kwargs) -> Any:
67 | # We call the registry function.
68 | result = _function_registry_var.get()[destination](*args, **kwargs)
69 | # If the result is an Executable, it will be executed, due to the
70 | # make_executable decorator which executes results by default.
71 | return result
72 |
73 | def __getitem__(self, key: str):
74 | if key not in _function_registry_var.get():
75 | raise KeyError(
76 | f'Key "{key}" not registered in registry:'
77 | f' {_function_registry_var.get().keys()}'
78 | )
79 | return _function_registry_var.get()[key]
80 |
81 | def __setitem__(self, key: str, item: _RegistryEntry):
82 | _function_registry_var.get()[key] = item
83 |
84 | def __delitem__(self, key: str):
85 | del _function_registry_var.get()[key]
86 |
87 | def __iter__(self):
88 | return iter(_function_registry_var.get())
89 |
90 | def __len__(self):
91 | return len(_function_registry_var.get())
92 |
93 | def __repr__(self):
94 | return repr(_function_registry_var.get())
95 |
96 |
97 | def _copy_also_references(registry: dict[str, Any]) -> dict[str, Any]:
98 | """Returns a copy of the registry.
99 |
100 | This performs a shallow copy, but the entries that are references, i.e.
101 | of type executing.RegistryReference will be copied.
102 | This allows in particular to get the expected behaviour when copying
103 | a registry that contains builtin functions that have been configured,
104 | as the configuration is stored in a reference object and needs to be copied.
105 |
106 | Args:
107 | registry: The registry to copy.
108 |
109 | Returns:
110 | A copy of the registry.
111 | """
112 | registry_copy = dict()
113 | for key, entry in registry.items():
114 | registry_copy[key] = (
115 | copy.copy(entry)
116 | if isinstance(entry, RegistryReference)
117 | else entry
118 | )
119 | return registry_copy
120 |
121 |
122 | function_registry = _Registry()
123 | config_registry = contextvars.ContextVar[dict](
124 | 'config_registry', default=dict()
125 | )
126 |
127 |
128 | def copy_registry() -> _RegistryData:
129 | return _copy_also_references(_function_registry_var.get()), copy.copy(
130 | config_registry.get()
131 | )
132 |
133 |
134 | def set_registry(registry: _RegistryData) -> None:
135 | _function_registry_var.set(registry[0])
136 | config_registry.set(registry[1])
137 |
138 |
139 | @dataclasses.dataclass
140 | class _RegistryDataWrapper(
141 | Generic[_T], executing.Executable[_T]
142 | ):
143 | """Wraps an executable with the registry data.
144 |
145 | Attributes:
146 | wrapped: Executable to be wrapped.
147 | registry: Registry to use when executing this Executable.
148 | """
149 |
150 | wrapped: executing.Executable[_T]
151 | registry: _RegistryData
152 |
153 | @final
154 | async def _aiterate(
155 | self, iteration_depth: int = 1
156 | ) -> AsyncIterator[updating.Update[_T]]:
157 | """Yields the intermediate values and calls the final_value_callback."""
158 | with RegistryContext(self.registry):
159 | it = self.wrapped.with_depth(iteration_depth).__aiter__()
160 | while True:
161 | with RegistryContext(self.registry):
162 | try:
163 | update = await it.__anext__()
164 | except StopAsyncIteration:
165 | break
166 | yield update
167 |
168 | @final
169 | async def _aexec(self) -> _T:
170 | """Iterate this value until done (including calling final_value_callback).
171 |
172 | Returns:
173 | The final value given by the AsyncIterator _inner().
174 | """
175 | with RegistryContext(self.registry):
176 | result = await self.wrapped
177 | return result
178 |
179 |
180 | def with_current_registry(
181 | executable: executing.Executable,
182 | ) -> _RegistryDataWrapper:
183 | """Wraps an executable and attaches the current registry to it."""
184 | registry = copy_registry()
185 | return _RegistryDataWrapper(executable, registry)
186 |
187 |
188 | def with_registry(
189 | executable: executing.Executable, registry: _RegistryData
190 | ) -> _RegistryDataWrapper:
191 | """Wraps an executable and attaches the given registry to it."""
192 | return _RegistryDataWrapper(executable, registry)
193 |
194 |
195 | @dataclasses.dataclass
196 | class RegistryContext:
197 | """Context Manager to update the registry locally.
198 |
199 | Attributes:
200 | registry: An optional registry to use in this context manager.
201 | If None is provided, makes a "local" copy of the function_registry
202 | that can be modified within the context.
203 | """
204 | registry: _RegistryData | None = None
205 |
206 | def __enter__(self):
207 | if self.registry is None:
208 | # We create a copy of the current function_registry.
209 | updated = _copy_also_references(_function_registry_var.get())
210 | updated_config = copy.copy(config_registry.get())
211 | else:
212 | updated, updated_config = self.registry
213 | # We replace the current function_registry by the copy.
214 | self._token = _function_registry_var.set(updated)
215 | self._token_config = config_registry.set(updated_config)
216 |
217 | def __exit__(self, exc_type, exc_val, exc_tb):
218 | # We reset the function_registry to its value prior to entering the
219 | # context manager.
220 | _function_registry_var.reset(self._token)
221 | config_registry.reset(self._token_config)
222 |
--------------------------------------------------------------------------------
/onetwo/core/executing_with_context_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 | import dataclasses
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from onetwo.core import executing
21 | from onetwo.core import executing_with_context
22 | from onetwo.core import updating
23 | from typing_extensions import override
24 |
25 |
26 | def _collect_stream(
27 | executable: executing.Executable,
28 | depth: int = 1,
29 | ) -> list[str]:
30 | res = []
31 | with executing.safe_stream(executable, iteration_depth=depth) as stream:
32 | updates = updating.Update[str]()
33 | for update in stream:
34 | updates += update
35 | result = updates.to_result()
36 | assert isinstance(result, str) or isinstance(result, list)
37 | res.append(result)
38 | return res
39 |
40 |
41 | @dataclasses.dataclass
42 | class ContextForTest:
43 | content: str = dataclasses.field(default_factory=str)
44 |
45 |
46 | @dataclasses.dataclass
47 | class ExecutableWithContextForTest(
48 | executing_with_context.ExecutableWithContext[ContextForTest, str]
49 | ):
50 | content: str = dataclasses.field(default_factory=str)
51 |
52 | @override
53 | def initialize_context(self, *args, **kwargs) -> ContextForTest:
54 | return ContextForTest(*args)
55 |
56 | @classmethod
57 | @override
58 | def wrap(cls, other: str) -> ExecutableWithContextForTest:
59 | return ExecutableWithContextForTest(content=other)
60 |
61 | @classmethod
62 | @override
63 | def get_result(cls, context: ContextForTest) -> str:
64 | return context.content
65 |
66 | @override
67 | @executing.make_executable # pytype: disable=wrong-arg-types
68 | async def execute(
69 | self,
70 | context: ContextForTest,
71 | ) -> str:
72 | # Add the content of this node to the context.
73 | context.content += self.content
74 | # Return the content between quotes.
75 | return f'"{self.content}"'
76 |
77 |
78 | @dataclasses.dataclass
79 | class SerialContextForTest:
80 | content: list[str] = dataclasses.field(default_factory=list)
81 |
82 |
83 | @dataclasses.dataclass
84 | class SerialExecutableWithContextForTest(
85 | executing_with_context.SerialExecutableWithContext[
86 | SerialContextForTest, list[str]
87 | ]
88 | ):
89 |
90 | @override
91 | @classmethod
92 | def empty_result(cls) -> list[str]:
93 | return []
94 |
95 | @override
96 | def initialize_context(self, *args, **kwargs) -> SerialContextForTest:
97 | return SerialContextForTest(list(*args))
98 |
99 | @classmethod
100 | @override
101 | def wrap(cls, other: str) -> SingleStepSerialExecutableWithContextForTest:
102 | return SingleStepSerialExecutableWithContextForTest(content=other)
103 |
104 | @classmethod
105 | @override
106 | def get_result(cls, context: SerialContextForTest) -> str:
107 | return ''.join(context.content)
108 |
109 |
110 | @dataclasses.dataclass
111 | class SingleStepSerialExecutableWithContextForTest(
112 | SerialExecutableWithContextForTest
113 | ):
114 | content: str = dataclasses.field(default_factory=str)
115 |
116 | @override
117 | @executing.make_executable # pytype: disable=wrong-arg-types
118 | def execute(self, context: ContextForTest) -> list[str]:
119 | # Add the content of this node to the context.
120 | context.content += self.content
121 | # Return the content between quotes.
122 | return [f'"{self.content}"']
123 |
124 | @override
125 | @executing.make_executable # pytype: disable=wrong-arg-types
126 | def iterate(
127 | self, context: ContextForTest, iteration_depth: int = 1
128 | ) -> list[str]:
129 | del iteration_depth
130 | # Add the content of this node to the context.
131 | context.content += self.content
132 | # Return the content between quotes.
133 | return [f'"{self.content}"']
134 |
135 |
136 | class ExecutingWithContextTest(parameterized.TestCase):
137 |
138 | def test_execute_simple(self):
139 | e = ExecutableWithContextForTest(content='hello')
140 | with self.subTest('run_directly'):
141 | res = executing.run(e)
142 | self.assertEqual(res, 'hello')
143 |
144 | with self.subTest('result_has_quotes'):
145 | res = executing.run(e.execute(ContextForTest()))
146 | self.assertEqual(res, '"hello"')
147 |
148 | with self.subTest('run_with_arguments'):
149 | res = executing.run(e('prefix '))
150 | self.assertEqual(res, 'prefix hello')
151 |
152 | with self.subTest('run_with_resetting'):
153 | res = executing.run(e())
154 | self.assertEqual(res, 'hello')
155 |
156 | with self.subTest('stream'):
157 | res = _collect_stream(e())
158 | self.assertListEqual(res, ['hello'])
159 |
160 | with self.subTest('stream_iterate'):
161 | res = _collect_stream(e.iterate(ContextForTest()))
162 | self.assertListEqual(res, ['"hello"'])
163 |
164 | def test_execute_serial(self):
165 | e = SerialExecutableWithContextForTest()
166 | e += 'hello '
167 | e += 'world'
168 |
169 | with self.subTest('run_directly'):
170 | res = executing.run(e)
171 | self.assertEqual(res, 'hello world')
172 |
173 | with self.subTest('result_is_a_list_with_quotes'):
174 | res = executing.run(e.execute(SerialContextForTest()))
175 | self.assertListEqual(res, ['"hello "', '"world"'])
176 |
177 | with self.subTest('stored_result_is_correct_after_run'):
178 | _ = executing.run(e)
179 | result_field = e._result
180 | assert isinstance(result_field, list)
181 | self.assertListEqual(result_field, ['"hello "', '"world"'])
182 |
183 | with self.subTest('stored_result_is_correct_after_stream'):
184 | _ = _collect_stream(e())
185 | self.assertListEqual(e._result, ['"hello "', '"world"'])
186 |
187 | with self.subTest('run_with_arguments'):
188 | res = executing.run(e(['prefix ']))
189 | self.assertEqual(res, 'prefix hello world')
190 |
191 | with self.subTest('result_with_arguments'):
192 | res = executing.run(e.execute(SerialContextForTest(['prefix '])))
193 | self.assertEqual(res, ['"hello "', '"world"'])
194 |
195 | with self.subTest('run_with_resetting'):
196 | res = executing.run(e())
197 | self.assertEqual(res, 'hello world')
198 |
199 | with self.subTest('stream_iterate_returns_lists'):
200 | res = _collect_stream(e.iterate(SerialContextForTest()))
201 | self.assertListEqual(res, [['"hello "'], ['"hello "', '"world"']])
202 |
203 | def test_right_addition(self):
204 | e = 'hello '
205 | e += SerialExecutableWithContextForTest()
206 | e += 'world'
207 |
208 | with self.subTest('run_directly'):
209 | res = executing.run(e)
210 | self.assertEqual(res, 'hello world')
211 |
212 | with self.subTest('result_is_a_list_with_quotes'):
213 | res = executing.run(e.execute(SerialContextForTest()))
214 | self.assertListEqual(res, ['"hello "', '"world"'])
215 |
216 | with self.subTest('run_with_arguments'):
217 | res = executing.run(e(['prefix ']))
218 | self.assertEqual(res, 'prefix hello world')
219 |
220 | with self.subTest('run_with_resetting'):
221 | res = executing.run(e())
222 | self.assertEqual(res, 'hello world')
223 |
224 | with self.subTest('stream_iterate'):
225 | res = _collect_stream(e.iterate(SerialContextForTest()))
226 | self.assertListEqual(res, [['"hello "'], ['"hello "', '"world"']])
227 |
228 |
229 | if __name__ == '__main__':
230 | absltest.main()
231 |
--------------------------------------------------------------------------------
/onetwo/builtins/llm_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Technical details and utilities for llm builtins."""
16 |
17 | from __future__ import annotations
18 |
19 |
20 | from collections.abc import Mapping, Sequence
21 | import enum
22 | from typing import cast, Any, TypeAlias
23 |
24 | from onetwo.core import content as content_lib
25 |
26 | _ChunkList: TypeAlias = content_lib.ChunkList
27 |
28 |
29 | class TokenHealingOption(enum.Enum):
30 | """Options for token healing.
31 |
32 | Token healing is a technique that can often significantly improve generation
33 | results (or rather help to avoid erroneous results). Consider prompts ending
34 | with a whitespace, e.g., "Theory of ". Depending on the model and the
35 | tokenizer, this string is likely to be encoded into a token sequence where the
36 | last element is a whitespace token (WT) corresponding to a single whitespace "
37 | ". The model will then generate next tokens, following WT. An issue with this
38 | scenario is that words are often tokenized together with the leading
39 | whitespace. Likely, there is " relativity" token (RT) and for the majority of
40 | sentenses where the word "relativity" (case-sensitive!) occured, it was
41 | encoded into RT. So the model may be very "familiar" with RT, but not with
42 | "relativity" token (no space). Unfortunately, a separate whitespace token
43 | already appears in the end of the sequence, so the model can't generate RT
44 | (indeed, two whitespaces in a row is very uncommon, so the model won't do it)
45 | and instead generates something else (that often looks odd). This phenomenon
46 | is more general and can occur with other tokens in the end of the prompt as
47 | well. For more details refer to https://github.com/guidance-ai/guidance.
48 | """
49 | NONE = 'NONE'
50 | # In token healing we replace the last token of the prompt and send the prompt
51 | # to completion. In the ideal scenario where constrained decoding is available
52 | # we make sure that the string representation of the first token that we
53 | # generate contains the string representation of the removed token as a
54 | # prefix. This way we guarantee that the completion matches user's
55 | # expectations and we avoid token boundary artifacts at the same time. If the
56 | # constrained decoding is not available for the model we sample completions
57 | # in hope that the constraint will be satisfied and roll back to vanilla
58 | # completion in case of failure.
59 | TOKEN_HEALING = 'TOKEN_HEALING'
60 | # In space healing we remove trailing whitespaces from the end of the prompt
61 | # and send the prompt to completion. In case the prompt has trailing
62 | # whitespaces we also remove any leading whitespaces from the generated text.
63 | # This may be considered as a less general ad-hoc implementation for the
64 | # token healing.
65 | SPACE_HEALING = 'SPACE_HEALING'
66 |
67 |
68 | def space_heal_reply(
69 | reply: str | tuple[str, Mapping[str, Any]]
70 | ) -> str | tuple[str, Mapping[str, Any]]:
71 | """Possibly remove spaces from the beginning of the reply.
72 |
73 | Args:
74 | reply: Reply that we want to space heal. Reply is either a string or a tuple
75 | of a string and a dictionary of additional information.
76 |
77 | Returns:
78 | Reply with leading whitespaces removed. In case of a tuple, the additional
79 | information is preserved.
80 | """
81 | if isinstance(reply, str):
82 | reply: str = reply.lstrip(' ')
83 | elif isinstance(reply, tuple):
84 | reply, details = reply
85 | reply = reply.lstrip(' ')
86 | reply = (reply, details)
87 | else:
88 | raise ValueError(f'Unexpected type of the reply:{type(reply)}')
89 | return reply
90 |
91 |
92 | def maybe_heal_prompt(
93 | *,
94 | original_prompt: str | _ChunkList,
95 | healing_option: TokenHealingOption,
96 | ) -> _ChunkList:
97 | """Maybe heal the prompt for further generation.
98 |
99 | Args:
100 | original_prompt: Prompt that we may want to heal.
101 | healing_option: Healing option that we want to use.
102 |
103 | Returns:
104 | In case we use space healing, we remove trailing whitespaces from the
105 | prompt. The token healing is not supported yet. Otherwise, we return the
106 | prompt as is.
107 | """
108 | if isinstance(original_prompt, str):
109 | original_prompt = _ChunkList([original_prompt])
110 | healed_prompt: _ChunkList = original_prompt
111 | if healing_option == TokenHealingOption.SPACE_HEALING:
112 | # Remove trailing whitespaces.
113 | healed_prompt = healed_prompt.rstrip(' ')
114 | elif healing_option == TokenHealingOption.TOKEN_HEALING:
115 | # TODO: Support token healing.
116 | raise NotImplementedError('Token healing is not supported yet.')
117 | return healed_prompt
118 |
119 |
120 | def maybe_heal_reply(
121 | *,
122 | reply_text: str,
123 | original_prompt: str | _ChunkList,
124 | healing_option: TokenHealingOption,
125 | ) -> str:
126 | """Maybe heal the reply.
127 |
128 | Args:
129 | reply_text: Reply text that we may want to heal.
130 | original_prompt: The prompt that was used to generate the reply before any
131 | type of healing was applied to it.
132 | healing_option: Healing option that was used to generate the reply.
133 |
134 | Returns:
135 | In case we used space healing and the original prompt ended with a
136 | whitespace, we return the reply with leading whitespaces removed. The token
137 | healing is not supported yet. Otherwise, we return the reply as is.
138 | """
139 | if (
140 | healing_option == TokenHealingOption.SPACE_HEALING
141 | and original_prompt.endswith(' ')
142 | ):
143 | # We need to remove spaces from the beginning of the replies.
144 | return cast(str, space_heal_reply(reply_text))
145 | elif healing_option == TokenHealingOption.TOKEN_HEALING:
146 | # TODO: Support token healing.
147 | raise NotImplementedError('Token healing is not supported yet.')
148 | return reply_text
149 |
150 |
151 | def maybe_heal_prompt_and_targets(
152 | *,
153 | original_prompt: str | _ChunkList,
154 | original_targets: Sequence[str],
155 | healing_option: TokenHealingOption,
156 | ) -> tuple[_ChunkList, Sequence[str]]:
157 | """Maybe heal the prompt and targets for further scoring.
158 |
159 | Args:
160 | original_prompt: Prompt that we may want to heal.
161 | original_targets: Targets that we may want to heal.
162 | healing_option: Healing option that we want to use.
163 |
164 | Returns:
165 | In case we use space healing, we remove trailing whitespaces from the
166 | prompt and add leading whitespaces to the targets when necessary. The token
167 | healing is not supported yet. Otherwise, we return the prompt and targets as
168 | is.
169 | """
170 | if isinstance(original_prompt, str):
171 | original_prompt = _ChunkList([original_prompt])
172 | healed_prompt: _ChunkList = original_prompt
173 | healed_targets = list(original_targets)
174 | if healing_option == TokenHealingOption.SPACE_HEALING:
175 | if healed_prompt.endswith(' '):
176 | # Remove trailing whitespaces.
177 | healed_prompt = healed_prompt.rstrip(' ')
178 | # Add leading whitespaces to the targets if necessary.
179 | for i, target in enumerate(healed_targets):
180 | if not target.startswith(' '):
181 | healed_targets[i] = ' ' + target
182 | elif healing_option == TokenHealingOption.TOKEN_HEALING:
183 | # TODO: Support token healing.
184 | raise NotImplementedError('Token healing is not supported yet.')
185 | return healed_prompt, healed_targets
186 |
--------------------------------------------------------------------------------
/docs/basics.md:
--------------------------------------------------------------------------------
1 | # OneTwo Basics
2 |
3 | One of the key principles behind the OneTwo library is to enable the creation
4 | of complex flows involving several calls to foundation models and possibly other
5 | tools.
6 | For ease of experimentation, it is important to easily change the backends or
7 | their configuration and run the same flow on two backends/configurations, e.g.
8 | when doing comparisons.
9 |
10 | The bottleneck is often the multiple RPC requests that need to happen. This
11 | makes fast iterations or experimenting on many examples slow and tedious. In
12 | order to reduce this bottleneck, there are two strategies that are implemented
13 | in the OneTwo library:
14 |
15 | 1. **Caching**: The result of the calls to the models are cached, which enables
16 | one to very quickly replay a flow or an experiment which may have partially
17 | executed (e.g. failed in the middle of execution). For example, if you have a
18 | complex flow and want to add just one extra step, rerunning the whole thing
19 | amounts to reading everything from cache and only executing for real that one
20 | last step.
21 | 1. **Asynchronous Execution**: While some of the model calls might need to be
22 | chained serially, there are many situations when you may want to execute some
23 | calls in parallel (e.g. talking to different backends, running an experiment on
24 | many examples, or having a step in your flow where several independent tasks are
25 | performed). A natural way to do that is to use asynchronous programming, or
26 | multi-threading.
27 |
28 | ## Builtins
29 |
30 | In order to use a uniform language, we define a number of "built-in" functions
31 | representing the basic operations one may want to perform using a model.
32 |
33 | - `llm.generate_text()` - Generate raw text.
34 | - `llm.generate_object()` - Generate and parse text into a Python object.
35 | - `llm.select()` - Choose among alternatives.
36 | - `llm.instruct()` - Generate answer to instructions.
37 | - `llm.chat()` - Generate text in a multi-turn dialogue.
38 |
39 | We also decouple these functions from their implementations.
40 | So you can use them to define your **prompting strategy**, without specifying
41 | which **model** or which **model parameters** you want to use, and only specify
42 | those later.
43 |
44 | ## Executables
45 |
46 | Many of the basic functions provided by the library actually return what we call
47 | *Executables*. For example:
48 |
49 | ```python
50 | from onetwo.builtins import llm
51 |
52 | e = llm.generate_text(
53 | 'Q: What are three not so well known cities in france?\nA:',
54 | stop=['Q:'],
55 | max_tokens=20,
56 | )
57 | # > Troyes, Dijon, Annecy.
58 | ```
59 |
60 | Now this `e` variable has type `Executable[str]` and needs to be *executed* to
61 | produce the final result. This happens by calling `ot.run()`:
62 |
63 | ```python
64 | from onetwo import ot
65 |
66 | result = ot.run(e)
67 | ```
68 | The benefit of this two-step process is that one can define possibly complex
69 | execution flows in a natural pythonic way, and decouple the definition of the
70 | flow from the actual backends that are used to execute it.
71 |
72 | ## Function Registry
73 |
74 | Specifying which backend is used to actually perform a built-in function like
75 | `llm.generate_text` is done when calling the `register()` method on a backend.
76 | This method registers the various function calls that the backend supports into
77 | a global function registry.
78 |
79 | You can temporarily override this registry if you want the calls to
80 | `llm.generate_text` to be routed elsewhere.
81 |
82 | For example,
83 |
84 | ```python
85 | backend = ....
86 | backend.register()
87 |
88 | def fake_generate_text(prompt: str | ChunkList):
89 | return prompt
90 |
91 | with ot.RegistryContext():
92 | llm.generate_text.configure(fake_generate_text)
93 | print(ot.run(llm.generate_text('test')))
94 | ```
95 |
96 | As another example, assume you have two different backends, then it is possible
97 | to create two distinct registries and pick one of those at execution time:
98 |
99 | ```python
100 | backend1 = ...
101 | backend2 = ...
102 |
103 | with ot.RegistryContext():
104 | backend1.register()
105 | registry1 = ot.copy_registry()
106 |
107 | with ot.RegistryContext():
108 | backend2.register()
109 | registry2 = ot.copy_registry()
110 | ```
111 |
112 | ```python
113 | ot.run(ot.with_registry(e, registry1))
114 | ot.run(ot.with_registry(e, registry2))
115 | ```
116 |
117 | ## Asynchronous execution
118 |
119 | While it may take a bit of time to get used to `asyncio` if you never used it,
120 | we tried to make it as simple as possible.
121 |
122 | So if you need to perform two sequential calls to an LLM, you can of course run
123 | one after the other:
124 |
125 | ```python
126 | result = ot.run(llm.generate_text('Q: What is the southernmost city in France? A:'))
127 | result2 = ot.run(llm.generate_text(f'Q: Who is the mayor of {result}? A:'))
128 | ```
129 |
130 | But a better way is to create a single Executable by combining them into a
131 | function decorated with `@ot.make_executable`:
132 |
133 | ```python
134 | @ot.make_executable
135 | async def f(*any_arguments):
136 | del any_arguments # This example does not use arguments.
137 | result = await llm.generate_text('Q: What is the southernmost city in France? A:')
138 | result2 = await llm.generate_text(f'Q: Who is the mayor of {result}? A:')
139 | return result2
140 |
141 | result = ot.run(f())
142 | ```
143 |
144 | Indeed, the `ot.run()` function will actually block on execution and will
145 | only return when the LLM has produced the output, while when an async function
146 | is created with multiple await calls in its body, the execution of this function
147 | can be interleaved with the execution of other async functions. This will be
148 | beneficial when creating complex workflows with multiple calls as they can be
149 | scheduled automatically in an optimal way. For example, if we were to repeatedly
150 | call the `f` function on different inputs the inner generate_text calls could be
151 | interleaved (see next section).
152 |
153 | Functions decorated with `@ot.make_executable` return `Executable` objects
154 | when called. I.e., `f()` is of type `Executable` and can be executed with
155 | `ot.run()`.
156 |
157 | ## Combining executables
158 |
159 | We provide in particular a way to combine executables in parallel:
160 |
161 | ```python
162 | e1 = llm.generate_text('Q: What is the southernmost city in France? A:')
163 | e2 = llm.generate_text('Q: What is the southernmost city in Spain? A:')
164 | e = ot.parallel(e1, e2)
165 | results = ot.run(e)
166 | ```
167 |
168 | The creation of the `Executable` `e` as a parallel composition of two
169 | executables indicates that one does not depend on the output of the other and
170 | the calls to the LLM can thus be performed in parallel, assuming that the
171 | backend supports it. Typically a backend will be a proxy to a remote server that
172 | may support multiple simultaneous calls from different threads, or that may
173 | support sending requests in batches. In this case, the execution will
174 | automatically take advantage of this functionality to speed things up and not
175 | having to wait for the first `generate_text` call to return before performing
176 | the second one.
177 |
178 | ## Composing multi-step prompts
179 |
180 | While using `async`, `await`, and `ot.parallel` lets you create arbitrary
181 | complex flows using standard Python, there are cases where one might want a
182 | simpler way to specify basic or typical combinations. We thus also provide ways
183 | of composing prompts that can execute in multiple steps (i.e. involving multiple
184 | calls to a model).
185 |
186 | We support two different syntaxes for that:
187 |
188 | - Prompt templates in jinja2 templating language;
189 | - Composition via the `+` operator.
190 |
191 | ```python
192 | from onetwo.builtins import composables as c
193 |
194 | template = c.j("""\
195 | What is the southernmost city in France? {{ generate_text() }}
196 | Who is its mayor? {{ generate_text() }}
197 | """)
198 | result = ot.run(template)
199 | ```
200 |
201 | ```python
202 | e = 'What is the southernmost city in France?' + c.generate_text() + \
203 | 'Who is its mayor?' + c.generate_text()
204 | result = ot.run(e)
205 | ```
--------------------------------------------------------------------------------
/onetwo/core/executing_impl_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from onetwo.agents import agents_test_utils
20 | from onetwo.core import executing
21 | from onetwo.core import executing_impl
22 | from onetwo.core import tracing
23 |
24 |
25 | def ordinary_function(x, y):
26 | return x + y
27 |
28 |
29 | async def async_function(x, y):
30 | return x + y
31 |
32 |
33 | @executing.make_executable # pytype: disable=wrong-arg-types
34 | def executable_function(x, y):
35 | return x + y
36 |
37 |
38 | @executing.make_executable # pytype: disable=wrong-arg-types
39 | async def async_executable_function(x, y):
40 | return x + y
41 |
42 |
43 | @tracing.trace # pytype: disable=wrong-arg-types
44 | def traced_ordinary_function(x, y):
45 | return x + y
46 |
47 |
48 | @tracing.trace # pytype: disable=wrong-arg-types
49 | async def traced_async_function(x, y):
50 | return x + y
51 |
52 |
53 | @executing.make_executable
54 | @tracing.trace # pytype: disable=wrong-arg-types
55 | def executable_traced_function(x, y):
56 | return x + y
57 |
58 |
59 | @tracing.trace
60 | @executing.make_executable # pytype: disable=wrong-arg-types
61 | def traced_executable_function(x, y):
62 | return x + y
63 |
64 |
65 | def executable_function_that_does_not_use_make_executable_decorator(x, y):
66 | return executing.serial(executable_function(x, y), executable_function(x, y))
67 |
68 |
69 | class C:
70 | def ordinary_method(self, x, y):
71 | return x + y
72 |
73 | async def async_method(self, x, y):
74 | return x + y
75 |
76 | @executing.make_executable # pytype: disable=wrong-arg-types
77 | def executable_method(self, x, y):
78 | return x + y
79 |
80 | @executing.make_executable # pytype: disable=wrong-arg-types
81 | async def async_executable_method(self, x, y):
82 | return x + y
83 |
84 |
85 | class ExecutingImplTest(parameterized.TestCase):
86 |
87 | def test_set_decorated_with_make_executable(self):
88 | def f(x):
89 | return x
90 |
91 | with self.subTest('false_before_setting'):
92 | self.assertFalse(executing_impl.is_decorated_with_make_executable(f))
93 |
94 | executing_impl.set_decorated_with_make_executable(f)
95 |
96 | with self.subTest('true_after_setting'):
97 | self.assertTrue(executing_impl.is_decorated_with_make_executable(f))
98 |
99 | @parameterized.named_parameters(
100 | ('ordinary_function', ordinary_function, False),
101 | ('async_function', async_function, False),
102 | ('executable_function', executable_function, True),
103 | ('async_executable_function', async_executable_function, True),
104 | ('traced_ordinary_function', traced_ordinary_function, False),
105 | ('traced_async_function', traced_async_function, False),
106 | ('executable_traced_function', executable_traced_function, True),
107 | ('traced_executable_function', traced_executable_function, True),
108 | ('agent', agents_test_utils.StringAgent(), True),
109 | ('ordinary_method', C().ordinary_method, False),
110 | ('async_method', C().async_method, False),
111 | ('executable_method', C().executable_method, True),
112 | ('async_executable_method', C().async_executable_method, True),
113 | (
114 | 'executable_function_that_does_not_use_make_executable_decorator',
115 | executable_function_that_does_not_use_make_executable_decorator,
116 | False,
117 | ),
118 | )
119 | def test_is_decorated_with_make_executable(self, function, expected_result):
120 | self.assertEqual(
121 | expected_result,
122 | executing_impl.is_decorated_with_make_executable(function),
123 | )
124 |
125 | @parameterized.named_parameters(
126 | ('ordinary_function', ordinary_function, False),
127 | ('async_function', async_function, True),
128 | ('executable_function', executable_function, True),
129 | ('async_executable_function', async_executable_function, True),
130 | ('traced_ordinary_function', traced_ordinary_function, False),
131 | ('traced_async_function', traced_async_function, True),
132 | ('executable_traced_function', executable_traced_function, True),
133 | ('traced_executable_function', traced_executable_function, True),
134 | ('agent', agents_test_utils.StringAgent(), True),
135 | ('ordinary_method', C().ordinary_method, False),
136 | ('async_method', C().async_method, True),
137 | ('executable_method', C().executable_method, True),
138 | ('async_executable_method', C().async_executable_method, True),
139 | # TODO: The following case should ideally return True but does
140 | # not currently, as it is difficult to predict that the function returns
141 | # an Executable in this case without actually calling it.
142 | (
143 | 'executable_function_that_does_not_use_make_executable_decorator',
144 | executable_function_that_does_not_use_make_executable_decorator,
145 | False,
146 | ),
147 | )
148 | def test_returns_awaitable(self, function, expected_result):
149 | self.assertEqual(
150 | expected_result,
151 | executing_impl.returns_awaitable(function),
152 | )
153 |
154 | @parameterized.named_parameters(
155 | ('ordinary_function', ordinary_function),
156 | ('async_function', async_function),
157 | ('executable_function', executable_function),
158 | ('async_executable_function', async_executable_function),
159 | ('traced_ordinary_function', traced_ordinary_function),
160 | ('traced_async_function', traced_async_function),
161 | ('executable_traced_function', executable_traced_function),
162 | ('traced_executable_function', traced_executable_function),
163 | ('ordinary_method', C().ordinary_method),
164 | ('async_method', C().async_method),
165 | ('executable_method', C().executable_method),
166 | ('async_executable_method', C().async_executable_method),
167 | )
168 | def test_call_and_maybe_await(self, function):
169 | result = executing.run(
170 | executing_impl.call_and_maybe_await(function, 'x', 'y')
171 | )
172 | with self.subTest('positional_args'):
173 | self.assertEqual('xy', result)
174 |
175 | result = executing.run(
176 | executing_impl.call_and_maybe_await(function, x='x', y='y')
177 | )
178 | with self.subTest('keyword_args'):
179 | self.assertEqual('xy', result)
180 |
181 | partial_function = functools.partial(function, x='x')
182 | result = executing.run(
183 | executing_impl.call_and_maybe_await(partial_function, y='y')
184 | )
185 | with self.subTest('partial_function'):
186 | self.assertEqual('xy', result)
187 |
188 | def test_call_and_maybe_await_agent(self):
189 | function = agents_test_utils.StringAgent(sequence=['x', 'y'])
190 | result = executing.run(
191 | executing_impl.call_and_maybe_await(function, 'some_input')
192 | )
193 | self.assertEqual('x y', result)
194 |
195 | def test_call_and_maybe_await_executable_function_that_does_not_use_make_executable_decorator(
196 | self,
197 | ):
198 | # Note that this case is handled correctly, even though `returns_awaitable`
199 | # does not return the correct value for this function.
200 | function = executable_function_that_does_not_use_make_executable_decorator
201 | result = executing.run(
202 | executing_impl.call_and_maybe_await(function, 'x', 'y')
203 | )
204 | self.assertEqual(['xy', 'xy'], result)
205 |
206 |
207 | if __name__ == '__main__':
208 | absltest.main()
209 |
--------------------------------------------------------------------------------
/onetwo/stdlib/code_execution/python_execution_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper methods, to be used by implementations."""
16 |
17 | from __future__ import annotations
18 |
19 | import ast
20 | import datetime
21 | import importlib
22 | import json
23 | import logging
24 | import textwrap
25 | from typing import Any, Callable, TypeAlias
26 |
27 | from onetwo.stdlib.code_execution import python_execution
28 |
29 | # Aliases
30 | _ExecutionStatus: TypeAlias = python_execution.ExecutionStatus
31 | _SandboxStatus: TypeAlias = python_execution.SandboxStatus
32 | _SandboxResult: TypeAlias = python_execution.SandboxResult
33 | _SandboxResultTiming: TypeAlias = python_execution.SandboxResultTiming
34 |
35 | # Constants for hook communication
36 | DATA_KEY: str = 'data'
37 | EXCEPTION_KEY: str = 'exception'
38 | EX_ARGS_KEY: str = 'args'
39 | EX_CLASS_KEY: str = 'exception_class'
40 | EX_HOOK_NAME_KEY: str = 'hook_name'
41 |
42 |
43 | def current_timing(
44 | *, start: datetime.datetime, base: datetime.datetime
45 | ) -> _SandboxResultTiming:
46 | """Returns a timing object based on the current time and given start/base.
47 |
48 | Args:
49 | start: Time the execution request was sent to the sandbox.
50 | base: Time of last interaction with the sandbox (i.e., the later of the
51 | start time, and the time of the last callback to a hook function).
52 | """
53 | now = datetime.datetime.now()
54 | return _SandboxResultTiming(
55 | since_start=now - start, since_last_interaction=now - base
56 | )
57 |
58 |
59 | def adjust_code_to_set_final_expression_value(
60 | code: str,
61 | variable_name: str = 'result',
62 | default_value: Any = None,
63 | ) -> str:
64 | r"""Returns adjusted code that store the final expression value in a variable.
65 |
66 | Ignores trailing comment lines when determining the "final expression value".
67 |
68 | If the provided code is not valid Python code, then will return the original
69 | code unchanged. (Ideally, the caller should perform validation of the code
70 | prior to calling this function.)
71 |
72 | Examples (assuming variable_name = 'result'):
73 | * 'x = 2\nx + 3' => 'x = 2\nresult = x + 3'.
74 | * 'x = 2\ndel x' => 'x = 2\ndel x\nresult = None'.
75 | * 'x = 2\n# Comment' => 'result = x = 2\n# Comment'.
76 |
77 | Args:
78 | code: String containing the original unadjusted Python code.
79 | variable_name: The variable name into which to store the final expression
80 | value.
81 | default_value: Default value to assign to `variable_name` in the case where
82 | the actual final expression value was undefined.
83 | """
84 | dedented_code = textwrap.dedent(code.rstrip())
85 | lines = dedented_code.split('\n')
86 |
87 | # Special case when `code` is empty or consists of only whitespace.
88 | if len(lines) == 1 and not lines[0].strip():
89 | lines = []
90 |
91 | try:
92 | parse_tree = ast.parse(dedented_code)
93 | except Exception: # pylint: disable=broad-exception-caught
94 | logging.warning(
95 | 'Failed to parse Python code:\n```\n%s\n```',
96 | dedented_code,
97 | exc_info=True,
98 | )
99 | return code
100 |
101 | result_idx = None
102 | if parse_tree and parse_tree.body:
103 | # ast.AST.lineno is 1-indexed. We subtract 1 to make it 0-based.
104 | last_statement = parse_tree.body[-1]
105 | # The only Python statements that are compatible with variable assignment
106 | # (as far as we are aware...) are expressions and assignments, e.g.:
107 | # * Expression: `2 + 3` ==> `result = 2 + 3`
108 | # * Assignment: `y = x + 2` ==> `result = y = x + 2`
109 | # Other Python statements that are incompatible with variable assignment
110 | # include compound statements (`if`, `while`, `for`, etc.) and various
111 | # simple non-expression / non-assignment statements (e.g., `assert`, `del`,
112 | # `import`, `raise`, etc.).
113 | # For background, see: https://docs.python.org/3/reference/index.html
114 | if isinstance(last_statement, ast.Assign) or isinstance(
115 | last_statement, ast.Expr
116 | ):
117 | result_idx = last_statement.lineno - 1
118 |
119 | if result_idx is None:
120 | # In this case, since the last statement does not return a value that can be
121 | # assigned to the result variable, we just set the result equal to None.
122 | lines.append(f'{variable_name} = {default_value}')
123 | else:
124 | # If the last statement does return a value, then we can simply set the
125 | # result variable equal to that.
126 | lines[result_idx] = f'{variable_name} = ' + lines[result_idx]
127 |
128 | return '\n'.join(lines)
129 |
130 |
131 | def parse_sandbox_result_json(
132 | result_str: str,
133 | start_time: datetime.datetime,
134 | base_time: datetime.datetime,
135 | ) -> _SandboxResult | None:
136 | """Tries to parse a JSON string into a SandboxResult object.
137 |
138 | Args:
139 | result_str: The string output from the sandbox, expected to be JSON.
140 | start_time: The time execution started.
141 | base_time: The time of the last interaction.
142 |
143 | Returns:
144 | A _SandboxResult if result_str is valid JSON and has the expected
145 | structure (contains at least 'stdout'), otherwise None.
146 | """
147 | timing = current_timing(start=start_time, base=base_time)
148 |
149 | try:
150 | # Note that this uses the JSON decoder with potential
151 | # implementation-specific customizations.
152 | result_dict = json.loads(result_str)
153 | except json.JSONDecodeError:
154 | logging.warning('Failed to parse result string:\n```\n%s\n```', result_str)
155 | return None
156 |
157 | if isinstance(result_dict, dict) and 'stdout' in result_dict:
158 | # Received a full SandboxResult in dict form, as expected.
159 | return _SandboxResult(
160 | final_expression_value=result_dict.get('final_expression_value'),
161 | stdout=result_dict.get('stdout', ''),
162 | sandbox_status=_SandboxStatus(
163 | result_dict.get('sandbox_status', 'AFTER_RUNNING_CODE')
164 | ),
165 | execution_status=_ExecutionStatus(
166 | result_dict.get('execution_status', 'SUCCESS')
167 | ),
168 | status_message=result_dict.get('status_message', ''),
169 | failure_details=python_execution.parse_failure_details(result_dict),
170 | timing=timing,
171 | )
172 | else:
173 | # JSON was valid, but not the expected SandboxResult structure.
174 | return None
175 |
176 |
177 | def create_sandbox_hook_callable(
178 | hook_name: str, conn: Any
179 | ) -> Callable[..., Any]:
180 | """Creates a callable to be used inside a sandbox for a specific hook.
181 |
182 | This function is intended to be run *inside* the sandboxed process.
183 | Args:
184 | hook_name: The name of the hook.
185 | conn: The connection object to the main process.
186 |
187 | Returns:
188 | A wrapper function that communicates with the main process.
189 | """
190 |
191 | def hook_wrapper(*args, **kwargs):
192 | message = {'hook': hook_name, 'args': args, 'kwargs': kwargs}
193 | conn.send(json.dumps(message))
194 | result_str = conn.recv()
195 | result = json.loads(result_str)
196 |
197 | ex = result.get(EXCEPTION_KEY)
198 | if ex:
199 | error_class_name = ex.get(EX_CLASS_KEY, 'RuntimeError')
200 | error_args = ex.get(EX_ARGS_KEY, ())
201 | try:
202 | error_module = importlib.import_module('builtins')
203 | error_class = getattr(error_module, error_class_name)
204 | except AttributeError:
205 | error_class = RuntimeError
206 | error = error_class(*error_args)
207 | setattr(error, '_hook_name', ex.get(EX_HOOK_NAME_KEY))
208 | raise error
209 | return result.get(DATA_KEY)
210 |
211 | return hook_wrapper
212 |
--------------------------------------------------------------------------------
/onetwo/core/executing_impl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Core content from `executing.py` factored out to avoid circular dependencies.
16 |
17 | The intention is to keep the content in this file to the minimum needed to avoid
18 | circular dependencies between `executing.py` and other low-level libraries like
19 | `utils.py` or `batching.py`. In particular, we should minimize the dependencies
20 | from this file to other OneTwo libraries.
21 |
22 | The existence of this file should be treated as an implementation detail from
23 | the perspective of general users of OneTwo. So rather than importing this file
24 | directly, most code (both inside and outside of the OneTwo core) should instead
25 | import `executing.py` and/or `utils.py`, which contain public aliases to the
26 | classes and functions defined here.
27 | """
28 |
29 | from __future__ import annotations
30 |
31 | import abc
32 | from collections.abc import AsyncIterator, Awaitable, Callable, Generator
33 | import dataclasses
34 | import functools
35 | import inspect
36 | from typing import Any, final, Generic, ParamSpec, TypeVar
37 |
38 | from onetwo.core import updating
39 |
40 |
41 | # Basic type variables that need to be specified when using this library.
42 | _Result = TypeVar('_Result')
43 | _Args = ParamSpec('_Args')
44 |
45 | _Update = updating.Update
46 |
47 |
48 | class Executable(
49 | Generic[_Result],
50 | Awaitable[_Result],
51 | AsyncIterator[_Update[_Result]],
52 | metaclass=abc.ABCMeta,
53 | ):
54 | """Interface for a process that can be executed step by step.
55 |
56 | Executable supports both `await` and `async for` statements. If the underlying
57 | implementation supports only `await`, we wrap it into the async iterator that
58 | yields only that one item. Similarly, if the underlying implementation is an
59 | async iterator and only supports `async for`, we implement `await` by manually
60 | iterating through all of the values and returning the final one.
61 |
62 | When awaited, executable returns a value of type Result. However, when using
63 | `async for`, executable produces instances of class (or subclasses of)
64 | `updating.Update`. Class Update helps maintaining (`accumulate`) intermediate
65 | resuts and obtaining the final result (`to_result`) based on accumulated
66 | information.
67 | """
68 |
69 | @abc.abstractmethod
70 | async def _aexec(self) -> _Result:
71 | """Implementation as an Awaitable."""
72 |
73 | @abc.abstractmethod
74 | async def _aiterate(
75 | self, iteration_depth: int = 1
76 | ) -> AsyncIterator[_Update[_Result]]:
77 | """Implementation as an AsyncIterator with configurable depth."""
78 | yield _Update() # For correct typing
79 |
80 | @final
81 | def with_depth(self, iteration_depth: int) -> AsyncIterator[_Update[_Result]]:
82 | return self._aiterate(iteration_depth)
83 |
84 | @final
85 | def __await__(self) -> Generator[Any, Any, _Result]:
86 | """Method that is called when using `await executable`."""
87 | result = yield from self._aexec().__await__()
88 | return result
89 |
90 | @final
91 | def __aiter__(self) -> AsyncIterator[_Update[_Result]]:
92 | """Method that is called when using `async for executable`."""
93 | return self._aiterate().__aiter__()
94 |
95 | @final
96 | async def __anext__(self) -> _Update[_Result]:
97 | """Method that is called when using `async for executable`."""
98 | return await self._aiterate().__anext__()
99 |
100 |
101 | @dataclasses.dataclass
102 | class ExecutableWithPostprocessing(
103 | Generic[_Result], Executable[_Result]
104 | ):
105 | """An executable with a callback executed at the end of the processing.
106 |
107 | One can define two callbacks, one for when the executable is executed with
108 | `await` and one when the executable is iterated through with `async for`.
109 | The latter one is optional.
110 |
111 | Attributes:
112 | wrapped: Executable to augment with postprocessing.
113 | postprocessing_callback: Callback to call at the end of the processing (in
114 | case we call the executable with `await`).
115 | update_callback: Callback to call after each update from the
116 | wrapped Executable (in case we call the executable with `async for`). If
117 | this is None, we will use the postprocessing_callback on
118 | `update.to_result()` (hence converting the updates into results, calling
119 | the callback and converting this back into an Update to be yielded).
120 | """
121 |
122 | wrapped: Executable[_Result]
123 | postprocessing_callback: Callable[[_Result], _Result]
124 | update_callback: Callable[[_Update[_Result]], _Update[_Result]] | None = None
125 |
126 | @final
127 | async def _aiterate(
128 | self, iteration_depth: int = 1
129 | ) -> AsyncIterator[_Update[_Result]]:
130 | """Yields the intermediate values and calls the final_value_callback."""
131 | updates = _Update()
132 | async for update in self.wrapped.with_depth(iteration_depth):
133 | updates += update
134 | if self.update_callback is not None:
135 | yield self.update_callback(update)
136 | else:
137 | yield _Update(self.postprocessing_callback(updates.to_result()))
138 |
139 | @final
140 | async def _aexec(self) -> _Result:
141 | """Iterate this value until done (including calling final_value_callback).
142 |
143 | Returns:
144 | The final value given by the AsyncIterator _inner().
145 | """
146 | result = await self.wrapped
147 | return self.postprocessing_callback(result)
148 |
149 |
150 | def set_decorated_with_make_executable(f: Callable[..., Any]) -> None:
151 | """Marks the callable as being decorated with @executing.make_executable.
152 |
153 | Args:
154 | f: An arbitrary callable.
155 | """
156 | f.decorated_with_make_executable = True
157 |
158 |
159 | def is_decorated_with_make_executable(f: Callable[..., Any]) -> bool:
160 | """Returns whether the callable is decorated with @executing.make_executable.
161 |
162 | Args:
163 | f: An arbitrary callable.
164 | """
165 | # Special handling for partial functions.
166 | while isinstance(f, functools.partial):
167 | f = f.func
168 |
169 | # Special handling for callable objects (e.g., sub-classes of Agent).
170 | if (
171 | hasattr(f, '__call__')
172 | and not inspect.isfunction(f)
173 | and not inspect.ismethod(f)
174 | ):
175 | f = f.__call__
176 |
177 | return getattr(f, 'decorated_with_make_executable', False)
178 |
179 |
180 | def returns_awaitable(f: Callable[..., Any]) -> bool:
181 | """Returns whether the callable returns something that is awaitable.
182 |
183 | Note that while using this function is more reliable than calling
184 | `inspect.iscoroutinefunction` directly, there are still some cases that it
185 | does not catch, such as when `f` is manually implemented to return an
186 | Executable, without using the `@executing.make_executable` decorator. To
187 | catch these cases, it is recommended to use `call_and_maybe_await` where
188 | possible, rather than depending on `returns_awaitable`.
189 |
190 | Args:
191 | f: An arbitrary callable.
192 | """
193 | if (
194 | hasattr(f, '__call__')
195 | and not inspect.isfunction(f)
196 | and not inspect.ismethod(f)
197 | and not isinstance(f, functools.partial)
198 | ):
199 | f = f.__call__
200 |
201 | return inspect.iscoroutinefunction(f) or is_decorated_with_make_executable(f)
202 |
203 |
204 | async def call_and_maybe_await(
205 | f: Callable[_Args, _Result | Awaitable[_Result] | Executable[_Result]],
206 | *args: _Args.args,
207 | **kwargs: _Args.kwargs,
208 | ) -> _Result:
209 | """"Calls the callable and awaits the result if appropriate."""
210 | if returns_awaitable(f):
211 | result = await f(*args, **kwargs)
212 | else:
213 | result = f(*args, **kwargs)
214 | # The below condition is needed in case `f` was manually implemented to return
215 | # an Executable, without using the `@executing.make_executable` decorator.
216 | if isinstance(result, Executable):
217 | result = await result
218 | return result
219 |
--------------------------------------------------------------------------------
/onetwo/agents/critics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Interfaces and utilities for scoring, ranking and selecting updates.
16 |
17 | These can be used in combination with agents to implement optimization
18 | algorithms.
19 | """
20 |
21 | import abc
22 | from collections.abc import Sequence
23 | from typing import Generic, TypeAlias, TypeVar
24 |
25 | from onetwo.agents import agents_base
26 | from onetwo.core import executing
27 |
28 |
29 | # Type used to represent a state.
30 | _S = TypeVar('_S')
31 |
32 | # Type used to represent an incremental update of a state.
33 | _U = TypeVar('_U')
34 |
35 |
36 | class ScoringFunction(Generic[_S, _U], metaclass=abc.ABCMeta):
37 | """Interface for a scoring functions."""
38 |
39 | @executing.make_executable(copy_self=False)
40 | @abc.abstractmethod
41 | async def __call__(self, state: _S, update: _U) -> float:
42 | """Returns an absolute score of the update given the current state.
43 |
44 | The score represents the quality of the pair (state, update), i.e. how good
45 | the state after update, which means how good is the result of doing
46 | `state + update` (since updates can be added to states via the `+` operator
47 | to form new states).
48 | Ideally it should be comparable across different (state, update) pairs, but
49 | there could be cases where the starting state is identical across different
50 | updates (e.g. we are comparing different updates to the same state) in which
51 | case it is okay to return a score that does not take the state into account,
52 | but this is not recommended.
53 |
54 | Args:
55 | state: The current state.
56 | update: An incremental update of the state.
57 |
58 | Returns:
59 | A float representing the score of the state after update.
60 | The score is not normalized and can take any (positive or negative) value.
61 | Higher means better.
62 | """
63 | return 0.0
64 |
65 |
66 | class RankingFunction(Generic[_S, _U], metaclass=abc.ABCMeta):
67 | """Interface for a ranking function."""
68 |
69 | @executing.make_executable(copy_self=False)
70 | @abc.abstractmethod
71 | async def __call__(
72 | self, states_and_updates: Sequence[tuple[_S, _U]]
73 | ) -> list[int]:
74 | """Ranks a list of states and updates and returns indices.
75 |
76 | The order is from best to worst.
77 | When considering a pair (state, update), the value of the pair corresponds
78 | to how good the state after update is, i.e. the value of `state + update`.
79 | The order should thus make sense even when comparing pairs with different
80 | starting states.
81 |
82 | Args:
83 | states_and_updates: A list of states and updates to be ranked.
84 |
85 | Returns:
86 | A list of indices corresponding to the ranking of the states and updates.
87 | The indices refer to the order in the input list.
88 | """
89 | return []
90 |
91 |
92 | class SelectingFunction(Generic[_S, _U], metaclass=abc.ABCMeta):
93 | """Interface for a selecting function."""
94 |
95 | @executing.make_executable(copy_self=False)
96 | @abc.abstractmethod
97 | async def __call__(
98 | self, states_and_updates: Sequence[tuple[_S, _U]]
99 | ) -> int:
100 | """Returns the index of the best (state, update) pair.
101 |
102 | When considering a pair (state, update), the value of the pair corresponds
103 | to how good the state after update is, i.e. the value of `state + update`.
104 | The notion of "best" should thus make sense even when comparing pairs with
105 | different starting states.
106 |
107 | Args:
108 | states_and_updates: A list of states and updates to select from.
109 |
110 | Returns:
111 | The index of the selected pair in the input list.
112 | """
113 | return 0
114 |
115 |
116 | def ranker_from_scorer(
117 | scorer: ScoringFunction[_S, _U]
118 | ) -> RankingFunction[_S, _U]:
119 | """Converts a ScoringFunction into a RankingFunction.
120 |
121 | The RankingFunction ranks the states and updates by decreasing score.
122 |
123 | Args:
124 | scorer: A ScoringFunction.
125 |
126 | Returns:
127 | A RankingFunction.
128 | """
129 | @executing.make_executable # pytype: disable=wrong-arg-types
130 | async def ranker(states_and_updates: Sequence[tuple[_S, _U]]) -> list[int]:
131 | executables = [scorer(s[0], s[1]) for s in states_and_updates] # pytype: disable=wrong-arg-count
132 | scores = await executing.par_iter(executables)
133 | sorted_updates = sorted(
134 | enumerate(scores), key=lambda x: x[1], reverse=True
135 | )
136 | return [update[0] for update in sorted_updates]
137 |
138 | return ranker
139 |
140 |
141 | async def _select_k_best(
142 | states_and_updates: Sequence[tuple[_S, _U]],
143 | critic: SelectingFunction[_S, _U],
144 | k: int,
145 | ) -> list[int]:
146 | """Selects the k best states and updates."""
147 | current_list = list(states_and_updates)
148 | selected = []
149 | while len(selected) < k and current_list:
150 | if len(current_list) == 1:
151 | # If there is only one element left, we can stop the loop.
152 | selected.append(0)
153 | break
154 | else:
155 | best_index = await critic(current_list) # pytype: disable=wrong-arg-count
156 | current_list.pop(best_index)
157 | selected.append(best_index)
158 | return selected
159 |
160 |
161 | def ranker_from_selector(
162 | selector: SelectingFunction[_S, _U]
163 | ) -> RankingFunction[_S, _U]:
164 | """Converts a SelectingFunction into a RankingFunction.
165 |
166 | The RankingFunction repeatedly calls the SelectingFunction to select the best
167 | state/update pair and removes it from the list.
168 |
169 | Args:
170 | selector: A SelectingFunction.
171 |
172 | Returns:
173 | A RankingFunction.
174 | """
175 |
176 | @executing.make_executable # pytype: disable=wrong-arg-types
177 | async def ranker(states_and_updates: Sequence[tuple[_S, _U]]) -> list[int]:
178 | return await _select_k_best(
179 | states_and_updates, selector, len(states_and_updates)
180 | )
181 |
182 | return ranker
183 |
184 |
185 | class ScoreFromUpdates(
186 | Generic[_S, _U], ScoringFunction[_S, agents_base.ScoredUpdate[_U]]
187 | ):
188 | """Scoring function that extracts the score directly from the update."""
189 |
190 | @executing.make_executable # pytype: disable=wrong-arg-types
191 | async def __call__(
192 | self, state: _S, update: agents_base.ScoredUpdate[_U]
193 | ) -> float:
194 | return update.score
195 |
196 |
197 | # Type used to represent inputs.
198 | _I = TypeVar('_I')
199 |
200 | _ListOfScoredUpdates: TypeAlias = agents_base.UpdateListState[
201 | _I, agents_base.ScoredUpdate[_U]
202 | ]
203 | _ScoredUpdate: TypeAlias = agents_base.ScoredUpdate[_U]
204 |
205 |
206 | class ScoreFromUpdateList(
207 | Generic[_I, _U], ScoringFunction[_ListOfScoredUpdates, _ScoredUpdate]
208 | ):
209 | """Scoring function that extracts the score from an update list."""
210 |
211 | @executing.make_executable # pytype: disable=wrong-arg-types
212 | async def __call__(
213 | self, state: _ListOfScoredUpdates, update: _ScoredUpdate
214 | ) -> float:
215 | """Sums the scores of the updates in the list and the new update.
216 |
217 | This scoring function can be used when for example we have an agent
218 | that associates a score to each of the updates it performs. In this case
219 | the score of a state can be the sum of the scores of all the updates.
220 | For example if we have a sequential distribution as the agent, it may
221 | assign a probability to each update, which we can convert (by taking the
222 | log) into a score that can be added across the whole sequence.
223 |
224 | Args:
225 | state: The current state, i.e. a list of scored updates.
226 | update: The latest scored update.
227 |
228 | Returns:
229 | The sum of the scores of all the updates in the list and the new update.
230 | """
231 | updated_state = state + update
232 | score = 0.0
233 | for s in updated_state.updates:
234 | score += s.score
235 | return score
236 |
237 |
--------------------------------------------------------------------------------
/onetwo/backends/onetwo_api.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Google OneTwo API connector.
16 |
17 | A Backend that connects to a OneTwo model server and exposes its
18 | functionalities.
19 | """
20 |
21 | import collections
22 | from collections.abc import Mapping, Sequence
23 | import dataclasses
24 | import json
25 | from typing import Any
26 | from onetwo.backends import backends_base
27 | from onetwo.builtins import llm
28 | from onetwo.core import batching
29 | from onetwo.core import caching
30 | from onetwo.core import content as content_lib
31 | from onetwo.core import tracing
32 | from onetwo.core import utils
33 | import requests
34 |
35 |
36 |
37 |
38 |
39 |
40 | @batching.add_batching # Methods of this class are batched.
41 | @dataclasses.dataclass
42 | class OneTwoAPI(
43 | caching.FileCacheEnabled, # Methods of this class are cached.
44 | backends_base.Backend,
45 | ):
46 | """Google OneTwo API.
47 |
48 | Attributes:
49 | disable_caching: Whether caching is enabled for this object (inherited from
50 | CacheEnabled).
51 | cache_filename: Name of the file (full path) where the cache is stored
52 | (inherited from FileCacheEnabled)
53 | endpoint: The address to connect to (typically some http endpoint).
54 | batch_size: Number of requests (generate_text or chat or generate_embedding)
55 | that is grouped together when sending them to OneTwo API. OneTwo API does
56 | not explicitly support batching (i.e. multiple requests can't be passed
57 | via arguments). Instead we send multiple requests from separate threads.
58 | enable_streaming: Whether to enable streaming replies from generate_text.
59 | max_qps: Maximum queries per second for the backend (if None, no rate
60 | limiting is applied).
61 | temperature: Temperature parameter (float) for LLM generation (can be set as
62 | a default and can be overridden per request).
63 | max_tokens: Maximum number of tokens to generate (can be set as a default
64 | and can be overridden per request).
65 | stop: Stop sequences (as a list of strings) for LLM text generation (can be
66 | set as a default and can be overridden per request).
67 | top_p: Top-p parameter (float) for LLM text generation (can be set as a
68 | default and can be overridden per request).
69 | top_k: Top-k parameter (int) for LLM text generation (can be set as a
70 | default and can be overridden per request).
71 | """
72 | endpoint: str = dataclasses.field(init=True, default_factory=str)
73 | batch_size: int = 1
74 | enable_streaming: bool = False
75 | max_qps: float | None = None
76 |
77 | # Generation parameters
78 | temperature: float | None = None
79 | max_tokens: int | None = None
80 | stop: Sequence[str] | None = None
81 | top_p: float | None = None
82 | top_k: int | None = None
83 |
84 | _counters: collections.Counter[str] = dataclasses.field(
85 | init=False, default_factory=collections.Counter
86 | )
87 |
88 | def register(self, name: str | None = None) -> None:
89 | """See parent class."""
90 | del name
91 | # Reset all the defaults in case some other backend was already registered.
92 | # Indeed, we rely on certain builtins configured with OneTwo defaults.
93 | llm.reset_defaults()
94 | llm.generate_text.configure(
95 | self.generate_text,
96 | temperature=self.temperature,
97 | max_tokens=self.max_tokens,
98 | stop=self.stop,
99 | top_p=self.top_p,
100 | top_k=self.top_k,
101 | )
102 | llm.count_tokens.configure(self.count_tokens)
103 | llm.tokenize.configure(self.tokenize)
104 |
105 | def __post_init__(self) -> None:
106 | # Create cache.
107 | self._cache_handler = caching.SimpleFunctionCache(
108 | cache_filename=self.cache_filename,
109 | )
110 | # Check the health status of the endpoint.
111 | try:
112 | response = requests.get(self.endpoint + '/health')
113 | if response.status_code != requests.codes.ok:
114 | raise ValueError(f'OneTwoAPI endpoint unhealthy: {response.text}')
115 | except Exception as err:
116 | raise ValueError(f'OneTwoAPI connection failed: {err}') from err
117 |
118 | @tracing.trace(name='OneTwoAPI.generate_text')
119 | @caching.cache_method( # Cache this method.
120 | name='generate_text',
121 | is_sampled=True, # Two calls with same args may return different replies.
122 | cache_key_maker=lambda: caching.CacheKeyMaker(hashed=['prompt']),
123 | )
124 | @batching.batch_method_with_threadpool(
125 | batch_size=utils.FromInstance('batch_size'),
126 | wrapper=batching.add_logging,
127 | )
128 | def generate_text(
129 | self,
130 | prompt: str | content_lib.ChunkList,
131 | *,
132 | temperature: float | None = None,
133 | max_tokens: int | None = None,
134 | stop: Sequence[str] | None = None,
135 | top_k: int | None = None,
136 | top_p: float | None = None,
137 | include_details: bool = False,
138 | **kwargs, # Optional server-specific arguments.
139 | ) -> str | tuple[str, Mapping[str, Any]]:
140 | """See builtins.llm.generate_text."""
141 | self._counters['generate_text'] += 1
142 |
143 | if isinstance(prompt, content_lib.ChunkList):
144 | prompt = str(prompt)
145 |
146 | args = {
147 | 'prompt': prompt,
148 | 'temperature': temperature,
149 | 'max_tokens': max_tokens,
150 | 'stop': stop,
151 | 'top_k': top_k,
152 | 'top_p': top_p,
153 | 'include_details': include_details,
154 | }
155 | args.update(kwargs)
156 | # TODO: Trace this external API call.
157 | response = requests.post(
158 | self.endpoint + '/generate_text',
159 | headers={'Content-Type': 'application/json'},
160 | data=json.dumps(args),
161 | )
162 | if response.status_code != requests.codes.ok:
163 | raise ValueError(f'OneTwoAPI /generate_text failed: {response.text}')
164 | response = json.loads(response.text)
165 | return (response if include_details else response[0])
166 |
167 | @caching.cache_method( # Cache this method.
168 | name='tokenize',
169 | is_sampled=False,
170 | cache_key_maker=lambda: caching.CacheKeyMaker(hashed=['prompt']),
171 | )
172 | @batching.batch_method_with_threadpool(
173 | batch_size=utils.FromInstance('batch_size'),
174 | wrapper=batching.add_logging,
175 | )
176 | def tokenize(
177 | self,
178 | content: str | content_lib.ChunkList,
179 | ) -> list[int]:
180 | """See builtins.llm.tokenize."""
181 | self._counters['tokenize'] += 1
182 |
183 | if isinstance(content, content_lib.ChunkList):
184 | content = str(content)
185 |
186 | # TODO: Trace this external API call.
187 | response = requests.post(
188 | self.endpoint + '/tokenize',
189 | json={
190 | 'content': content,
191 | },
192 | )
193 | if response.status_code != requests.codes.ok:
194 | raise ValueError(f'OneTwoAPI /tokenize failed: {response.text}')
195 | response = json.loads(response.text)
196 | return response['result']
197 |
198 | @caching.cache_method( # Cache this method.
199 | name='count_tokens',
200 | is_sampled=False,
201 | cache_key_maker=lambda: caching.CacheKeyMaker(hashed=['prompt']),
202 | )
203 | @batching.batch_method_with_threadpool(
204 | batch_size=utils.FromInstance('batch_size'),
205 | wrapper=batching.add_logging,
206 | )
207 | def count_tokens(
208 | self,
209 | content: str | content_lib.ChunkList,
210 | ) -> int:
211 | """See builtins.llm.tokenize."""
212 | self._counters['count_tokens'] += 1
213 |
214 | if isinstance(content, content_lib.ChunkList):
215 | content = str(content)
216 |
217 | # TODO: Trace this external API call.
218 | response = requests.post(
219 | url=self.endpoint + '/count_tokens',
220 | json={'content': content},
221 | )
222 | if response.status_code != requests.codes.ok:
223 | raise ValueError(
224 | f'OneTwoAPI /count_tokens failed: {response.text}'
225 | )
226 | response = json.loads(response.text)
227 | return response['result']
228 |
--------------------------------------------------------------------------------