├── colabs └── images │ ├── paligemma_fox.png │ ├── paligemma_puffin.png │ ├── README.md │ └── image_provenance.py ├── onetwo ├── backends │ ├── testdata │ │ └── bird.jpg │ ├── onetwo_api_test_script.sh │ ├── run_model_server.py │ ├── formatters_test.py │ ├── formatters.py │ ├── backends_base.py │ ├── onetwo_api_manual_test.py │ ├── openai_mock.py │ ├── model_server.py │ └── onetwo_api.py ├── version.py ├── __init__.py ├── core │ ├── constants.py │ ├── updating_test.py │ ├── core_test_utils.py │ ├── sampling_test.py │ ├── updating.py │ ├── routing.py │ ├── executing_with_context_test.py │ ├── executing_impl_test.py │ └── executing_impl.py ├── ot.py ├── builtins │ ├── callbacks_test.py │ ├── tool_use_test.py │ ├── tool_use.py │ ├── composables.py │ ├── prompt_templating.py │ └── llm_utils.py ├── stdlib │ ├── ensembling │ │ ├── distribution_metrics.py │ │ └── distribution_metrics_test.py │ └── code_execution │ │ ├── python_execution_test.py │ │ ├── python_execution_test_utils.py │ │ └── python_execution_utils.py └── agents │ ├── agents_test_utils_test.py │ ├── agents_test_utils.py │ ├── critics_test.py │ ├── tasks │ ├── game_of_24_test.py │ └── game_of_24.py │ ├── distribution_test.py │ └── critics.py ├── .gitignore ├── CONTRIBUTING.md ├── pyproject.toml ├── docs ├── faq.md └── basics.md └── README.md /colabs/images/paligemma_fox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/onetwo/HEAD/colabs/images/paligemma_fox.png -------------------------------------------------------------------------------- /colabs/images/paligemma_puffin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/onetwo/HEAD/colabs/images/paligemma_puffin.png -------------------------------------------------------------------------------- /onetwo/backends/testdata/bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/onetwo/HEAD/onetwo/backends/testdata/bird.jpg -------------------------------------------------------------------------------- /colabs/images/README.md: -------------------------------------------------------------------------------- 1 | # Image Assets 2 | 3 | This directory contains various image assets used in the Colab tutorial. 4 | 5 | The complete source and licensing information for every image file is documented 6 | programmatically in `image_provenance.py` within this directory. 7 | -------------------------------------------------------------------------------- /onetwo/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Current OneTwo version at head on GitHub.""" 16 | 17 | # A new GitHub release will be pushed everytime `__version__` is increased. 18 | # When changing this, also update the CHANGELOG.md. 19 | __version__ = '0.3.0' 20 | -------------------------------------------------------------------------------- /onetwo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """OneTwo API.""" 16 | 17 | # Do NOT add anything here !! 18 | # Indeed, top-level `__init__.py` makes it hard to import a specific sub-module 19 | # without triggering a full import of the codebase. 20 | # Instead, the public API is exposed in `ot.py`. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | *.sw[op] 3 | 4 | # C extensions 5 | *.so 6 | 7 | # Packages 8 | *.egg 9 | *.egg-info 10 | dist 11 | build 12 | eggs 13 | .eggs 14 | parts 15 | bin 16 | var 17 | sdist 18 | develop-eggs 19 | .installed.cfg 20 | lib 21 | lib64 22 | __pycache__ 23 | 24 | # Installer logs 25 | pip-log.txt 26 | 27 | # Unit test / coverage reports 28 | .coverage 29 | .nox 30 | .cache 31 | .pytest_cache 32 | 33 | 34 | # Mac 35 | .DS_Store 36 | 37 | # JetBrains 38 | .idea 39 | 40 | # VS Code 41 | .vscode 42 | 43 | # emacs 44 | *~ 45 | 46 | # Built documentation 47 | docs/_build 48 | bigquery/docs/generated 49 | docs.metadata 50 | 51 | # Virtual environment 52 | env/ 53 | venv/ 54 | 55 | # Test logs 56 | coverage.xml 57 | *sponge_log.xml 58 | 59 | # System test environment variables. 60 | system_tests/local_test_setup 61 | 62 | # Make sure a generated file isn't accidentally committed. 63 | pylintrc 64 | pylintrc.test 65 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | At this time we do not plan to accept non-trivial contributions. 4 | We encourage forking the repository and continued development (as permitted by 5 | the license). 6 | 7 | ## Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a Contributor License 10 | Agreement. You (or your employer) retain the copyright to your contribution, 11 | this simply gives us permission to use and redistribute your contributions as 12 | part of the project. Head over to to see 13 | your current agreements on file or to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows [Google's Open Source Community 29 | Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /onetwo/core/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Some constants used throughout the onetwo codebase.""" 16 | 17 | # Used when errors occur. 18 | ERROR_STRING = '#ERROR#' 19 | 20 | # Field in the caching key where the name of the cached function is stored. 21 | CACHING_FUNCTION_NAME_KEY = '_destination' 22 | 23 | # Prompt prefix. 24 | PROMPT_PREFIX = 'prefix' 25 | 26 | # The field of the Jinja context that contains variables. 27 | CONTEXT_VARS = '__vars__' 28 | 29 | # The context variable containing the execution result for a Jinja template. 30 | RESULT_VAR = '_result' 31 | 32 | # The fields in the Jinja context variables with results of `choose` command. 33 | CHOICES_VAR = 'choices' 34 | SCORES_VAR = 'scores' 35 | -------------------------------------------------------------------------------- /onetwo/backends/onetwo_api_test_script.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | # Run from the onetwo root dir that contains README.md and other files. 18 | ONETWO_ROOT_DIR="." 19 | START_SERVER_CMD="python3 ${ONETWO_ROOT_DIR}/onetwo/backends/run_model_server.py" 20 | START_CLIENT_CMD="python3 ${ONETWO_ROOT_DIR}/onetwo/backends/onetwo_api_manual_test.py" 21 | CACHE_DIR="${ONETWO_ROOT_DIR}/tmp" 22 | # Requires portpicker to be installed. 23 | PORT=`python3 -m portpicker $$` 24 | 25 | function clean_fail() { 26 | kill $!; 27 | exit 1 28 | } 29 | 30 | echo "Using port ${PORT}" 31 | 32 | set -o xtrace 33 | 34 | # Start model_server. 35 | ${START_SERVER_CMD} --port="${PORT}" & 36 | sleep 10 37 | 38 | # Start client and run test: from scratch. 39 | ${START_CLIENT_CMD} \ 40 | --endpoint="http://localhost:${PORT}" \ 41 | --cache_dir=${CACHE_DIR} & sleep 3 || clean_fail 42 | 43 | # Start client and run test: cached replies. 44 | ${START_CLIENT_CMD} \ 45 | --endpoint="http://localhost:${PORT}" --cache_dir=${CACHE_DIR}\ 46 | --load_cache_file=True || clean_fail 47 | 48 | # Stop model_server. 49 | kill $! 50 | -------------------------------------------------------------------------------- /onetwo/ot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Entry point to the OneTwo library. 16 | 17 | ```python 18 | from onetwo import ot 19 | ``` 20 | """ 21 | 22 | from onetwo import version 23 | from onetwo.core import composing 24 | from onetwo.core import executing 25 | from onetwo.core import results 26 | from onetwo.core import routing 27 | from onetwo.core import sampling 28 | from onetwo.evaluation import evaluation 29 | 30 | __version__: str = version.__version__ 31 | 32 | compare_with_critic = evaluation.compare_with_critic 33 | copy_registry = routing.copy_registry 34 | evaluate = evaluation.evaluate 35 | Executable = executing.Executable 36 | function_call = routing.function_registry 37 | function_registry = routing.function_registry 38 | HTMLRenderer = results.HTMLRenderer 39 | make_composable = composing.make_composable 40 | make_executable = executing.make_executable 41 | naive_comparison_critic = evaluation.naive_comparison_critic 42 | naive_evaluation_critic = evaluation.naive_evaluation_critic 43 | naive_fuzzy_evaluation_critic = evaluation.naive_fuzzy_evaluation_critic 44 | par_iter = executing.par_iter 45 | parallel = executing.parallel 46 | RegistryContext = routing.RegistryContext 47 | repeat = sampling.repeat 48 | run = executing.run 49 | safe_stream = executing.safe_stream 50 | set_registry = routing.set_registry 51 | stream_updates = executing.stream_updates 52 | stream_with_callback = executing.stream_with_callback 53 | with_current_registry = routing.with_current_registry 54 | with_registry = routing.with_registry 55 | 56 | -------------------------------------------------------------------------------- /colabs/images/image_provenance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Stores provenance information for images used in OneTwo Colab. 16 | 17 | This module contains metadata about images, including their external URLs and 18 | licensing information, primarily used for image provenance tracking. 19 | """ 20 | 21 | import dataclasses 22 | from typing import Dict 23 | 24 | 25 | @dataclasses.dataclass 26 | class ImageSource: 27 | """Stores provenance and licensing information for an image asset.""" 28 | # The external URL where the image was originally found. 29 | external_url: str = '' 30 | # The local path where the image is stored in the repository. 31 | local_path: str = '' 32 | # The license under which the image is distributed. 33 | license: str = '' 34 | 35 | 36 | # A constant mapping image names (keys) to their ImageSource metadata (values). 37 | image_sources: Dict[str, ImageSource] = { 38 | 'paligemma_fox': ImageSource( 39 | external_url='https://big-vision-paligemma.hf.space/file=/tmp/gradio/4aa2d3fd01a6308961397f68e043b2015bc91493/image.png', 40 | local_path='paligemma_fox.png', 41 | license=( 42 | 'CC0 by [XiaohuaZhai@](https://sites.google.com/corp/view/xzhai)' 43 | ), 44 | ), 45 | 'paligemma_puffin': ImageSource( 46 | external_url='https://big-vision-paligemma.hf.space/file=/tmp/gradio/78f93b49088f8d72ee546d656387403d647b413f/image.png', 47 | local_path='paligemma_puffin.png', 48 | license=( 49 | 'CC0 by [XiaohuaZhai@](https://sites.google.com/corp/view/xzhai)' 50 | ), 51 | ), 52 | } 53 | -------------------------------------------------------------------------------- /onetwo/backends/run_model_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Binary that starts a OneTwo Model Server (serves as an example/test).""" 16 | 17 | import argparse 18 | import importlib 19 | import json 20 | from typing import Final 21 | 22 | import uvicorn 23 | 24 | 25 | _DEFAULT_PORT = 8000 26 | _DEFAULT_BACKEND_MODULE: Final[str] = 'onetwo.backends.test_utils' 27 | _DEFAULT_BACKEND_CLASS: Final[str] = 'LLMForTest' 28 | _DEFAULT_BACKEND_ARGS: Final[str] = '{"default_reply": "Test reply"}' 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser('OneTwo Model Server.') 33 | parser.add_argument( 34 | '--port', type=int, default=_DEFAULT_PORT, help='Port to listen on.' 35 | ) 36 | parser.add_argument( 37 | '--backend_module', 38 | type=str, 39 | default=_DEFAULT_BACKEND_MODULE, 40 | help='Backend module to load.', 41 | ) 42 | parser.add_argument( 43 | '--backend_class', 44 | type=str, 45 | default=_DEFAULT_BACKEND_CLASS, 46 | help='Backend class to instantiate.', 47 | ) 48 | parser.add_argument( 49 | '--backend_args', 50 | type=str, 51 | default=_DEFAULT_BACKEND_ARGS, 52 | help='Arguments for the backend class constructor (in JSON format).', 53 | ) 54 | args = parser.parse_args() 55 | backend_module = importlib.import_module(args.backend_module) 56 | backend_class = getattr(backend_module, args.backend_class) 57 | backend_args = json.loads(args.backend_args) 58 | backend = backend_class(**backend_args) 59 | backend.register() 60 | 61 | uvicorn.run( 62 | 'onetwo.backends.model_server:ModelServer', 63 | host='0.0.0.0', 64 | port=args.port, 65 | factory=True, 66 | ) 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "onetwo" 3 | description = "LLM Orchestration Library by Google DeepMind." 4 | readme = "README.md" 5 | requires-python = ">=3.10" 6 | license = {file = "LICENSE"} 7 | authors = [{name = "OneTwo Authors", email="no-reply@google.com"}] 8 | classifiers = [ 9 | "Programming Language :: Python :: 3", 10 | "Programming Language :: Python :: 3 :: Only", 11 | "License :: OSI Approved :: Apache Software License", 12 | "Intended Audience :: Science/Research", 13 | ] 14 | keywords = [] 15 | 16 | # pip dependencies of the project. 17 | dependencies = [ 18 | "absl-py", 19 | "aenum", 20 | "dataclasses-json", 21 | "fastapi", 22 | "freezegun", 23 | # Note: there is a PYPI `gemma` library, which is not what we want. 24 | "gemma@git+https://github.com/google-deepmind/gemma.git", 25 | "google-cloud-aiplatform", 26 | "google-generativeai", 27 | "html5lib", 28 | "immutabledict", 29 | "jinja2", 30 | "numpy", 31 | "openai", 32 | "pillow", 33 | "pytest", 34 | "portpicker", # For `backends/onetwo_api_test_script.sh`. 35 | "pyyaml", # For `import yaml`. 36 | "termcolor", 37 | "tqdm", 38 | "typing_extensions", 39 | "uvicorn", 40 | ] 41 | 42 | # This is set automatically by flit using `onetwo.__version__`. 43 | dynamic = ["version"] 44 | 45 | [project.urls] 46 | homepage = "https://github.com/google-deepmind/onetwo" 47 | repository = "https://github.com/google-deepmind/onetwo" 48 | 49 | [project.optional-dependencies] 50 | # Installed through `pip install '.[dev]'`. 51 | dev = [ 52 | "pytest-xdist", 53 | "pylint>=2.6.0", 54 | "pyink", 55 | ] 56 | 57 | # Installed through `pip install '.[docs]'`. 58 | docs = [ 59 | # Install `apitree` with all extensions (sphinx, theme,...) 60 | "sphinx-apitree[ext]", 61 | ] 62 | 63 | [tool.pytest.ini_options] 64 | addopts = [ 65 | "--import-mode=importlib", 66 | ] 67 | 68 | [tool.pyink] 69 | # Formatting configuration to follow Google style-guide. 70 | line-length = 80 71 | preview = true 72 | pyink-indentation = 2 73 | pyink-use-majority-quotes = true 74 | 75 | [build-system] 76 | # Build system specify which backend is used to build/install the project (flit, 77 | # poetry, setuptools, ...). All backends are supported by `pip install`. 78 | requires = ["flit_core >=3.8,<4"] 79 | build-backend = "flit_core.buildapi" 80 | 81 | [tool.flit.sdist] 82 | # Flit specific options (files to exclude from the PyPI package). 83 | # If using another build backend (setuptools, poetry), you can remove this 84 | # section. 85 | exclude = [ 86 | # Do not release tests files on PyPI. 87 | "**/*_test.py", 88 | ] 89 | -------------------------------------------------------------------------------- /onetwo/backends/formatters_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for formatting.""" 16 | import asyncio 17 | from collections.abc import Sequence 18 | from typing import TypeAlias 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | # Necessary for the FormatterName enum to be populated. 23 | from onetwo.backends import formatters # pylint: disable=unused-import 24 | from onetwo.builtins import formatting 25 | from onetwo.core import content as content_lib 26 | 27 | 28 | _Message: TypeAlias = content_lib.Message 29 | _PredefinedRole: TypeAlias = content_lib.PredefinedRole 30 | 31 | 32 | class FormattersTest(parameterized.TestCase): 33 | 34 | @parameterized.named_parameters( 35 | ( 36 | 'gemma_user_only', 37 | formatting.FormatterName.GEMMA, 38 | [_Message(role=_PredefinedRole.USER, content='Hello')], 39 | 'user\nHello', 40 | ), 41 | ( 42 | 'gemma_user_and_model', 43 | formatting.FormatterName.GEMMA, 44 | [ 45 | _Message(role=_PredefinedRole.USER, content='Hello'), 46 | _Message(role=_PredefinedRole.MODEL, content='What'), 47 | ], 48 | 'user\nHello\nmodel\nWhat', 49 | ), 50 | ( 51 | 'gemma_user_and_empty_model', 52 | formatting.FormatterName.GEMMA, 53 | [ 54 | _Message(role=_PredefinedRole.USER, content='Hello'), 55 | _Message(role=_PredefinedRole.MODEL, content=''), 56 | ], 57 | 'user\nHello\nmodel', 58 | ), 59 | ) 60 | def test_format( 61 | self, 62 | formatter: formatting.FormatterName, 63 | messages: Sequence[_Message], 64 | expected: str, 65 | ): 66 | async def wrapper(): 67 | formatter_class = formatting.FORMATTER_CLASS_BY_NAME[formatter] 68 | formatter_instance = formatter_class() # pytype: disable=not-instantiable 69 | result = formatter_instance.format(messages) 70 | return result 71 | 72 | result = asyncio.run(wrapper()) 73 | self.assertEqual(str(result), expected) 74 | 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /onetwo/builtins/callbacks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Sequence 16 | from typing import TypeVar 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from onetwo.builtins import callbacks 21 | from onetwo.builtins import llm 22 | from onetwo.core import content as content_lib 23 | from onetwo.core import executing 24 | from onetwo.core import templating 25 | 26 | 27 | _T = TypeVar('_T') 28 | 29 | 30 | Chunk = content_lib.Chunk 31 | ChunkList = content_lib.ChunkList 32 | 33 | 34 | # TODO: For now all the tests are in prompt_templating_test.py. 35 | # We should move them here. 36 | class CallbacksTest(parameterized.TestCase): 37 | 38 | def setUp(self): 39 | super().setUp() 40 | 41 | # This class tests various `llm` builtins. In case `import llm` is not 42 | # executed (this may happen when running `pytest` with multiple tests that 43 | # import `llm` module) various builtins from `llm` may be already configured 44 | # elsewhere in unexpected ways. We manually reset all the default builtin 45 | # implementations to make sure they are set properly. 46 | llm.reset_defaults() 47 | 48 | def generate( 49 | prompt: str | ChunkList, 50 | *, 51 | temperature: float | None = None, 52 | max_tokens: int | None = None, 53 | stop: Sequence[str] | None = None, 54 | top_k: int | None = None, 55 | top_p: float | None = None, 56 | ) -> str: 57 | del prompt, temperature, max_tokens, stop, top_k, top_p 58 | return ' done' 59 | 60 | def score( 61 | prompt: str | ChunkList, targets: Sequence[str] 62 | ) -> Sequence[float]: 63 | del prompt 64 | # We score by the length of the target. 65 | return [float(len(target)) for target in targets] 66 | 67 | llm.generate_text.configure(generate) 68 | llm.score_text.configure(score) 69 | 70 | def test_generate_text(self): 71 | tpl = templating.JinjaTemplate(text='{{ generate_text() }}') 72 | tpl.register_callback( 73 | 'generate_text', callbacks.generate_text, pass_context=True 74 | ) 75 | res = executing.run(tpl.render()) 76 | self.assertEqual(res['prefix'], ' done') 77 | 78 | 79 | if __name__ == '__main__': 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /onetwo/stdlib/ensembling/distribution_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Metric functions taking a distribution of model predictions as input. 16 | 17 | These can be used, for example, for evaluating the outputs of a SelfConsistency 18 | strategy. 19 | """ 20 | 21 | from collections.abc import Sequence 22 | import dataclasses 23 | from typing import Generic, TypeVar 24 | from onetwo.core import executing 25 | from onetwo.core import tracing 26 | from onetwo.core import utils 27 | from onetwo.evaluation import agent_evaluation 28 | 29 | # Type representing a strategy's output. 30 | _O = TypeVar('_O') 31 | 32 | 33 | @dataclasses.dataclass 34 | class AccuracyAtK(Generic[_O]): 35 | """Returns the maximum accuracy among the first k predictions. 36 | 37 | Attributes: 38 | k: Number of candidate answers to consider for accuracy. If None, then will 39 | consider all predictions (even those with probability 0). If non-None, 40 | then will only consider the first k. 41 | base_metric: Metric to use for calculating the accuracy of an individual 42 | prediction vs. the target. Defaults to exact-match accuracy. Bigger is 43 | assumed to mean better. 44 | """ 45 | 46 | k: int | None = 1 47 | base_metric: agent_evaluation.MetricFunction = lambda t, p: 1.0 * (t == p) 48 | 49 | @executing.make_executable(copy_self=False) 50 | @tracing.trace(name=utils.FROM_INSTANCE_CLASS_NAME) 51 | async def __call__( 52 | self, target: _O, prediction: Sequence[tuple[_O, float]] 53 | ) -> float: 54 | """Returns accuracy at k (value from 0.0 to 1.0). 55 | 56 | Args: 57 | target: The target (single answer). 58 | prediction: The predicted distribution over possible answers. 59 | """ 60 | k = self.k if self.k is not None else len(prediction) 61 | if k < 0: 62 | raise ValueError(f'k must be non-negative, got {k}') 63 | 64 | # For simplicity, we are currently calling `self.base_metric` sequentially 65 | # for each of the predictions. In the case where `self.base_metric` is an 66 | # `async` function, though, this could potentially be optimized in the 67 | # future by wrapping the calls to the base metric with `executing.parallel`. 68 | base_metric_values = [ 69 | await utils.call_and_maybe_await(self.base_metric, target, x) 70 | for x, _ in prediction[:k] 71 | ] 72 | return max(base_metric_values, default=0.0) 73 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # OneTwo FAQ 2 | 3 | ## Is there a way to see actual formatted prompts that are sent to the backends? 4 | 5 | While some of the builtin functions (e.g., `llm.generate_text`) apply little to 6 | no modification to the prompt before it is sent to the model, others (e.g., 7 | `llm.instruct` or `llm.chat`) may apply a lot of formatting. For example, 8 | one natural way of implementing `llm.instruct` for a model that is only 9 | pre-trained (PT) but not instruction-tuned (IT) is to first format the 10 | user-provided task, e.g. `'Write me a short poem'`, into a longer prompt similar 11 | to `'Here is the task: Write me a short poem. Here is the answer:'` and then 12 | send it to pure completion with `llm.generate_text`. Indeed, this is precisely 13 | how the default implementation of `llm.instruct` 14 | (`onetwo.builtins.default_instruct`) works. 15 | 16 | In cases like this, where a non-trivial formatting of prompts takes place, the 17 | user may naturally want to see the actual fully formatted prompt that is sent to 18 | a model. A simple way to do it, which often comes handy when debugging, is to 19 | configure (mock) `llm.generate_text` with a fake implementation that simply 20 | returns the prompt (for convenience we provide such an implementation in 21 | `onetwo.builtins.echo_generate_text`): 22 | 23 | ```python 24 | import ot 25 | from onetwo.builtins import llm 26 | backend = ... 27 | # Assume this backend only configures `llm.generate_text`, i.e. `llm.instruct` 28 | # is configured to use the default OneTwo implementation. 29 | backend.register() 30 | print(ot.run(llm.generate_text('Once upon a'))) 31 | print(ot.run(llm.instruct('Name three cities in France.'))) 32 | 33 | def fake_generate_text(prompt: str | content_lib.ChunkList, **kwargs): 34 | return prompt 35 | 36 | # Alternatively, use `onetwo.builtins.echo_generate_text`. 37 | llm.generate_text.configure(fake_generate_text) 38 | # Now `llm.generate_text` simply returns its input. 39 | 40 | assert ot.run(llm.generate_text('Once upon a')) == 'Once upon a' 41 | # `llm.instruct` formats the prompt and sends it to `llm.generate_text`, 42 | # which returns the formatted prompt. 43 | print(ot.run(llm.instruct('Name three cities in France.'))) 44 | # We should get something like 'Task: Name three cities in France.\n Answer:'. 45 | 46 | backend.register() 47 | # Now `llm.generate_text` points again to the backend implementation. 48 | ``` 49 | 50 | This approach assumes that you know exactly where the first call to the 51 | *external model API* happens (e.g., `llm.generate_text` in the example above) so 52 | that you can mock it. 53 | 54 | In the future we plan to introduce a more principled and unified way of doing 55 | this. 56 | Likely we will base it on the `onetwo.core.tracing.trace` decorator that is 57 | already available in OneTwo (refer to the "Agents and Tool Use" section of our 58 | [Colab](https://colab.research.google.com/github/google-deepmind/onetwo/blob/main/colabs/tutorial.ipynb) 59 | for more details on tracing with OneTwo). 60 | -------------------------------------------------------------------------------- /onetwo/stdlib/code_execution/python_execution_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from onetwo.stdlib.code_execution import python_execution 18 | 19 | _ExecutionStatus = python_execution.ExecutionStatus 20 | _SandboxResult = python_execution.SandboxResult 21 | 22 | 23 | class SandboxResultTest(parameterized.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ( 27 | 'success_no_output', 28 | python_execution.SandboxResult(), 29 | 'None', 30 | ), 31 | ( 32 | 'success_stdout_only', 33 | python_execution.SandboxResult(stdout='Hi'), 34 | 'Hi', 35 | ), 36 | ( 37 | 'success_final_expression_value_only', 38 | python_execution.SandboxResult(final_expression_value=2), 39 | '2', 40 | ), 41 | ( 42 | 'success_stdout_and_final_expression_value', 43 | python_execution.SandboxResult(final_expression_value=2, stdout='Hi'), 44 | 'Hi\n2', 45 | ), 46 | ( 47 | 'error_no_output', 48 | python_execution.SandboxResult( 49 | execution_status=_ExecutionStatus.EXECUTION_ERROR, 50 | status_message='Error message.', 51 | ), 52 | 'EXECUTION_ERROR: Error message.', 53 | ), 54 | ( 55 | 'error_stdout_only', 56 | python_execution.SandboxResult( 57 | stdout='Hi', 58 | execution_status=_ExecutionStatus.EXECUTION_ERROR, 59 | status_message='Error message.', 60 | ), 61 | 'Hi\nEXECUTION_ERROR: Error message.', 62 | ), 63 | ( 64 | 'error_final_expression_value_only', 65 | python_execution.SandboxResult( 66 | final_expression_value=2, 67 | execution_status=_ExecutionStatus.EXECUTION_ERROR, 68 | status_message='Error message.', 69 | ), 70 | '2\nEXECUTION_ERROR: Error message.', 71 | ), 72 | ( 73 | 'error_stdout_and_final_expression_value', 74 | python_execution.SandboxResult( 75 | final_expression_value=2, 76 | stdout='Hi', 77 | execution_status=_ExecutionStatus.EXECUTION_ERROR, 78 | status_message='Error message.', 79 | ), 80 | 'Hi\n2\nEXECUTION_ERROR: Error message.', 81 | ), 82 | ) 83 | def test_to_string(self, result, expected_string): 84 | self.assertEqual(expected_string, str(result)) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /onetwo/builtins/tool_use_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from onetwo.builtins import tool_use 20 | from onetwo.core import executing 21 | from onetwo.core import routing 22 | 23 | 24 | def _add_sync(arg1: Any, arg2: Any) -> Any: 25 | return arg1 + arg2 26 | 27 | 28 | async def _add_async(arg1: Any, arg2: Any) -> Any: 29 | return arg1 + arg2 30 | 31 | 32 | @executing.make_executable # pytype: disable=wrong-arg-types 33 | def _add_executable(arg1: Any, arg2: Any) -> Any: 34 | return arg1 + arg2 35 | 36 | 37 | class ToolUseTest(parameterized.TestCase): 38 | 39 | def setUp(self): 40 | super().setUp() 41 | 42 | # This class tests routing.function_registry. In case `import routing` is 43 | # not executed (this may happen when running `pytest` with multiple tests 44 | # that import `llm` module) the `function_registry` may be already filled 45 | # with various functions elsewhere in unexpected ways. We manually remove 46 | # all the keys to make sure it is empty. 47 | routing.function_registry.clear() 48 | # Unfortunately, we also removed all the builtins configured when importing 49 | # tool_use. Let's re-set them. 50 | # TODO:` or such 51 | # for better control of reproducibility. 52 | tool_use.reset_defaults() 53 | 54 | @parameterized.named_parameters( 55 | ('sync', _add_sync), 56 | ('async', _add_async), 57 | ('executable', _add_executable), 58 | ) 59 | def test_default_run_tool_function_types(self, add_function): 60 | routing.function_registry['add'] = add_function 61 | result = executing.run( 62 | tool_use.run_tool( # pytype: disable=wrong-keyword-args 63 | tool_name='add', tool_args=['a', 'b'], tool_kwargs={} 64 | ) 65 | ) 66 | self.assertEqual('ab', result) 67 | 68 | @parameterized.named_parameters( 69 | ('positional_args_numeric', [1, 2], {}, 3), 70 | ('keyword_args_numeric', [], {'arg1': 1, 'arg2': 2}, 3), 71 | ('positional_args_string', ['1', '2'], {}, '12'), 72 | ) 73 | def test_default_run_tool_args_types(self, args, kwargs, expected_result): 74 | def add(arg1: Any, arg2: Any) -> Any: 75 | return arg1 + arg2 76 | 77 | routing.function_registry['add'] = add 78 | result = executing.run( 79 | tool_use.run_tool( # pytype: disable=wrong-keyword-args 80 | tool_name='add', tool_args=args, tool_kwargs=kwargs 81 | ) 82 | ) 83 | self.assertEqual(expected_result, result) 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /onetwo/stdlib/ensembling/distribution_metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from onetwo.core import executing 18 | from onetwo.stdlib.ensembling import distribution_metrics 19 | 20 | 21 | # Simple custom base metric that gives full credit if the correct answer is 22 | # included anywhere in the prediction (e.g., if target is `12` and the 23 | # prediction is 'The answer is 12.') 24 | def answer_included_sync(target: str, prediction: str): 25 | return 1.0 if target in prediction else 0.0 26 | 27 | 28 | async def answer_included_async(target: str, prediction: str): 29 | return answer_included_sync(target, prediction) 30 | 31 | 32 | @executing.make_executable # pytype: disable=wrong-arg-types 33 | def answer_included_executable(target: str, prediction: str): 34 | return answer_included_sync(target, prediction) 35 | 36 | 37 | class DistributionMetricsTest(parameterized.TestCase): 38 | 39 | @parameterized.named_parameters( 40 | ('correct_at_1', 'a', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 1, 1.0), 41 | ('wrong_at_1', 'b', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 1, 0.0), 42 | ('correct_at_2', 'b', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 2, 1.0), 43 | ('wrong_at_2', 'c', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 2, 0.0), 44 | ('case_sensitive', 'A', [('a', 0.5), ('b', 0.3), ('c', 0.2)], 1, 0.0), 45 | ('empty_distribution', 'a', [], 1, 0.0), 46 | ('number_correct', 1.0, [(1, 0.5), (2, 0.3), (3, 0.2)], 1, 1.0), 47 | ('number_wrong', 2.0, [(1, 0.5), (2, 0.3), (3, 0.2)], 1, 0.0), 48 | ) 49 | def test_accuracy_at_k_default_base_metric( 50 | self, target, predicted_distribution, k, expected 51 | ): 52 | metric = distribution_metrics.AccuracyAtK(k=k) 53 | actual = executing.run( 54 | metric(target=target, prediction=predicted_distribution) # pytype: disable=wrong-keyword-args 55 | ) 56 | self.assertEqual(expected, actual) 57 | 58 | @parameterized.named_parameters( 59 | ('sync', answer_included_sync), 60 | ('async', answer_included_async), 61 | ('executable', answer_included_executable), 62 | ) 63 | def test_accuracy_at_k_custom_base_metric(self, base_metric): 64 | metric = distribution_metrics.AccuracyAtK( 65 | k=2, base_metric=base_metric 66 | ) 67 | 68 | actual = executing.run( 69 | metric(target='b', prediction=[('=a', 0.5), ('=b', 0.3), ('=c', 0.2)]) # pytype: disable=wrong-keyword-args 70 | ) 71 | with self.subTest('correct_at_2'): 72 | self.assertEqual(1.0, actual, actual) 73 | 74 | actual = executing.run( 75 | metric(target='c', prediction=[('=a', 0.5), ('=b', 0.3), ('=c', 0.2)]) # pytype: disable=wrong-keyword-args 76 | ) 77 | with self.subTest('wrong_at_2'): 78 | self.assertEqual(0.0, actual, actual) 79 | 80 | 81 | if __name__ == '__main__': 82 | absltest.main() 83 | -------------------------------------------------------------------------------- /onetwo/backends/formatters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of various formatters for different models.""" 16 | 17 | from collections.abc import Sequence 18 | from typing import Final, TypeAlias 19 | 20 | import aenum 21 | from onetwo.builtins import formatting 22 | from onetwo.core import content as content_lib 23 | 24 | 25 | _PredefinedRole = content_lib.PredefinedRole 26 | _Message: TypeAlias = content_lib.Message 27 | _Chunk: TypeAlias = content_lib.Chunk 28 | _ChunkList: TypeAlias = content_lib.ChunkList 29 | _FormatterName = formatting.FormatterName 30 | _Formatter = formatting.Formatter 31 | 32 | 33 | _START_OF_TURN: Final[str] = '' 34 | _END_OF_TURN: Final[str] = '' 35 | 36 | 37 | class GemmaFormatter(_Formatter): 38 | """Gemma formatter for instruction tuned models.""" 39 | 40 | @property 41 | def role_map(self) -> dict[str | _PredefinedRole, str]: 42 | """Returns a mapping from role to string representation. 43 | 44 | This serves two purposes: 45 | - It specifies which roles are supported by the formatter. Any role that is 46 | not in this map will be ignored or raise an error if 47 | `raise_error_if_unsupported_roles=True` in the call to `format`. 48 | - It allows to specify how the role name is represented in the prompt. 49 | """ 50 | return { 51 | _PredefinedRole.USER: 'user', 52 | _PredefinedRole.MODEL: 'model' 53 | } 54 | 55 | def is_role_supported(self, role: str| _PredefinedRole) -> bool: 56 | """Overridden from base class (Formatter).""" 57 | return role in self.role_map 58 | 59 | def is_already_formatted(self, content: Sequence[_Message]) -> bool: 60 | """Overridden from base class (Formatter).""" 61 | if not content: 62 | return False 63 | return any( 64 | _START_OF_TURN in str(message.content) 65 | for message in content 66 | ) 67 | 68 | def extra_stop_sequences(self) -> list[str]: 69 | """Overridden from base class (Formatter).""" 70 | return [_START_OF_TURN] 71 | 72 | def _format( 73 | self, 74 | content: Sequence[_Message], 75 | ) -> _ChunkList: 76 | """Overridden from base class (Formatter).""" 77 | res = [] 78 | for i, message in enumerate(content): 79 | role_name = self.role_map[message.role] 80 | if i == len(content) - 1 and message.role == _PredefinedRole.MODEL: 81 | if (str(message.content)): 82 | res.append(f'{_START_OF_TURN}{role_name}\n{message.content}') 83 | else: 84 | res.append(f'{_START_OF_TURN}{role_name}') 85 | else: 86 | res.append( 87 | f'{_START_OF_TURN}{role_name}\n{message.content}{_END_OF_TURN}' 88 | ) 89 | res = '\n'.join(res) 90 | return _ChunkList([_Chunk(res)]) 91 | 92 | 93 | aenum.extend_enum(formatting.FormatterName, 'GEMMA', 'gemma') 94 | formatting.FORMATTER_CLASS_BY_NAME[_FormatterName.GEMMA] = GemmaFormatter 95 | -------------------------------------------------------------------------------- /onetwo/core/updating_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import copy 18 | from typing import TypeAlias 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from onetwo.core import updating 23 | 24 | 25 | Update = updating.Update 26 | ListUpdate = updating.ListUpdate 27 | 28 | _S: TypeAlias = updating.AddableUpdate[str] 29 | 30 | 31 | class UpdatingTest(parameterized.TestCase): 32 | 33 | def test_list_update_accumulate(self): 34 | u1 = Update('a') 35 | u2 = Update('b') 36 | l1 = ListUpdate([(u1, 0)]) 37 | l2 = ListUpdate([(u2, 0)]) 38 | l3 = copy.deepcopy(l2) 39 | 40 | with self.subTest('start_from_scratch'): 41 | self.assertEqual(Update() + l2, l3) 42 | 43 | with self.subTest('should_replace'): 44 | self.assertEqual(l1 + l2, l3) 45 | 46 | def test_plus_and_sum(self): 47 | list_updates1 = (ListUpdate([(i, i)]) for i in range(3)) 48 | list_updates2 = (ListUpdate([(i, i)]) for i in range(3)) 49 | update = Update() 50 | for u in list_updates1: 51 | update += u 52 | final = sum(list_updates2, start=Update()) 53 | self.assertEqual(update, final) 54 | self.assertEqual(final.to_result(), [0, 1, 2]) 55 | 56 | @parameterized.named_parameters( 57 | ('empty', ListUpdate([]), []), 58 | ('singleton', ListUpdate([(Update('a'), 0)]), ['a']), 59 | ( 60 | 'non_contiguous', 61 | ListUpdate([ 62 | (Update('a'), 0), 63 | (Update('b'), 2), 64 | (Update('c'), 4), 65 | ]), 66 | ['a', None, 'b', None, 'c'], 67 | ), 68 | ( 69 | 'nested', 70 | ListUpdate([ 71 | (ListUpdate([(Update('a'), 1)]), 1), 72 | ]), 73 | [None, [None, 'a']], 74 | ), 75 | ( 76 | 'nested_another_example', 77 | ListUpdate([ 78 | (ListUpdate([(10, 1)]), 0), 79 | (ListUpdate([(20, 0)]), 2), 80 | ]), 81 | [[None, 10], None, [20,]], 82 | ), 83 | ( 84 | 'nested_deeper', 85 | ListUpdate([ 86 | ( 87 | ListUpdate([ 88 | (ListUpdate([('ab', 1)]), 1) 89 | ]), 90 | 0 91 | ), 92 | ( 93 | ListUpdate([ 94 | (ListUpdate([('cd', 0)]), 1) 95 | ]), 96 | 2 97 | ), 98 | ]), 99 | [[None, [None, 'ab']], None, [None, ['cd']]], 100 | ), 101 | ) 102 | def test_list_update_to_accumulate(self, list_update, result): 103 | self.assertListEqual(list_update.to_result(), result) 104 | 105 | def test_nested(self): 106 | u1 = Update('a') 107 | u2 = Update(u1) 108 | self.assertEqual(u1.to_result(), u2.to_result()) 109 | 110 | def test_addable(self): 111 | l = _S('a') 112 | self.assertEqual(l.to_result(), 'a') 113 | l += 'b' 114 | self.assertEqual(l.to_result(), 'ab') 115 | l += _S('c') + 'd' 116 | self.assertEqual(l.to_result(), 'abcd') 117 | 118 | 119 | if __name__ == '__main__': 120 | absltest.main() 121 | -------------------------------------------------------------------------------- /onetwo/backends/backends_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines the base Backend class.""" 16 | 17 | import dataclasses 18 | 19 | 20 | @dataclasses.dataclass 21 | class Backend: 22 | """Interface for a class that registers its method. 23 | 24 | Attributes: 25 | name: An optional name for this backend instance. 26 | """ 27 | 28 | # A name for this backend instance. 29 | name: str = dataclasses.field( 30 | # We set `init=False` to prevent `name` from becoming an `__init__` 31 | # argument with a default value. If `name` were an `__init__` 32 | # argument, subclasses would not be able to define positional 33 | # arguments, because Python does not allow positional arguments to 34 | # follow arguments with default values. 35 | init=False, 36 | default='', 37 | ) 38 | 39 | # Name under which the methods are registered by default. 40 | _default_name: str = dataclasses.field( 41 | init=False, 42 | default='default', 43 | ) 44 | 45 | # Not an abstract method since by default it is ok not to register anything. 46 | def register(self, name: str | None = None): 47 | """Add the relevant methods to the registry. 48 | 49 | It is necessary for a child class to override this method in order to 50 | specify which methods it wants to expose/register. 51 | In particular this can be used to call `configure` or `get_variant` on 52 | builtins or to directly register methods of this object under particular 53 | names. 54 | For example the body of this method can look like: 55 | ``` 56 | registry = routing.function_registry 57 | # Register self.my_fn with the provided or default name prepended. 58 | registry[f'{name or self._default_name}.my_fn'] = self.my_fn 59 | # Configure some builtin. 60 | llm.generate_text.configure( 61 | self.my_generate_text, 62 | temperature=default_temperature 63 | ) 64 | # Configure a variant of the builtin with a different default parameter. 65 | variant = llm.generate_text.get_variant( 66 | self.my_generate_text 67 | temperature=0.0 68 | ) 69 | # Register the variant under a special name. 70 | registry['llm.zero_temp'] = variant 71 | ``` 72 | 73 | Args: 74 | name: An optional argument to use as prefix of the names in the registry, 75 | if None, the _default_name attribute may be used (it can be set in a 76 | child class declaration). 77 | """ 78 | 79 | 80 | def truncate_reply(reply: str, stop_sequences: list[str]) -> str: 81 | """Returns the truncated reply not including any stop sequence. 82 | 83 | Example: 84 | If reply is 'abcdefg' and stop_sequences is ['f', 'c'], then the truncated 85 | reply is 'ab'. 86 | 87 | Args: 88 | reply: The original reply to be truncated. 89 | stop_sequences: The substrings to truncate the reply at. The stop sequences 90 | are not included in the truncated reply. 91 | """ 92 | if not stop_sequences: 93 | return reply 94 | 95 | # Some stop sequences may be overlapping, and we want to guarantee that we 96 | # have the shortest possible prefix after truncation. 97 | # So if we have a reply 'abcdefg' and use as stop sequences ['f', 'def'] 98 | # we want to return 'abc'. This means we have to compute separately the 99 | # truncation for each stop sequence and pick the smallest obtained prefix. 100 | truncated_replies = set() 101 | for stop_sequence in stop_sequences: 102 | truncated_replies.add(reply.split(stop_sequence)[0]) 103 | 104 | # min will sort alphabetically, which, since we only have prefixes of a common 105 | # string, will result in returning the shortest one. 106 | return min(truncated_replies) 107 | -------------------------------------------------------------------------------- /onetwo/builtins/tool_use.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Definitions of built-in functions and utilities for tool use. 16 | 17 | A tool can be an arbitrary function that we want to expose to the LLM so that 18 | it can call it. 19 | """ 20 | 21 | from typing import Any 22 | 23 | from onetwo.builtins import builtins_base 24 | from onetwo.core import constants 25 | from onetwo.core import executing 26 | from onetwo.core import routing 27 | from onetwo.core import utils 28 | 29 | 30 | @builtins_base.Builtin 31 | def run_tool( 32 | tool_name: str, tool_args: tuple[Any, ...], tool_kwargs: dict[str, Any] 33 | ) -> Any: 34 | """Interface of the run_tool built-in function. 35 | 36 | Runs a tool and returns the result. 37 | 38 | Args: 39 | tool_name: The name of the tool to run. Depending on the implementation, 40 | this could either be the name of a function in the function registry, or 41 | could be another tool name that is indirectly mapped to such a function. 42 | tool_args: Position args to pass to the tool function. 43 | tool_kwargs: Keyword args to pass to the tool function. 44 | 45 | Returns: 46 | The return value of the tool function. 47 | """ 48 | del tool_name, tool_args, tool_kwargs 49 | raise NotImplementedError( 50 | 'The implementation should be provided at runtime by calling `configure`' 51 | ' or `get_variant`. This function cannot be called directly.' 52 | ) 53 | 54 | 55 | @executing.make_executable # pytype: disable=wrong-arg-types 56 | async def default_run_tool( 57 | tool_name: str, tool_args: tuple[Any, ...], tool_kwargs: dict[str, Any] 58 | ) -> Any: 59 | """Default implementation of run_tool which calls function_registry.""" 60 | if tool_name == constants.ERROR_STRING: 61 | # ERROR_STRING as the tool_name is a special case, where we are expected 62 | # to simply echo the error message stored in the tool argument. 63 | if len(tool_args) != 1: 64 | raise ValueError( 65 | 'When tool_name is ERROR_STRING, we expect there to be exactly one' 66 | ' argument containing the detailed error message (e.g., an error' 67 | ' that occurred when parsing the LLM response to determine the' 68 | f' tool call). Instead found {len(tool_args)} arguments:' 69 | f' {tool_name=}, {tool_args=}, {tool_kwargs=}' 70 | ) 71 | return tool_args[0] 72 | 73 | try: 74 | tool_function = routing.function_registry.get(tool_name) 75 | if tool_function is None: 76 | raise ValueError( 77 | f'Function {tool_name} is not registered in the function_registry' 78 | f' ({routing.function_registry=}).' 79 | ) 80 | return await utils.call_and_maybe_await( 81 | tool_function, *tool_args, **tool_kwargs 82 | ) 83 | 84 | except ValueError as e: 85 | return f'{constants.ERROR_STRING}: {e}' 86 | 87 | 88 | def reset_defaults(): 89 | """Resets default implementations for all builtins in this file.""" 90 | # Keep all module level `some_builtin.configure(...)` commands in this method. 91 | run_tool.configure(default_run_tool) 92 | 93 | reset_defaults() 94 | # TODO: Define an additional `run_tool_text` builtin that takes as 95 | # input: 96 | # (1) a string like `f('a', 'b')` or `f(x, 'b')` or `y = f(x, 'b')` that can be 97 | # parsed into a tool call (possibly with variable references and/or variable 98 | # assignment) 99 | # (2) a context represented as a dictionary of variable values 100 | # (e.g., {'x': 'a'}) 101 | # and which returns as output a string containing the tool's return value or 102 | # error message in a form appropriate for presenting to an LLM in a prompt, 103 | # along with a dictionary of context updates in the case of a variable 104 | # assignment (e.g., {'y': 'ab'})). This `run_tool_text` builtin can further be 105 | # wrapped as a composable or as a callback in a Jinja prompt template. 106 | -------------------------------------------------------------------------------- /onetwo/backends/onetwo_api_manual_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Sequence 16 | import os 17 | import pprint 18 | import time 19 | 20 | from absl import app 21 | from absl import flags 22 | from onetwo.backends import onetwo_api 23 | from onetwo.builtins import llm 24 | from onetwo.core import executing 25 | from onetwo.core import sampling 26 | 27 | _ENDPOINT = flags.DEFINE_string( 28 | 'endpoint', 29 | default='http://localhost:9876', 30 | help='Endpoint to use.', 31 | ) 32 | _CACHE_DIR = flags.DEFINE_string( 33 | 'cache_dir', 34 | default='/tmp/onetwo_api', 35 | help='Directory where the cache will be stored.', 36 | ) 37 | _LOAD_CACHE = flags.DEFINE_bool( 38 | 'load_cache_file', 39 | default=False, 40 | help='Whether we should read the cache stored in file.' 41 | ) 42 | _PRINT_DEBUG = flags.DEFINE_bool( 43 | 'print_debug', 44 | default=False, 45 | help='Debug logging.' 46 | ) 47 | 48 | 49 | def main(argv: Sequence[str]) -> None: 50 | del argv 51 | fname = os.path.join(_CACHE_DIR.value, 'onetwo_api.json') 52 | backend = onetwo_api.OneTwoAPI( 53 | endpoint=_ENDPOINT.value, 54 | cache_filename=fname, 55 | batch_size=4, 56 | ) 57 | backend.register() 58 | if _LOAD_CACHE.value: 59 | print('Loading cache from file %s', fname) 60 | load_start = time.time() 61 | backend.load_cache() 62 | load_end = time.time() 63 | print('Spent %.4fsec loading cache.' % (load_end - load_start)) 64 | 65 | print('1. A single generate query.') 66 | prompt_text = """ 67 | Question: Natural logarithm of $e^12$? 68 | Reasoning: "Natural" logarithm means logarithm to the base of $e$. For 69 | example, natural logarithm of $10$ means exponent to which $e$ must be 70 | raised to produce 10. 71 | Answer: 12. 72 | Question: Differentiate $\\frac{1}{\\log(x)}$. 73 | """ 74 | res = executing.run(llm.generate_text( # pytype: disable=wrong-keyword-args 75 | prompt=prompt_text, 76 | stop=['\n\n'], 77 | )) 78 | if _PRINT_DEBUG.value: 79 | print('Returned value(s):') 80 | pprint.pprint(res) 81 | print('1.1 Same query to see if it has been cached.') 82 | res = executing.run(llm.generate_text( # pytype: disable=wrong-keyword-args 83 | prompt=prompt_text, 84 | stop=['\n\n'], 85 | )) 86 | if _PRINT_DEBUG.value: 87 | print('Returned value(s):') 88 | pprint.pprint(res) 89 | print('1.2 Same query but different parameters, run requests again.') 90 | res = executing.run(llm.generate_text( # pytype: disable=wrong-keyword-args 91 | prompt=prompt_text, 92 | temperature=0., 93 | stop=['\n\n'], 94 | )) 95 | if _PRINT_DEBUG.value: 96 | print('Returned value(s):') 97 | pprint.pprint(res) 98 | print('2. Repeated generate request.') 99 | exe = executing.par_iter(sampling.repeat( 100 | executable=llm.generate_text( # pytype: disable=wrong-keyword-args 101 | prompt='Today is', temperature=0.5, stop=['.']), 102 | num_repeats=5, 103 | )) 104 | res = executing.run(exe) 105 | if _PRINT_DEBUG.value: 106 | print('Returned value(s):') 107 | pprint.pprint(res) 108 | print('3. Three batched generate queries.') 109 | exe = executing.par_iter([ 110 | llm.generate_text(prompt='In summer', stop=['.']), # pytype: disable=wrong-keyword-args 111 | llm.generate_text(prompt='In winter', max_tokens=32), # pytype: disable=wrong-keyword-args 112 | llm.generate_text(prompt='In autumn'), # pytype: disable=wrong-keyword-args 113 | ]) 114 | res = executing.run(exe) 115 | if _PRINT_DEBUG.value: 116 | print('Returned value(s):') 117 | pprint.pprint(res) 118 | 119 | if not _LOAD_CACHE.value: 120 | start = time.time() 121 | backend.save_cache(overwrite=True) 122 | print('Took %.4fsec saving cache to %s.' % (time.time() - start, fname)) 123 | 124 | print('PASS') 125 | 126 | 127 | if __name__ == '__main__': 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /onetwo/core/core_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for OneTwo unit tests.""" 16 | 17 | import collections 18 | from collections.abc import Mapping 19 | import io 20 | import json 21 | import pprint 22 | from typing import Any, Literal 23 | import unittest 24 | from onetwo.core import results 25 | import PIL.Image 26 | 27 | 28 | class CounterAssertions(unittest.TestCase): 29 | """Mixin class for counter assertions.""" 30 | 31 | # pylint: disable=invalid-name 32 | def assertCounterEqual( 33 | self, 34 | counter_first: collections.Counter[str], 35 | counter_second: collections.Counter[str], 36 | ) -> None: 37 | # Remove zero values. 38 | first = counter_first - collections.Counter() 39 | second = counter_second - collections.Counter() 40 | message = f'A - B contains: {pprint.pformat(first - second)}\n' 41 | message += f'B - A contains: {pprint.pformat(second - first)}' 42 | return self.assertEqual(dict(first), dict(second), message) 43 | 44 | 45 | def maybe_read_file(filepath: str) -> str: 46 | """Returns the contents of the file if it exists, or else empty string.""" 47 | try: 48 | with open(filepath) as f: 49 | file_contents = f.read() 50 | return file_contents 51 | except IOError: 52 | return '' 53 | 54 | 55 | def maybe_read_json(filepath: str) -> Mapping[str, Any] | None: 56 | """Returns the file contents as JSON, or None if there is a problem.""" 57 | file_contents = maybe_read_file(filepath) 58 | try: 59 | return json.loads(file_contents) 60 | except json.JSONDecodeError: 61 | return None 62 | 63 | 64 | class MockTimer: 65 | """Mock timer for use in unit tests.""" 66 | 67 | def __init__(self): 68 | self._current_time = 0 69 | 70 | def __call__(self) -> float: 71 | self._current_time += 1 72 | return float(self._current_time) 73 | 74 | 75 | def reset_fields( 76 | er: list[results.ExecutionResult] | results.ExecutionResult, 77 | reset_values: Mapping[str, Any], 78 | ): 79 | """Recursively resset the given fields in the given execution result(s). 80 | 81 | Args: 82 | er: The ExecutionResult or list thereof to reset fields in. 83 | reset_values: Mapping field_name to the value to set when resetting. 84 | 85 | Returns: 86 | None. 87 | """ 88 | if isinstance(er, list): 89 | for sub_er in er: 90 | reset_fields(sub_er, reset_values) 91 | return 92 | 93 | for name, value in reset_values.items(): 94 | setattr(er, name, value) 95 | 96 | for stage in er.stages: 97 | reset_fields(stage, reset_values) 98 | 99 | 100 | def reset_times( 101 | er: list[results.ExecutionResult] | results.ExecutionResult, 102 | ) -> None: 103 | """Resets the start and end times in the given execution result(s).""" 104 | reset_fields(er, {'start_time': 0.0, 'end_time': 0.0}) 105 | 106 | 107 | def create_test_pil_image( 108 | fmt: Literal['PNG', 'JPEG', 'GIF'] = 'PNG', 109 | mode: Literal['RGB', 'RGBA', 'L', 'P'] = 'RGB', 110 | size: tuple[int, int] = (1, 1), 111 | ) -> PIL.Image.Image: 112 | """Returns a PIL Image in the given format, for testing purposes. 113 | 114 | Args: 115 | fmt: The target image format. 116 | mode: The image mode. 117 | size: The image dimensions (width, height). 118 | """ 119 | img = PIL.Image.new(mode, size) 120 | buffer = io.BytesIO() 121 | # GIF requires saving the palette. 122 | save_kwargs = ( 123 | {'format': fmt, 'save_all': True} 124 | if fmt == 'GIF' 125 | else {'format': fmt} 126 | ) 127 | try: 128 | # Save and reload the image to attempt to set the format attribute. 129 | img.save(buffer, **save_kwargs) 130 | buffer.seek(0) 131 | reloaded_img = PIL.Image.open(buffer) 132 | if reloaded_img.format is None: 133 | # PIL might not always preserve format. Inject it for testing purpose. 134 | # This might happen for simple modes like 'L' depending on PIL version. 135 | reloaded_img.format = fmt 136 | return reloaded_img 137 | except Exception as e: 138 | raise ValueError( 139 | f'Failed to create test image with format {fmt}: {e}' 140 | ) from e 141 | -------------------------------------------------------------------------------- /onetwo/backends/openai_mock.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A module that mocks the OpenAI library, for testing purposes. 16 | 17 | This document explains how to use the chat completion method: 18 | https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models 19 | """ 20 | 21 | from collections.abc import Mapping, Sequence 22 | from typing import Final, NamedTuple, TypeVar 23 | 24 | import pydantic 25 | 26 | 27 | _DEFAULT_REPLY: Final[str] = 'Hello' 28 | _DEFAULT_SCORE: Final[float] = -0.1 29 | _T = TypeVar('_T', bound=pydantic.BaseModel) 30 | 31 | 32 | class OpenAIMessage(NamedTuple): 33 | role: str 34 | content: str 35 | parsed: pydantic.BaseModel | None = None 36 | refusal: str | None = None 37 | 38 | 39 | class ChatCompletionTokenLogprob(NamedTuple): 40 | token: str 41 | logprob: float 42 | 43 | 44 | class ChoiceLogProbs(NamedTuple): 45 | content: Sequence[ChatCompletionTokenLogprob] 46 | 47 | 48 | class Choice(NamedTuple): 49 | index: int 50 | message: OpenAIMessage 51 | logprobs: ChoiceLogProbs 52 | finish_reason: str 53 | 54 | 55 | class ChatCompletion(NamedTuple): 56 | choices: Sequence[Choice] 57 | 58 | 59 | class OpenAI: 60 | """A mock of the OpenAI class that provides a chat.completions.create method. 61 | 62 | See 63 | https://github.com/openai/openai-python/blob/f0bdef04611a24ed150d19c4d180aacab3052704/src/openai/_client.py#L49 64 | """ 65 | 66 | def __init__(self, api_key: str | None = None): 67 | del api_key 68 | self._reply = _DEFAULT_REPLY 69 | 70 | @property 71 | def chat(self): 72 | return self 73 | 74 | @property 75 | def completions(self): 76 | return self 77 | 78 | def create( 79 | self, 80 | model: str, 81 | messages: Sequence[Mapping[str, str]], 82 | **kwargs, 83 | ) -> ChatCompletion: 84 | del model, messages 85 | samples = 1 86 | if 'n' in kwargs: 87 | samples = kwargs['n'] 88 | return ChatCompletion( 89 | choices=[ 90 | Choice( 91 | index=i, 92 | message=OpenAIMessage(role='assistant', content=self._reply), 93 | logprobs=ChoiceLogProbs( 94 | content=[ 95 | ChatCompletionTokenLogprob( 96 | token='a', logprob=_DEFAULT_SCORE 97 | ) 98 | ] 99 | ), 100 | finish_reason='stop', 101 | ) 102 | for i in range(samples) 103 | ] 104 | ) 105 | 106 | def parse( 107 | self, 108 | *, 109 | model: str, 110 | messages: Sequence[Mapping[str, str]], 111 | response_format: type[_T], 112 | **kwargs, 113 | ) -> ChatCompletion: 114 | """Mocks client.chat.completions.parse for structured outputs.""" 115 | del model, messages, kwargs 116 | parsed_object = None 117 | refusal_message = None 118 | 119 | if isinstance(response_format, type) and issubclass( 120 | response_format, pydantic.BaseModel 121 | ): 122 | try: 123 | mock_data = { 124 | 'name': 'Mock Test City', 125 | 'population': 500000, 126 | } 127 | # Create a validated instance from the mock data 128 | parsed_object = response_format(**mock_data) 129 | except Exception as e: # pylint: disable=broad-except 130 | refusal_message = f'Mock failed to construct {response_format}: {e}' 131 | else: 132 | refusal_message = ( 133 | '`response_format` must be a pydantic.BaseModel subclass for ' 134 | f'structured parsing, but got {response_format}.' 135 | ) 136 | 137 | mock_message = OpenAIMessage( 138 | role='assistant', 139 | content=self._reply, 140 | parsed=parsed_object, 141 | refusal=refusal_message, 142 | ) 143 | return ChatCompletion( 144 | choices=[ 145 | Choice( 146 | index=0, 147 | message=mock_message, 148 | logprobs=ChoiceLogProbs( 149 | content=[ 150 | ChatCompletionTokenLogprob( 151 | token='a', logprob=_DEFAULT_SCORE 152 | ) 153 | ] 154 | ), 155 | finish_reason='stop', 156 | ) 157 | ] 158 | ) 159 | -------------------------------------------------------------------------------- /onetwo/backends/model_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """OneTwo Model Server.""" 16 | 17 | import logging 18 | import sys 19 | import traceback 20 | from typing import Any 21 | 22 | import fastapi 23 | from onetwo.builtins import llm 24 | from onetwo.core import batching 25 | 26 | 27 | _Body = fastapi.Body 28 | 29 | 30 | def _get_http_exception(exception: Exception) -> Exception: 31 | error_message = 'An exception {} occurred. Arguments: {}.'.format( 32 | type(exception).__name__, exception.args 33 | ) 34 | logging.info( 35 | '%s\\nTraceback: %s', error_message, traceback.format_exc() 36 | ) 37 | # Converts all exceptions to HTTPException. 38 | return fastapi.HTTPException(status_code=500, detail=error_message) 39 | 40 | 41 | class ModelServer: 42 | """Model server that wraps llm builtins and exposes them as API calls. 43 | 44 | See run_model_server.py for an example of how to use this class. 45 | """ 46 | 47 | def __init__(self): 48 | """Initializes a fastapi application and sets the configs.""" 49 | 50 | logging.basicConfig( 51 | format='%(asctime)s: %(message)s', 52 | datefmt='%m/%d/%Y %I:%M:%S %p', 53 | level=logging.INFO, 54 | stream=sys.stdout, 55 | ) 56 | 57 | # We disable batching on the server (it may still be enabled on the client). 58 | # This is specific to the way we run onetwo in an async application. In 59 | # general, in client applications it is best to call onetwo.run(). 60 | batching._enable_batching.set(False) 61 | 62 | self._app = fastapi.FastAPI() 63 | 64 | @self._app.post('/tokenize') 65 | async def _tokenize( 66 | content: str = _Body(..., embed=True) # Ellipsis: required parameter. 67 | ) -> list[int]: 68 | """Wraps llm.tokenize.""" 69 | # We disable batching on the server (it may still be enabled on the 70 | # client). This is specific to the way we run onetwo in an async 71 | # application. In general, in client applications it is best to call 72 | # onetwo.run(). 73 | batching._enable_batching.set(False) # pylint: disable=protected-access 74 | 75 | try: 76 | res = await llm.tokenize(content) # pytype: disable=wrong-arg-count 77 | except Exception as exception: # pylint: disable=broad-exception-caught 78 | raise _get_http_exception(exception) from exception 79 | return res 80 | 81 | @self._app.post('/generate_text') 82 | async def _generate_text( 83 | prompt: str = _Body(...), # Ellipsis: required parameter. 84 | temperature: float = _Body(default=None), 85 | max_tokens: int = _Body(default=None), 86 | include_details: bool = _Body(default=False), 87 | ) -> tuple[str, dict[str, Any]]: 88 | """Wraps llm.generate_text.""" 89 | # We disable batching on the server (it may still be enabled on the 90 | # client). This is specific to the way we run onetwo in an async 91 | # application. In general, in client applications it is best to call 92 | # onetwo.run(). 93 | batching._enable_batching.set(False) # pylint: disable=protected-access 94 | try: 95 | res = await llm.generate_text( # pytype: disable=wrong-keyword-args 96 | prompt=prompt, 97 | temperature=temperature, 98 | max_tokens=max_tokens, 99 | include_details=include_details, 100 | ) 101 | except Exception as exception: # pylint: disable=broad-exception-caught 102 | raise _get_http_exception(exception) from exception 103 | if include_details: 104 | return res 105 | else: 106 | return res, {} 107 | 108 | @self._app.post('/count_tokens') 109 | async def _count_tokens( 110 | content: str = _Body(..., embed=True) # Ellipsis: required parameter. 111 | ) -> int: 112 | """Wraps llm.count_tokens.""" 113 | # We disable batching on the server (it may still be enabled on the 114 | # client). This is specific to the way we run onetwo in an async 115 | # application. In general, in client applications it is best to call 116 | # onetwo.run(). 117 | batching._enable_batching.set(False) # pylint: disable=protected-access 118 | try: 119 | res = await llm.count_tokens(content) # pytype: disable=wrong-arg-count 120 | except Exception as exception: # pylint: disable=broad-exception-caught 121 | raise _get_http_exception(exception) from exception 122 | return res 123 | 124 | self._app.add_api_route( 125 | path='/health', 126 | endpoint=self._health, 127 | methods=['GET'], 128 | ) 129 | 130 | async def __call__(self, scope, receive, send): 131 | await self._app(scope, receive, send) 132 | 133 | def _health(self): 134 | """Executes a health check.""" 135 | return {} 136 | -------------------------------------------------------------------------------- /onetwo/agents/agents_test_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import TypeAlias 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from onetwo.agents import agents_base 22 | from onetwo.agents import agents_test_utils 23 | from onetwo.core import executing 24 | 25 | 26 | StringAgentState = agents_base.UpdateListState[str, str] 27 | 28 | 29 | _SU: TypeAlias = agents_base.ScoredUpdate[str] 30 | _ULS: TypeAlias = agents_base.UpdateListState[str, _SU] 31 | 32 | 33 | class DistributionAgentTest(parameterized.TestCase): 34 | 35 | @parameterized.named_parameters( 36 | ('single_end', {'hello': 1.0}, 'hello', {'$': 1.0}), 37 | ('impossible', {'hello': 1.0}, 'x', {'': 1.0}), 38 | ('single_start', {'hello': 1.0}, '', {'h': 1.0}), 39 | ('double_start', {'hello': 0.7, 'hallo': 0.3}, 'h', {'e': 0.7, 'a': 0.3}), 40 | ('double_end', {'hello': 0.7, 'hallo': 0.3}, 'hello', {'$': 1.0}), 41 | ( 42 | 'triple', 43 | {'hello': 0.1, 'helly': 0.2, 'hallo': 0.2, 'world': 0.5}, 44 | 'h', 45 | {'e': 0.6, 'a': 0.4}, 46 | ), 47 | ) 48 | def test_next_step_distribution(self, words, prefix, expected): 49 | async def wrapper(words: dict[str, float], prefix: str) -> list[_SU]: 50 | agent = agents_test_utils.DistributionAgentForTest(words) 51 | state = await agent.initialize_state(prefix) # pytype: disable=wrong-arg-count 52 | return await agent.get_next_step_distribution(state) # pytype: disable=wrong-arg-count 53 | 54 | res = executing.run(wrapper(words, prefix)) 55 | res_as_dict = { 56 | scored_update.update: round(scored_update.score, 2) 57 | for scored_update in res 58 | } 59 | self.assertDictEqual(expected, res_as_dict) 60 | 61 | @parameterized.named_parameters( 62 | ('start', '', 1.0), 63 | ('impossible', 'x', 0.0), 64 | ('incomplete', 'h', 0.5), 65 | ('hello', 'hello', 0.3), 66 | ('hallo', 'hallo', 0.2), 67 | ('hello$', 'hello$', 0.1), 68 | ) 69 | def test_score_state(self, state, expected): 70 | score = agents_test_utils.DistributionAgentForTest( 71 | {'hello': 0.1, 'hello_world': 0.2, 'hallo': 0.2, 'world': 0.5} 72 | ).score_state(state) 73 | self.assertEqual( 74 | expected, 75 | round(score, 2), 76 | ) 77 | 78 | @parameterized.named_parameters( 79 | ('start', '', False), 80 | ('impossible', 'x', True), 81 | ('incomplete', 'h', False), 82 | ('hello', 'hello', False), 83 | ('hallo', 'hallo', False), 84 | ('hello$', 'hello$', True), 85 | ) 86 | def test_is_finished(self, state, expected): 87 | self.assertEqual( 88 | expected, 89 | agents_test_utils.DistributionAgentForTest( 90 | {'hello': 0.1, 'hello_world': 0.2, 'hallo': 0.2, 'world': 0.5} 91 | ).is_finished(state), 92 | ) 93 | 94 | @parameterized.named_parameters( 95 | ('empty', '', ['hello$', 'hallo$', 'hello_world$', 'world$']), 96 | ('h', 'h', ['hello$', 'hallo$', 'hello_world$']), 97 | ('he', 'he', ['hello$', 'hello_world$']), 98 | ('hello', 'hello', ['hello$', 'hello_world$']), 99 | ('hello$', 'hello$', ['hello$']), 100 | ('x', 'x', ['x']), 101 | ) 102 | def test_execute(self, prefix, expected): 103 | agent = agents_test_utils.DistributionAgentForTest( 104 | {'hello': 0.1, 'hallo': 0.2, 'hello_world': 0.4, 'world': 0.3} 105 | ) 106 | res = executing.run(agent(inputs=prefix)) # pytype: disable=wrong-keyword-args 107 | self.assertIn(res, expected) 108 | 109 | 110 | class StringAgentTest(parameterized.TestCase): 111 | 112 | def test_sample_next_step(self): 113 | agent = agents_test_utils.StringAgent(sequence=['a', 'b', 'c', 'd']) 114 | state = executing.run(agent.initialize_state(inputs='test')) # pytype: disable=wrong-keyword-args 115 | 116 | result = executing.run( 117 | agent.sample_next_step(state=state, num_candidates=2) # pytype: disable=wrong-keyword-args 118 | ) 119 | with self.subTest('first_call_returns_first_steps_in_sequence'): 120 | self.assertEqual(result, ['a', 'b']) 121 | 122 | result = executing.run( 123 | agent.sample_next_step(state=state, num_candidates=2) # pytype: disable=wrong-keyword-args 124 | ) 125 | with self.subTest('second_call_returns_next_steps_in_sequence'): 126 | self.assertEqual(result, ['c', 'd']) 127 | 128 | @parameterized.named_parameters( 129 | ('max_length_equals_seq_length', 3, ['a', 'a', 'b'], 'a a b'), 130 | ('max_length_greater_than_seq_length', 4, ['a', 'a', 'b'], 'a a b'), 131 | ('max_length_less_than_seq_length', 2, ['a', 'a', 'b'], 'a a'), 132 | ) 133 | def test_execute(self, max_length, sequence, expected_result): 134 | agent = agents_test_utils.StringAgent( 135 | max_length=max_length, sequence=sequence 136 | ) 137 | result = executing.run(agent(inputs='test')) # pytype: disable=wrong-keyword-args 138 | self.assertEqual(expected_result, result) 139 | 140 | 141 | if __name__ == '__main__': 142 | absltest.main() 143 | -------------------------------------------------------------------------------- /onetwo/agents/agents_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Agent implementations for use in tests.""" 16 | 17 | import dataclasses 18 | from typing import TypeAlias 19 | 20 | from onetwo.agents import agents_base 21 | from onetwo.agents import distribution 22 | from onetwo.core import executing 23 | 24 | _SU: TypeAlias = agents_base.ScoredUpdate[str] 25 | _ULS: TypeAlias = agents_base.UpdateListState[str, _SU] 26 | 27 | 28 | @dataclasses.dataclass 29 | class DistributionAgentForTest( 30 | distribution.DistributionAgent[str, str, str, _SU, None] 31 | ): 32 | """Given a distribution over strings, form a distribution over next chars.""" 33 | 34 | distribution: dict[str, float] = dataclasses.field(default_factory=dict) 35 | 36 | @executing.make_executable(copy_self=False, non_copied_args=['environment']) 37 | async def initialize_state( 38 | self, inputs: str, environment: None = None 39 | ) -> str: 40 | """Overridden from base class (Agent).""" 41 | return inputs 42 | 43 | def extract_output(self, state: str) -> str: 44 | """Overridden from base class (Agent).""" 45 | return state 46 | 47 | def is_finished(self, state: str) -> bool: 48 | """Overridden from base class (Agent).""" 49 | # We reached a final state if we have the end-of-sequence token `$`. 50 | if state.endswith('$'): 51 | return True 52 | # But we also consider the state to be final if it cannot be reached. 53 | found = False 54 | for word in self.distribution: 55 | if word.startswith(state): 56 | found = True 57 | return not found 58 | 59 | def score_state(self, state: str) -> float: 60 | """Returns the probability of reaching this state from an empty state.""" 61 | sum_probabilities = 0.0 62 | for word, prob in self.distribution.items(): 63 | if (word + '$').startswith(state): 64 | sum_probabilities += prob 65 | return sum_probabilities 66 | 67 | @executing.make_executable(copy_self=False) 68 | async def get_next_step_distribution( 69 | self, state: str, environment: None = None 70 | ) -> list[_SU]: 71 | """Overridden from base class (DistributionAgent).""" 72 | if self.is_finished(state): 73 | # If we are in a final state, we return a distribution with score 1 for 74 | # the empty update (which leaves the state unchanged). 75 | return [agents_base.ScoredUpdate(update='', score=1.0)] 76 | next_letter_probs = {} 77 | for word, score in self.distribution.items(): 78 | if word.startswith(state): 79 | if len(word) > len(state): 80 | next_letter = word[len(state)] 81 | else: 82 | next_letter = '$' # End of sequence token. 83 | if next_letter not in next_letter_probs: 84 | next_letter_probs[next_letter] = score 85 | else: 86 | next_letter_probs[next_letter] += score 87 | # Normalize the distribution. 88 | total_prob = sum(next_letter_probs.values()) 89 | if total_prob > 0.0: 90 | for k in next_letter_probs: 91 | next_letter_probs[k] /= total_prob 92 | else: 93 | # If the total probability is 0, we assume we cannot continue hence the 94 | # next update is necessarily empty. 95 | return [agents_base.ScoredUpdate(update='', score=1.0)] 96 | # Convert the disrtibution to ScoredUpdates. 97 | return [ 98 | agents_base.ScoredUpdate(update=k, score=v) 99 | for k, v in next_letter_probs.items() 100 | ] 101 | 102 | 103 | _StringAgentState: TypeAlias = agents_base.UpdateListState[str, str] 104 | 105 | 106 | @dataclasses.dataclass 107 | class StringAgent( 108 | agents_base.SingleSampleAgent[str, str, _StringAgentState, str, None] 109 | ): 110 | """Simple test agent, whose input / updates are strings. 111 | 112 | Its output is a concatenation of the update strings, separate by space. 113 | 114 | Attributes: 115 | max_length: Maximum length of the agent's state (i.e., of its update list). 116 | If specified, then will finish when this length is reached. If None, then 117 | will by default run forever. 118 | sequence: A sequence of strings to be used by the agent to produce samples. 119 | """ 120 | 121 | max_length: int = 5 122 | sequence: list[str] = dataclasses.field(default_factory=list) 123 | 124 | @executing.make_executable(copy_self=False, non_copied_args=['environment']) 125 | async def initialize_state( 126 | self, inputs: str, environment: None = None 127 | ) -> _StringAgentState: 128 | return _StringAgentState(inputs=inputs) 129 | 130 | def extract_output(self, state: _StringAgentState) -> str: 131 | """Overridden from base class (Agent).""" 132 | return ' '.join(state.updates) 133 | 134 | def is_finished(self, state: _StringAgentState) -> bool: 135 | """Overridden from base class (Agent).""" 136 | return len(state.updates) >= self.max_length or not self.sequence 137 | 138 | @executing.make_executable(copy_self=False) 139 | async def _sample_single_next_step( 140 | self, state: _StringAgentState, environment=None 141 | ) -> str: 142 | """Overridden from base class (SingleSampleAgent).""" 143 | return self.sequence.pop(0) 144 | -------------------------------------------------------------------------------- /onetwo/builtins/composables.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Composable versions of builtin functions.""" 16 | 17 | from collections.abc import Sequence 18 | from typing import Any, TypeVar 19 | 20 | from onetwo.builtins import llm 21 | from onetwo.builtins import prompt_templating 22 | from onetwo.core import composing 23 | from onetwo.core import content as content_lib 24 | 25 | 26 | _T = TypeVar('_T') 27 | 28 | store = composing.store 29 | section_start = composing.section_start 30 | section_end = composing.section_end 31 | 32 | 33 | @composing.make_composable # pytype: disable=wrong-arg-types 34 | def f( 35 | context: composing.Context, content: str, role: content_lib.RoleType = None 36 | ) -> content_lib.Chunk: 37 | """Composable version of the string.format function that uses context vars.""" 38 | try: 39 | content = content.format(**context.variables) 40 | except KeyError as e: 41 | raise ValueError( 42 | f'Could not format {content} with context' 43 | f' {context.variables}, some variables were not found.' 44 | ) from e 45 | return content_lib.Chunk(content, role=role) 46 | 47 | 48 | @composing.make_composable # pytype: disable=wrong-arg-types 49 | def c( 50 | context: composing.Context, content: Any, role: content_lib.RoleType = None 51 | ) -> content_lib.ChunkList: 52 | """Composable chunk created from content.""" 53 | del context 54 | if isinstance(content, content_lib.Chunk): 55 | if role: 56 | content.role = role 57 | return content_lib.ChunkList([content]) 58 | elif isinstance(content, content_lib.ChunkList): 59 | if role: 60 | for chunk in content: 61 | chunk.role = role 62 | return content 63 | else: 64 | return content_lib.ChunkList([content_lib.Chunk(content, role=role)]) 65 | 66 | 67 | @composing.make_composable # pytype: disable=wrong-arg-types 68 | async def j( 69 | context: composing.Context, 70 | template: str, 71 | name: str = 'JinjaTemplate', 72 | role: content_lib.RoleType = None, 73 | ) -> content_lib.Chunk: 74 | """Composable version of a jinja formatted prompt using context vars.""" 75 | result = await prompt_templating.JinjaTemplateWithCallbacks( 76 | name=name, text=template 77 | ).render(**context.variables) 78 | # We extract the output variables from the template and store them into the 79 | # context. 80 | for key, value in result.items(): 81 | if key != prompt_templating.PROMPT_PREFIX: 82 | context[key] = value 83 | return content_lib.Chunk(result[prompt_templating.PROMPT_PREFIX], role=role) 84 | 85 | 86 | @composing.make_composable # pytype: disable=wrong-arg-types 87 | def generate_text(context: composing.Context, **kwargs) -> ...: 88 | """Composable version of the llm.generate_text function.""" 89 | return llm.generate_text(context.prefix, **kwargs) # pytype: disable=wrong-arg-count 90 | 91 | 92 | @composing.make_composable # pytype: disable=wrong-arg-types 93 | async def chat( 94 | context: composing.Context, 95 | **kwargs, 96 | ) -> content_lib.ChunkList: 97 | """Composable version of the llm.chat function.""" 98 | return content_lib.ChunkList([ 99 | content_lib.Chunk( 100 | await llm.chat(context.to_messages(), **kwargs), # pytype: disable=wrong-arg-count 101 | role=content_lib.PredefinedRole.MODEL, 102 | ) 103 | ]) 104 | 105 | 106 | @composing.make_join_composable # pytype: disable=wrong-arg-types 107 | async def select( 108 | context: composing.Context, 109 | options: Sequence[tuple[str, composing.Context]], 110 | ) -> tuple[str, composing.Context]: 111 | """Composable select function. 112 | 113 | Args: 114 | context: Context of execution. 115 | options: Sequence of options. Each option is a pair (text, context). 116 | 117 | Returns: 118 | The pair (text, context) that is selected among the possible options. For 119 | example, this could be the one with the highest score. 120 | """ 121 | common_prefix = context.prefix 122 | text_options = [text for text, _ in options] 123 | value, index, _ = await llm.select( # pytype: disable=wrong-arg-count 124 | common_prefix, text_options, include_details=True 125 | ) 126 | return value, options[index][1] 127 | 128 | 129 | @composing.make_composable # pytype: disable=wrong-arg-types 130 | def generate_object( 131 | context: composing.Context, cls: type[Any], **kwargs 132 | ) -> ...: 133 | """Composable version of the llm.generate_object function.""" 134 | return llm.generate_object(context.prefix, cls, **kwargs) # pytype: disable=wrong-arg-count 135 | 136 | 137 | @composing.make_composable # pytype: disable=wrong-arg-types 138 | async def instruct( 139 | context: composing.Context, assistant_prefix: str | None = None, **kwargs 140 | ) -> ...: 141 | """Composable version of the llm.instruct function.""" 142 | # TODO: We are awaiting llm.instruct which means we cannot 143 | # iterate through the result if we are doing streaming. We should instead 144 | # wrap this into an ExecutableWithPostprocessing which adds the 145 | # assistant_prefix, or change the prefix directly and call instruct with 146 | # the unchanged prefix (requires to deep-copy the prefix). 147 | result = await llm.instruct(context.prefix, assistant_prefix, **kwargs) # pytype: disable=wrong-arg-count 148 | if assistant_prefix is None: 149 | return result 150 | else: 151 | return assistant_prefix + result 152 | -------------------------------------------------------------------------------- /onetwo/agents/critics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Sequence 16 | import dataclasses 17 | from typing import TypeAlias 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from onetwo.agents import agents_base 22 | from onetwo.agents import critics 23 | from onetwo.agents import distribution 24 | from onetwo.core import executing 25 | 26 | 27 | _SU: TypeAlias = agents_base.ScoredUpdate[str] 28 | _ULS: TypeAlias = agents_base.UpdateListState[str, _SU] 29 | 30 | 31 | class ScorerForTest(critics.ScoringFunction[str, float]): 32 | 33 | @executing.make_executable(copy_self=False) 34 | async def __call__(self, state: str, update: float) -> float: 35 | """Assigns as a score the update value.""" 36 | return update 37 | 38 | 39 | class RankerForTest(critics.RankingFunction[str, float]): 40 | 41 | @executing.make_executable(copy_self=False) 42 | async def __call__( 43 | self, states_and_updates: Sequence[tuple[str, float]] 44 | ) -> Sequence[int]: 45 | """Returns the indices of the states sorted by the update value.""" 46 | sorted_states_and_updates = sorted( 47 | enumerate(states_and_updates), key=lambda x: x[1][1], reverse=True 48 | ) 49 | return [i for i, _ in sorted_states_and_updates] 50 | 51 | 52 | class SelectorForTest(critics.SelectingFunction[str, float]): 53 | 54 | @executing.make_executable(copy_self=False) 55 | async def __call__( 56 | self, states_and_updates: Sequence[tuple[str, float]] 57 | ) -> int: 58 | """Returns the index of the state with the highest update value.""" 59 | return max(enumerate(states_and_updates), key=lambda x: x[1][1])[0] 60 | 61 | 62 | @dataclasses.dataclass 63 | class DistributionAgentForTest( 64 | distribution.DistributionAgent[ 65 | str, str, str, agents_base.ScoredUpdate[str], None 66 | ] 67 | ): 68 | distribution: list[tuple[str, float]] = dataclasses.field( 69 | default_factory=list 70 | ) 71 | 72 | @executing.make_executable(copy_self=False, non_copied_args=['environment']) 73 | async def initialize_state(self, inputs: str) -> str: 74 | """Overridden from base class (Agent).""" 75 | return inputs 76 | 77 | def extract_output(self, state: str) -> str: 78 | """Overridden from base class (Agent).""" 79 | return state 80 | 81 | def is_finished(self, state: str) -> bool: 82 | """Overridden from base class (Agent).""" 83 | return bool(state) 84 | 85 | @executing.make_executable(copy_self=False) 86 | async def get_next_step_distribution( 87 | self, state: str, environment: None = None 88 | ) -> list[agents_base.ScoredUpdate[str]]: 89 | """Overridden from base class (DistributionAgent).""" 90 | return [ 91 | agents_base.ScoredUpdate(update=d[0], score=d[1]) 92 | for d in self.distribution 93 | ] 94 | 95 | 96 | class CriticsTest(parameterized.TestCase): 97 | 98 | def test_scorer_to_ranker(self): 99 | scorer = ScorerForTest() 100 | converted_ranker = critics.ranker_from_scorer(scorer) 101 | ranker = RankerForTest() 102 | states_and_updates = [ 103 | ('a', 1.0), 104 | ('b', 2.0), 105 | ('c', 3.0), 106 | ('d', 4.0), 107 | ('e', 5.0), 108 | ] 109 | res = executing.run(ranker(states_and_updates)) # pytype: disable=wrong-arg-count 110 | res2 = executing.run(converted_ranker(states_and_updates)) # pytype: disable=wrong-arg-count 111 | self.assertEqual(res, res2) 112 | self.assertEqual(res, [4, 3, 2, 1, 0]) 113 | 114 | def test_selector_to_ranker(self): 115 | selector = SelectorForTest() 116 | converted_ranker = critics.ranker_from_selector(selector) 117 | ranker = RankerForTest() 118 | states_and_updates = [ 119 | ('a', 1.0), 120 | ('b', 2.0), 121 | ('c', 3.0), 122 | ('d', 4.0), 123 | ('e', 5.0), 124 | ] 125 | res = executing.run(ranker(states_and_updates)) # pytype: disable=wrong-arg-count 126 | res2 = executing.run(converted_ranker(states_and_updates)) # pytype: disable=wrong-arg-count 127 | self.assertEqual(res, res2) 128 | self.assertEqual(res, [4, 3, 2, 1, 0]) 129 | 130 | def test_score_from_update(self): 131 | """Use an agent with distribution to score its states.""" 132 | distrib = [('a', 0.1), ('b', 0.3), ('c', 0.6)] 133 | dist_agent = DistributionAgentForTest(distribution=distrib) 134 | dist_map = {k: v for k, v in distrib} 135 | scoring_function = critics.ScoreFromUpdates() 136 | state = 'a' 137 | scores = [] 138 | expected_scores = [] 139 | 140 | async def wrapper(): 141 | nonlocal scores 142 | updates = await dist_agent.sample_next_step(state=state, num_candidates=3) # pytype: disable=wrong-keyword-args 143 | for update in updates: 144 | scores.append( 145 | await scoring_function(state, update) 146 | ) 147 | expected_scores.append(dist_map[update.update]) 148 | 149 | executing.run(wrapper()) 150 | self.assertEqual(scores, expected_scores) 151 | 152 | def test_score_from_update_list(self): 153 | """Use an agent with distribution to score its states.""" 154 | scoring_function = critics.ScoreFromUpdateList() 155 | state = _ULS('', [_SU('a', 0.1), _SU('b', 0.3), _SU('c', 0.6)]) 156 | update = _SU('d', 0.1) 157 | 158 | res = executing.run(scoring_function(state, update)) 159 | self.assertEqual(res, 1.1) 160 | 161 | 162 | if __name__ == '__main__': 163 | absltest.main() 164 | -------------------------------------------------------------------------------- /onetwo/agents/tasks/game_of_24_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from onetwo.agents import iterative_thought 17 | from onetwo.agents.tasks import game_of_24 18 | from onetwo.backends import backends_test_utils 19 | from onetwo.core import executing 20 | 21 | 22 | # Default reply for LLMForTest to return when it receives a prompt that it was 23 | # not expecting. 24 | DEFAULT_REPLY = 'UNKNOWN_PROMPT' 25 | 26 | 27 | class GameOf24Test(absltest.TestCase): 28 | 29 | def test_game_of_24_iterative_thought_prompt(self): 30 | # Some typical Game of 24 prompt inputs. 31 | description = game_of_24.GAME_OF_24_DESCRIPTION 32 | few_shots = game_of_24.GAME_OF_24_EXAMPLES 33 | state = iterative_thought.IterativeThoughtState( 34 | inputs='8 12 4 5', updates=['thought1'] 35 | ) 36 | prompt = iterative_thought.IterativeThoughtPromptJ2() 37 | 38 | # Now we define a test LLM to use in place of the actual LLM. To avoid the 39 | # need to hard-code here all of the expected requests and simulated replies, 40 | # however, we will just depend on a single default reply. 41 | llm_backend = backends_test_utils.LLMForTest( 42 | default_reply=DEFAULT_REPLY, 43 | default_score=0.0, 44 | ) 45 | llm_backend.register() 46 | 47 | # Now we execute the prompt and verify that the prompt contained the 48 | # expected content. (Although we don't verify all of the prompt formatting, 49 | # these assertions should be sufficient to catch many basic bugs where we 50 | # omitted a for-loop, or failed to include some of the fields due to a typo, 51 | # etc.) 52 | next_step, result = executing.run( 53 | prompt(description=description, few_shots=few_shots, state=state), 54 | enable_tracing=True, 55 | ) 56 | prefix = result.stages[0].outputs['prefix'] 57 | 58 | with self.subTest('prompt_should_contain_the_task_description'): 59 | self.assertIn(description, prefix) 60 | 61 | with self.subTest('prompt_should_contain_the_exemplar_inputs'): 62 | self.assertIn(few_shots[0].inputs, prefix) 63 | self.assertIn(few_shots[-1].inputs, prefix) 64 | 65 | with self.subTest('prompt_should_contain_the_exemplar_steps_so_far'): 66 | self.assertIn(few_shots[0].updates[0], prefix) 67 | self.assertIn(few_shots[-1].updates[-1], prefix) 68 | 69 | with self.subTest('prompt_should_contain_the_actual_inputs'): 70 | self.assertIn(state.inputs, prefix) 71 | 72 | if state.updates: 73 | with self.subTest('prompt_should_contain_the_actual_steps_so_far'): 74 | self.assertIn(state.updates[0], prefix) 75 | self.assertIn(state.updates[-1], prefix) 76 | 77 | with self.subTest('should_return_the_llm_reply_as_next_step'): 78 | self.assertEqual(DEFAULT_REPLY, next_step) 79 | 80 | def test_game_of_24_iterative_thought_proposer_prompt(self): 81 | # Some typical Game of 24 propose prompt inputs. 82 | description = game_of_24.GAME_OF_24_DESCRIPTION 83 | few_shots = game_of_24.GAME_OF_24_PROPOSER_EXAMPLES 84 | state = iterative_thought.IterativeThoughtState( 85 | inputs='8 12 4 5', updates=['thought1'] 86 | ) 87 | prompt = iterative_thought.IterativeThoughtProposerPromptJ2() 88 | 89 | # Now we define a test LLM to use in place of the actual LLM. To avoid the 90 | # need to hard-code here all of the expected requests and simulated replies, 91 | # however, we will just depend on a single default reply. 92 | llm_backend = backends_test_utils.LLMForTest( 93 | default_reply='ta\ntb', 94 | default_score=0.0, 95 | ) 96 | llm_backend.register() 97 | 98 | expected_next_steps = ['ta', 'tb'] 99 | 100 | # Now we execute the prompt and verify that the prompt contained the 101 | # expected content. (Although we don't verify all of the prompt formatting, 102 | # these assertions should be sufficient to catch many basic bugs where we 103 | # omitted a for-loop, or failed to include some of the fields due to a typo, 104 | # etc.) 105 | next_steps, result = executing.run( 106 | prompt(description=description, few_shots=few_shots, state=state), 107 | enable_tracing=True, 108 | ) 109 | prefix = result.stages[0].outputs['prefix'] 110 | 111 | with self.subTest('prompt_should_contain_the_task_description'): 112 | self.assertIn(description, prefix) 113 | 114 | with self.subTest('prompt_should_contain_the_exemplar_inputs'): 115 | self.assertIn(few_shots[0].state.inputs, prefix) 116 | self.assertIn(few_shots[-1].state.inputs, prefix) 117 | 118 | with self.subTest('prompt_should_contain_the_exemplar_state'): 119 | if few_shots[0].state.updates: 120 | self.assertIn(few_shots[0].state.updates[0], prefix) 121 | if few_shots[-1].state.updates: 122 | self.assertIn(few_shots[-1].state.updates[-1], prefix) 123 | 124 | with self.subTest('prompt_should_contain_the_exemplar_next_steps'): 125 | self.assertIn(few_shots[0].next_steps[0], prefix) 126 | self.assertIn(few_shots[-1].next_steps[-1], prefix) 127 | 128 | with self.subTest('prompt_should_contain_the_actual_inputs'): 129 | self.assertIn(state.inputs, prefix) 130 | 131 | if state.updates: 132 | with self.subTest('prompt_should_contain_the_actual_steps_so_far'): 133 | self.assertIn(state.updates[0], prefix) 134 | self.assertIn(state.updates[-1], prefix) 135 | 136 | with self.subTest('should_return_the_parsed_llm_reply_as_next_steps'): 137 | self.assertEqual(expected_next_steps, next_steps) 138 | 139 | if __name__ == '__main__': 140 | absltest.main() 141 | -------------------------------------------------------------------------------- /onetwo/stdlib/code_execution/python_execution_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for OneTwo unit tests involving Python execution.""" 16 | 17 | import collections 18 | from collections.abc import Mapping, Sequence 19 | import dataclasses 20 | import datetime 21 | from typing import Any, Callable 22 | import unittest 23 | 24 | from onetwo.core import executing 25 | from onetwo.stdlib.code_execution import python_execution 26 | 27 | # Aliases for brevity. 28 | _SandboxResult = python_execution.SandboxResult 29 | 30 | 31 | class SandboxResultAssertions(unittest.TestCase): 32 | """Mixin class for SandboxResult assertions.""" 33 | 34 | # pylint: disable=invalid-name 35 | def assertSandboxResultEqualIgnoringTiming( 36 | self, 37 | expected_result: _SandboxResult, 38 | actual_result: _SandboxResult, 39 | ) -> None: 40 | # Remove timing-related content. 41 | expected_without_timing = dataclasses.replace(expected_result, timing=None) 42 | actual_without_timing = dataclasses.replace(actual_result, timing=None) 43 | self.assertEqual( 44 | expected_without_timing, 45 | actual_without_timing, 46 | 'SandboxResult differed by more than just' 47 | f' timing.\nExpected:\n{expected_result!r}\nActual:\n{actual_result!r}', 48 | ) 49 | 50 | 51 | @dataclasses.dataclass 52 | class PythonSandboxForTest(python_execution.PythonSandbox): 53 | """Mock PythonSandbox. 54 | 55 | Attributes: 56 | hook_objects: Objects containing modifiable state of the hook functions. 57 | (Required as part of the PythonSandbox interface.) 58 | reply_by_request: Mapping from request to reply, or to a sequence of replies 59 | in case we want to return different results on the 1st call vs. the 2nd 60 | call, etc. 61 | default_reply: Default reply if not found in reply_by_request. 62 | requests: All `run` requests that were received, in the order received. 63 | (Same format as `unexpected_requests` below.) 64 | unexpected_requests: Requests that were not found in the corresponding 65 | mappings (i.e., requests for which we ended up falling back to returning 66 | the `default_reply`). 67 | """ 68 | 69 | hook_objects: Mapping[str, Any] = dataclasses.field(default_factory=dict) 70 | 71 | # Attributes for controlling the replies to be returned. 72 | reply_by_request: Mapping[str, _SandboxResult | Sequence[_SandboxResult]] = ( 73 | dataclasses.field(default_factory=dict) 74 | ) 75 | default_reply: _SandboxResult = dataclasses.field( 76 | default_factory=_SandboxResult 77 | ) 78 | 79 | # Attributes used for tracking the actual requests / replies (for assertions). 80 | requests: list[str] = dataclasses.field(init=False, default_factory=list) 81 | unexpected_requests: list[str] = dataclasses.field( 82 | init=False, default_factory=list 83 | ) 84 | 85 | # Attributes used for tracking the actual requests / replies (internal). 86 | _num_run_calls_by_request: collections.Counter[str] = ( 87 | dataclasses.field(init=False, default_factory=collections.Counter) 88 | ) 89 | 90 | def is_stateful(self) -> bool: 91 | """See base class (PythonSandbox).""" 92 | return True 93 | 94 | @executing.make_executable # pytype: disable=wrong-arg-types 95 | def run(self, code: str) -> _SandboxResult: 96 | """See base class (PythonSandbox).""" 97 | self.requests.append(code) 98 | 99 | # By request. 100 | if code in self.reply_by_request: 101 | reply = self.reply_by_request[code] 102 | if isinstance(reply, str): 103 | # Single reply specified. Always return it. 104 | return reply 105 | else: 106 | # Sequence of replies specified. Return the next (until we run out). 107 | reply_index = self._num_run_calls_by_request[code] 108 | self._num_run_calls_by_request[code] += 1 109 | if reply_index < len(reply): 110 | return reply[reply_index] 111 | 112 | # Default. 113 | self.unexpected_requests.append(code) 114 | return self.default_reply 115 | 116 | async def set_variables(self, **variables: Any) -> None: 117 | """See base class (PythonSandbox).""" 118 | raise NotImplementedError() 119 | 120 | async def get_variables(self, *names: str) -> Mapping[str, Any]: 121 | """See base class (PythonSandbox).""" 122 | raise NotImplementedError() 123 | 124 | def get_hook_object(self, key: str) -> Any: 125 | """See base class (PythonSandbox).""" 126 | return self.hook_objects.get(key) 127 | 128 | 129 | @dataclasses.dataclass 130 | class PythonSandboxForTestFactory(python_execution.PythonSandboxFactory): 131 | """Mock PythonSandboxFactory that returns a hard-coded PythonSandboxForTest. 132 | 133 | Attributes: 134 | default_sandbox: Sandbox to be returned by default on calls to 135 | `create_sandbox`. Note that in order to satisfy behavior expectations of 136 | `PythonPlanningAgent`, the factory will automatically overwrite the 137 | `hook_objects` member of this default sandbox each time before returning 138 | it from `create_sandbox`. (This should be fine in cases where we only need 139 | to return one sandbox; for tests in which we expect multiple sandboxes to 140 | be created, however, we may need more fine-grained configuration 141 | controlling the sandboxes to return.) 142 | """ 143 | 144 | default_sandbox: PythonSandboxForTest = dataclasses.field( 145 | default_factory=PythonSandboxForTest 146 | ) 147 | 148 | def create_sandbox( 149 | self, 150 | *, 151 | timeout: datetime.timedelta = datetime.timedelta(seconds=10), 152 | imports: Sequence[str] | str = tuple(), 153 | hooks: Mapping[str, Callable[..., Any]] | None = None, 154 | hook_objects: Mapping[str, Any] | None = None, 155 | allow_restarts: bool = False, 156 | ) -> PythonSandboxForTest: 157 | # If necessary, we can add here more detailed configurations controlling the 158 | # sandbox to return, e.g., depending on the parameters specified, or on the 159 | # number of times that `create_sandbox` has been called. For now we just 160 | # reuse the single `default_sandbox`. 161 | self.default_sandbox.hook_objects = hook_objects or {} 162 | return self.default_sandbox 163 | -------------------------------------------------------------------------------- /onetwo/builtins/prompt_templating.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for executing a prompt template.""" 16 | 17 | from collections.abc import Callable, Mapping 18 | import copy 19 | import dataclasses 20 | import functools 21 | from typing import Any, TypeVar 22 | 23 | import immutabledict 24 | from onetwo.builtins import callbacks 25 | from onetwo.core import executing 26 | from onetwo.core import templating 27 | 28 | 29 | _DRY_RUN_PREFIX_VAR = templating._DRY_RUN_PREFIX_VAR # pylint: disable=protected-access 30 | PROMPT_PREFIX = templating.PROMPT_PREFIX 31 | 32 | _T = TypeVar('_T') 33 | _EMPTY_MAP = immutabledict.immutabledict({}) 34 | 35 | 36 | @dataclasses.dataclass 37 | class JinjaTemplateWithCallbacks(templating.JinjaTemplate): 38 | """A Jinja2 template augmented with additional LLM-specific callbacks.""" 39 | 40 | def __post_init__(self): 41 | """See parent class.""" 42 | super().__post_init__() 43 | self.register_callback('llm', callbacks.llm_callback, pass_context=True) 44 | # Also registering the llm callback as 'generate_text' for compatibility 45 | # with builtins and composables. 46 | # TODO: deprecate the llm version and make generate_text support 47 | # dry run. 48 | self.register_callback( 49 | 'generate_text', callbacks.generate_text, pass_context=True 50 | ) 51 | self.register_callback( 52 | 'generate_texts', callbacks.generate_texts, pass_context=True 53 | ) 54 | self.register_callback('choose', callbacks.choose, pass_context=True) 55 | self.register_callback( 56 | 'generate_object', callbacks.generate_object, pass_context=True 57 | ) 58 | 59 | def _postprocess_iterable_reply(self, iterable_reply: Any) -> str: 60 | if isinstance(iterable_reply, tuple) and len(iterable_reply) == 2: 61 | # This is the case where we got a reply with details. We just use the 62 | # first part. 63 | # Note that this is specific to the generate_text reply so it may 64 | # not work for arbitrary streaming callbacks. 65 | return iterable_reply[0] 66 | else: 67 | return super()._postprocess_iterable_reply(iterable_reply) 68 | 69 | @executing.make_executable # pytype: disable=wrong-arg-types 70 | async def dry_run( 71 | self, 72 | inputs: Mapping[str, Any], 73 | llm_default_reply: str = 'default', 74 | generate_object_default_reply_map: Mapping[type[_T], _T] = _EMPTY_MAP, 75 | mock_callbacks: ( 76 | list[tuple[str, Callable[[Any], Any], bool]] | None 77 | ) = None, 78 | ) -> Mapping[str, Any]: 79 | """Dry runs the prompt and returns prefixes sent to the LLM. 80 | 81 | This method renders the jinja2 prompt and gets the prefixes that 82 | are sent to the language model without actually executing the LLM requests. 83 | By default this method mocks the `llm` and `choose` operations. 84 | The mocked `llm` simply returns the `llm_default_reply` provided by the 85 | user, while the mocked `choose` chooses the first `top_k` provided options 86 | and populates scores with all `0.0`. 87 | 88 | The user can also mock other callbacks. This can be useful when there is a 89 | callback that sends LLM requests within itself, eg. by defining and 90 | executing its own PromptTemplate instance. In this case the callback 91 | functions provided by the user can store prefixes in the 92 | `_DRY_RUN_PREFIX_VAR` output variable. 93 | 94 | Args: 95 | inputs: Inputs to the prompt template. 96 | llm_default_reply: String that is used as a reply from `llm` operation. 97 | generate_object_default_reply_map: map from type to instance for default 98 | values calling generate_object 99 | mock_callbacks: List of (callback_name, callback_fn, pass_context) for 100 | additional callbacks to be mocked. 101 | 102 | Returns: 103 | A dict with at least one key `PROMPT_PREFIX` that contains the final 104 | rendered string of the entire prompt. If the template has `llm()` calls 105 | the dict contains an `llm` key that stores a list of string-valued 106 | rendered prefixes sent to the `llm()` operation. If template has 107 | `choose()` calls the dict contains a `choose` key that stores a list of 108 | string-valued rendered prefixes sent to the `choose()` operation. 109 | The output may contain other keys if the user provides additional 110 | mock callbacks that write to the `_DRY_RUN_PREFIX_VAR` output variable. 111 | """ 112 | # To mock some of the callbacks we modify the `_callbacks` attribute 113 | # of the PromptTemplateJ2 instance. To avoid subtle issues (for example: 114 | # parallel async execution of multiple dry_run coroutines of the same 115 | # PromptTemplateJ2 instance) we create a copy of the instance. 116 | mock_prompt = copy.deepcopy(self) 117 | # By default we mock `choose` and `llm` operations. 118 | mock_prompt.register_callback( 119 | 'llm', 120 | functools.partial( 121 | callbacks.mock_llm_callback, default_reply=llm_default_reply 122 | ), 123 | pass_context=True, 124 | ) 125 | mock_prompt.register_callback( 126 | 'choose', callbacks.mock_choose_callback, pass_context=True 127 | ) 128 | mock_prompt.register_callback( 129 | 'generate_object', 130 | functools.partial( 131 | callbacks.mock_generate_object_callback, 132 | default_type_to_object_map=generate_object_default_reply_map, 133 | ), 134 | pass_context=True, 135 | ) 136 | # We also mock any other callbacks provided by the user. 137 | if mock_callbacks is not None: 138 | for callback_name, callback_fn, callback_pass_context in mock_callbacks: 139 | mock_prompt.register_callback( 140 | name=callback_name, 141 | function=callback_fn, 142 | pass_context=callback_pass_context, 143 | ) 144 | outputs = await mock_prompt.render(**inputs) 145 | try: 146 | result = outputs[_DRY_RUN_PREFIX_VAR] 147 | except KeyError as exc: 148 | raise ValueError( 149 | 'No DRY_RUN found in outputs. Please make sure that the template has' 150 | ' a llm() call.' 151 | ) from exc 152 | result[PROMPT_PREFIX] = outputs[PROMPT_PREFIX] 153 | return result 154 | -------------------------------------------------------------------------------- /onetwo/agents/distribution_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Sequence 16 | import dataclasses 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from onetwo.agents import agents_base 21 | from onetwo.agents import distribution 22 | from onetwo.core import executing 23 | 24 | 25 | StringAgentState = agents_base.UpdateListState[str, str] 26 | 27 | 28 | @dataclasses.dataclass 29 | class StringDistributionAgent( 30 | distribution.DistributionAgent[ 31 | str, str, StringAgentState, agents_base.ScoredUpdate[str], None 32 | ] 33 | ): 34 | """Simple test agent, whose input / updates are hard-coded strings. 35 | 36 | The update strings are of the form '1', '2', '3', etc. 37 | 38 | Its output is a concatenation of the update strings, separate by space. 39 | 40 | Attributes: 41 | max_length: Maximum length of the agent's state (i.e., of its update list). 42 | If specified, then will finish when this length is reached. If None, then 43 | will by default run forever. 44 | """ 45 | 46 | end_of_sequence: int = 0 47 | vocab_size: int = 10 48 | distribution: list[tuple[str, float]] = dataclasses.field( 49 | default_factory=list 50 | ) 51 | 52 | @executing.make_executable(copy_self=False, non_copied_args=['environment']) 53 | async def initialize_state( 54 | self, inputs: str, environment: None = None 55 | ) -> StringAgentState: 56 | return StringAgentState(inputs=inputs) 57 | 58 | def extract_output(self, state: StringAgentState) -> str: 59 | """Overridden from base class (Agent).""" 60 | return ' '.join(state.updates) 61 | 62 | def is_finished(self, state: StringAgentState) -> bool: 63 | """Overridden from base class (Agent).""" 64 | if not state.updates: 65 | return False 66 | return state.updates[-1] == self.end_of_sequence 67 | 68 | @executing.make_executable(copy_self=False) 69 | async def get_next_step_distribution( 70 | self, state: StringAgentState, environment: None = None 71 | ) -> list[agents_base.ScoredUpdate[str]]: 72 | if self.distribution: 73 | return [ 74 | agents_base.ScoredUpdate(update=d[0], score=d[1]) 75 | for d in self.distribution 76 | ] 77 | else: 78 | # Return a uniform distribution over the vocabulary. 79 | return [ 80 | agents_base.ScoredUpdate(update=str(i), score=1.0 / self.vocab_size) 81 | for i in range(self.vocab_size) 82 | ] 83 | 84 | 85 | def _scored_updates_to_tuples( 86 | scored_updates: Sequence[agents_base.ScoredUpdate[str]], 87 | ) -> list[tuple[str, float]]: 88 | return [(s.update, s.score) for s in scored_updates] 89 | 90 | 91 | class DistributionAgentTest(parameterized.TestCase): 92 | 93 | def test_get_next_step_distribution(self): 94 | agent = StringDistributionAgent(vocab_size=10) 95 | 96 | with self.subTest('distribution_is_uniform'): 97 | dist = executing.run( 98 | agent.get_next_step_distribution( # pytype: disable=wrong-keyword-args 99 | state=agent.initialize_state(inputs='1') # pytype: disable=wrong-keyword-args 100 | ) 101 | ) 102 | self.assertSequenceEqual( 103 | [(str(i), 0.1) for i in range(10)], _scored_updates_to_tuples(dist) 104 | ) 105 | 106 | with self.subTest('samples_are_correct'): 107 | samples = executing.run( 108 | agent.sample_next_step( # pytype: disable=wrong-keyword-args 109 | state=agent.initialize_state(inputs='1'), num_candidates=5 # pytype: disable=wrong-keyword-args 110 | ) 111 | ) 112 | samples = [s.update for s in samples] 113 | self.assertContainsSubset(set(samples), set(str(i) for i in range(10))) 114 | 115 | with self.subTest('samples_and_scores_are_correct'): 116 | samples_with_scores = executing.run( 117 | agent.sample_next_step( # pytype: disable=wrong-keyword-args 118 | state=agent.initialize_state(inputs='1'), num_candidates=5 # pytype: disable=wrong-keyword-args 119 | ) 120 | ) 121 | self.assertContainsSubset( 122 | set(_scored_updates_to_tuples(samples_with_scores)), 123 | set((str(i), 0.1) for i in range(10)), 124 | ) 125 | 126 | 127 | class ReweightedDistributionAgentTest(parameterized.TestCase): 128 | 129 | @parameterized.named_parameters( 130 | ('top_k', None, None, 2, [('1', 0.62), ('2', 0.0), ('3', 0.38)]), 131 | ('top_p', None, 0.1, None, [('1', 1.0), ('2', 0.0), ('3', 0.0)]), 132 | ('top_k_top_p', None, 0.1, 2, [('1', 1.0), ('2', 0.0), ('3', 0.0)]), 133 | ('temp_0', 0.0, None, None, [('1', 1.0), ('2', 0.0), ('3', 0.0)]), 134 | ('temp_1', 1.0, None, None, [('1', 0.5), ('2', 0.2), ('3', 0.3)]), 135 | ('temp_5', 5.0, None, None, [('1', 0.37), ('2', 0.3), ('3', 0.33)]), 136 | ('temp_100', 100.0, None, None, [('1', 0.33), ('2', 0.33), ('3', 0.33)]), 137 | ('temp_1_top_k', 1.0, None, 2, [('1', 0.63), ('2', 0.0), ('3', 0.38)]), 138 | ('temp_0_top_p', 0.0, 0.5, None, [('1', 1.0), ('2', 0.0), ('3', 0.0)]), 139 | ( 140 | 'temp_100_top_p', 141 | 100.0, 142 | 0.5, 143 | None, 144 | [('1', 0.5), ('2', 0.0), ('3', 0.5)], 145 | ), 146 | ) 147 | def test_get_distribution( 148 | self, 149 | temperature: float, 150 | top_p: float, 151 | top_k: int, 152 | expected_distribution: list[tuple[str, float]], 153 | ): 154 | inner_agent = StringDistributionAgent( 155 | distribution=[('1', 0.5), ('2', 0.2), ('3', 0.3)] 156 | ) 157 | outer_agent = distribution.ReweightedDistributionAgent( 158 | inner_agent=inner_agent, 159 | temperature=temperature, 160 | top_p=top_p, 161 | top_k=top_k, 162 | ) 163 | state = outer_agent.initialize_state('test') # pytype: disable=wrong-arg-count 164 | result = executing.run(outer_agent.get_next_step_distribution(state=state)) # pytype: disable=wrong-keyword-args 165 | # We round off the distribution to make the comparison easier. 166 | result = [(s.update, round(s.score, 2)) for s in result] 167 | # Also to avoid issues when comparing floats, we convert everything to 168 | # strings. 169 | self.assertEqual(str(expected_distribution), str(result)) 170 | 171 | 172 | if __name__ == '__main__': 173 | absltest.main() 174 | -------------------------------------------------------------------------------- /onetwo/core/sampling_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pprint 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from onetwo.core import caching 20 | from onetwo.core import executing 21 | from onetwo.core import sampling 22 | 23 | 24 | @executing.make_executable # pytype: disable=wrong-arg-types 25 | async def process(request: str, suffix: str = '') -> tuple[str, str]: 26 | key = caching.context_sampling_key.get() 27 | return (request+suffix, key) 28 | 29 | 30 | class SamplingTest(parameterized.TestCase): 31 | 32 | def test_repeat(self): 33 | executable = process('test') 34 | executables = sampling.repeat(executable, 3) 35 | 36 | res = executing.run(executing.par_iter(executables)) 37 | 38 | with self.subTest('different_sampling_keys'): 39 | self.assertEqual(res, [('test', ''), ('test', '1'), ('test', '2')]) 40 | 41 | def test_repeat_multiple_times(self): 42 | executable = process('test') 43 | executables = sampling.repeat(executable, 3) 44 | res = list(executing.run(executing.par_iter(executables))) 45 | executables = sampling.repeat(executable, 3, start_index=3) 46 | res += executing.run(executing.par_iter(executables)) 47 | executables = sampling.repeat(executable, 3, start_index=6) 48 | res += executing.run(executing.par_iter(executables)) 49 | res_keys = [key for _, key in res] 50 | 51 | with self.subTest('different_sampling_keys'): 52 | self.assertEqual(res_keys, [''] + [str(i) for i in range(1, 9)]) 53 | 54 | def test_nested_repeat(self): 55 | """We test the nesting of repeats, together with the dynamic keys.""" 56 | # Whether we create the repeated executable before executing them (i.e. 57 | # statically) or while executing them (i.e. dynamically), the effect on 58 | # sampling keys should be the same. 59 | 60 | # Static creation 61 | static_executable = executing.serial( 62 | sampling.repeat_and_execute(process('test1'), 3), 63 | sampling.repeat_and_execute(process('test2'), 3), 64 | ) 65 | 66 | # Dynamic creation 67 | @executing.make_executable # pytype: disable=wrong-arg-types 68 | async def dynamic_executable(): 69 | return await executing.serial( 70 | sampling.repeat_and_execute(process('test1'), 3), 71 | sampling.repeat_and_execute(process('test2'), 3), 72 | ) 73 | 74 | executable1 = sampling.repeat_and_execute(static_executable, 2) 75 | executable2 = sampling.repeat_and_execute(dynamic_executable(), 2) 76 | res1 = executing.run(executable1) 77 | res2 = executing.run(executable2) 78 | 79 | expected_results = [ 80 | [ 81 | [('test1', ''), ('test1', '1'), ('test1', '2')], 82 | [('test2', ''), ('test2', '1'), ('test2', '2')], 83 | ], 84 | [ 85 | [('test1', '1#0'), ('test1', '1#1'), ('test1', '1#2')], 86 | [('test2', '1#0'), ('test2', '1#1'), ('test2', '1#2')], 87 | ], 88 | ] 89 | 90 | with self.subTest('static'): 91 | self.assertEqual( 92 | res1, 93 | expected_results, 94 | msg=pprint.pformat(res1), 95 | ) 96 | 97 | with self.subTest('dynamic'): 98 | self.assertEqual( 99 | res2, 100 | expected_results, 101 | msg=pprint.pformat(res2), 102 | ) 103 | 104 | def test_update(self): 105 | sample_size = 2 106 | 107 | def update_result(result, sample_id): 108 | return result[0], result[1], sample_id, sample_size 109 | 110 | executable = sampling.repeat_and_execute( 111 | process('test'), sample_size, update_result_fn=update_result 112 | ) 113 | res = executing.run(executable) 114 | with self.subTest('should_update_results'): 115 | for i in range(sample_size): 116 | self.assertEqual(res[i], ('test', str(i) if i else '', i, sample_size)) 117 | 118 | def test_streaming(self): 119 | executable = executing.par_iter(sampling.repeat(process('test'), 3)) 120 | 121 | with executing.safe_stream(executable) as iterator: 122 | results = sum(iterator, start=executing.Update()).to_result() 123 | 124 | self.assertEqual( 125 | results, 126 | [('test', ''), ('test', '1'), ('test', '2')], 127 | ) 128 | 129 | def test_repeat_sampler(self): 130 | sampler = sampling.Repeated(process) 131 | # Note that we pass in both a positional and a keyword argument to verify 132 | # that the sampler is correctly passing them through. 133 | res = executing.run(sampler('test', suffix='-a', num_samples=3)) # pytype: disable=wrong-arg-count 134 | expected = [('test-a', ''), ('test-a', '1'), ('test-a', '2')] 135 | self.assertEqual(expected, res, res) 136 | 137 | def test_round_robin_sampler(self): 138 | @executing.make_executable # pytype: disable=wrong-arg-types 139 | async def strategy1(request, **kwargs): 140 | return await process(f'{request}-1', **kwargs) 141 | 142 | @executing.make_executable # pytype: disable=wrong-arg-types 143 | async def strategy2(request, **kwargs): 144 | return await process(f'{request}-2', **kwargs) 145 | 146 | sampler = sampling.RoundRobin([ 147 | sampling.Repeated(strategy1), 148 | sampling.Repeated(strategy2), 149 | ]) 150 | 151 | # Note that we pass in both a positional and a keyword argument to verify 152 | # that the sampler is correctly passing them through. 153 | res = executing.run(sampler('test', suffix='-a', num_samples=5)) # pytype: disable=wrong-arg-count 154 | expected = [ 155 | ('test-1-a', ''), 156 | ('test-2-a', '1'), 157 | ('test-1-a', '2'), 158 | ('test-2-a', '3'), 159 | ('test-1-a', '4'), 160 | ] 161 | with self.subTest('default_start_index_starts_at_0_with_first_sampler'): 162 | self.assertEqual(expected, res, res) 163 | 164 | # This time we omit the suffix kwarg for simplicity. 165 | res = executing.run(sampler(request='test', num_samples=5, start_index=1)) # pytype: disable=wrong-keyword-args 166 | expected = [ 167 | ('test-2', '1'), 168 | ('test-1', '2'), 169 | ('test-2', '3'), 170 | ('test-1', '4'), 171 | ('test-2', '5'), 172 | ] 173 | with self.subTest('start_index_1_starts_at_1_with_second_sampler'): 174 | self.assertEqual(expected, res, res) 175 | 176 | 177 | if __name__ == '__main__': 178 | absltest.main() 179 | -------------------------------------------------------------------------------- /onetwo/agents/tasks/game_of_24.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data config for applying IterativeThought to the Game of 24 problem.""" 16 | 17 | import textwrap 18 | 19 | from onetwo.agents import iterative_thought 20 | 21 | 22 | GAME_OF_24_DESCRIPTION = textwrap.dedent("""\ 23 | Use numbers and basic arithmetic operations (+ - * /) to obtain 24. 24 | Each number can be only used once, but results can be used in subsequent 25 | steps. 26 | """) 27 | 28 | 29 | # Exemplars for use with IterativeThoughtAgent. 30 | GAME_OF_24_EXAMPLES = [ 31 | iterative_thought.IterativeThoughtState( 32 | inputs='10 5 2 11', 33 | updates=[ 34 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)', 35 | '10 5 22, we try 10 / 5 = 2 (remaining: 22 2)', 36 | '22 2, we try 22 + 2 = 24 (remaining: 24, success: we got 24)', 37 | ], 38 | ), 39 | iterative_thought.IterativeThoughtState( 40 | inputs='10 5 2 11', 41 | updates=[ 42 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)', 43 | '10 5 22, we try 22 - 5 = 17 (remaining: 10 17)', 44 | '10 17, we try 10 + 17 = 27 ' 45 | + '(remaining: 27, failure: we did not get 24)', 46 | ], 47 | ), 48 | iterative_thought.IterativeThoughtState( 49 | inputs='10 5 2 11', 50 | updates=[ 51 | '10 5 2 11, we try 11 - 10 = 1 (remaining: 5 2 1)', 52 | '5 2 1, we try 2 + 1 = 3 (remaining: 5 3)', 53 | '5 3, we try 5 * 3 = 15 ' 54 | + '(remaining: 15, failure: we did not get 24)', 55 | ], 56 | ), 57 | iterative_thought.IterativeThoughtState( 58 | inputs='2 6 5 3', 59 | updates=[ 60 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)', 61 | '2 6 2, we try 2 * 6 = 12 (remaining: 2 12)', 62 | '2 12, we try 2 * 12 = 24 (remaining: 24, success: we got 24)', 63 | ], 64 | ), 65 | iterative_thought.IterativeThoughtState( 66 | inputs='2 6 5 3', 67 | updates=[ 68 | '2 6 5 3, we try 3 * 2 = 6 (remaining: 5 6 6)', 69 | '5 6 6, we try 6 + 6 = 12 (remaining: 5 12)', 70 | '5 12, we try 12 + 5 = 17 ' 71 | + '(remaining: 17, failure: we did not get 24)', 72 | ], 73 | ), 74 | iterative_thought.IterativeThoughtState( 75 | inputs='2 6 5 3', 76 | updates=[ 77 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)', 78 | '2 6 2, we try 6 / 2 = 3 (remaining: 2 3)', 79 | '2 3, we try 2 * 3 = 6 ' 80 | + '(remaining: 6, failure: we did not get 24)', 81 | ], 82 | ), 83 | iterative_thought.IterativeThoughtState( 84 | inputs='2 6 5 3', 85 | updates=[ 86 | '2 6 5 3, we try 2 * 5 = 10 (remaining: 6 3 10)', 87 | '6 3 10, we try 6 * 10 = 60 (remaining: 3 60)', 88 | '3 60, we try 60 / 3 = 20 ' 89 | + '(remaining: 20, failure: we did not get 24)', 90 | ], 91 | ), 92 | ] 93 | 94 | 95 | # Exemplars for use with IterativeThoughtProposerAgent. 96 | GAME_OF_24_PROPOSER_EXAMPLES = [ 97 | iterative_thought.IterativeThoughtProposerExemplar( 98 | iterative_thought.IterativeThoughtState( 99 | inputs='10 5 2 11', 100 | updates=[], 101 | ), 102 | next_steps=[ 103 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)', 104 | '10 5 2 11, we try 11 - 10 = 1 (remaining: 5 2 1)', 105 | ], 106 | ), 107 | iterative_thought.IterativeThoughtProposerExemplar( 108 | iterative_thought.IterativeThoughtState( 109 | inputs='10 5 2 11', 110 | updates=[ 111 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)', 112 | ], 113 | ), 114 | next_steps=[ 115 | '10 5 22, we try 10 / 5 = 2 (remaining: 22 2)', 116 | '10 5 22, we try 22 - 5 = 17 (remaining: 10 17)', 117 | ], 118 | ), 119 | iterative_thought.IterativeThoughtProposerExemplar( 120 | iterative_thought.IterativeThoughtState( 121 | inputs='10 5 2 11', 122 | updates=[ 123 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)', 124 | '10 5 22, we try 10 / 5 = 2 (remaining: 22 2)', 125 | ], 126 | ), 127 | next_steps=[ 128 | '22 2, we try 22 + 2 = 24 (remaining: 24, success: we got 24)', 129 | ], 130 | ), 131 | iterative_thought.IterativeThoughtProposerExemplar( 132 | iterative_thought.IterativeThoughtState( 133 | inputs='10 5 2 11', 134 | updates=[ 135 | '10 5 2 11, we try 2 * 11 = 22 (remaining: 10 5 22)', 136 | '10 5 22, we try 22 - 5 = 17 (remaining: 10 17)', 137 | ], 138 | ), 139 | next_steps=[ 140 | '10 17, we try 10 + 17 = 27 ' 141 | + '(remaining: 27, failure: we did not get 24)', 142 | ], 143 | ), 144 | iterative_thought.IterativeThoughtProposerExemplar( 145 | iterative_thought.IterativeThoughtState( 146 | inputs='2 6 5 3', 147 | updates=[], 148 | ), 149 | next_steps=[ 150 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)', 151 | '2 6 5 3, we try 3 * 2 = 6 (remaining: 5 6 6)', 152 | '2 6 5 3, we try 2 * 5 = 10 (remaining: 6 3 10)', 153 | ], 154 | ), 155 | iterative_thought.IterativeThoughtProposerExemplar( 156 | iterative_thought.IterativeThoughtState( 157 | inputs='2 6 5 3', 158 | updates=[ 159 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)', 160 | ], 161 | ), 162 | next_steps=[ 163 | '2 6 2, we try 6 / 2 = 3 (remaining: 2 3)', 164 | '2 6 2, we try 2 * 6 = 12 (remaining: 2 12)', 165 | ], 166 | ), 167 | iterative_thought.IterativeThoughtProposerExemplar( 168 | iterative_thought.IterativeThoughtState( 169 | inputs='2 6 5 3', 170 | updates=[ 171 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)', 172 | '2 6 2, we try 2 * 6 = 12 (remaining: 2 12)', 173 | ], 174 | ), 175 | next_steps=[ 176 | '2 12, we try 2 * 12 = 24 (remaining: 24, success: we got 24)', 177 | ], 178 | ), 179 | iterative_thought.IterativeThoughtProposerExemplar( 180 | iterative_thought.IterativeThoughtState( 181 | inputs='2 6 5 3', 182 | updates=[ 183 | '2 6 5 3, we try 5 - 3 = 2 (remaining: 2 6 2)', 184 | '2 6 2, we try 6 / 2 = 3 (remaining: 2 3)', 185 | ], 186 | ), 187 | next_steps=[ 188 | '2 3, we try 2 * 3 = 6 (remaining: 6, failure: we did not get 24)', 189 | ], 190 | ), 191 | ] 192 | -------------------------------------------------------------------------------- /onetwo/core/updating.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for objects representing updates of results. 16 | 17 | When executing a long computation, we may want to monitor its progress. The 18 | Update class and its subclasses allow one to receive periodic updates on the 19 | progress and compose these updates into a result representing everything done so 20 | far. 21 | 22 | Assuming we have an iterator over updates of an appropriate type (subclass of 23 | `Update` class with relevant implementation of `__add__` and `to_result`), we 24 | can compose a result in the following way: 25 | 26 | ``` 27 | updates = Update() 28 | for update in process_yielding_updates: 29 | updates += update # `update` of type MyUpdate. 30 | print('Current result: ', updates.to_result()) 31 | 32 | print('Final result', updates.to_result()) 33 | ``` 34 | 35 | Note that the + operator is overloaded, so one can also use this syntax: 36 | 37 | ``` 38 | updates = Update() 39 | for update in process_yielding_updates: 40 | updates += update 41 | print('Current result: ', updates.to_result()) 42 | ``` 43 | Or one can use the `sum` operator as well: 44 | 45 | ``` 46 | final_result = sum(process_yielding_updates, start=Update()).to_result() 47 | ``` 48 | """ 49 | 50 | from __future__ import annotations 51 | 52 | import dataclasses 53 | from typing import Generic, Protocol, TypeVar, cast 54 | 55 | 56 | _T = TypeVar('_T') 57 | 58 | 59 | class _Addable(Protocol): 60 | """Protocol for objects that can be added together with a + operator.""" 61 | 62 | def __add__(self: _T, other: _T) -> _T: ... 63 | 64 | 65 | _Tadd = TypeVar('_Tadd', bound=_Addable) 66 | 67 | 68 | @dataclasses.dataclass 69 | class Update(Generic[_T]): 70 | """Updates that can be accumulated and used to get the result. 71 | 72 | Basic implementation where payload and result are of the same type. 73 | 74 | Attributes: 75 | payload: Content of the update. 76 | """ 77 | 78 | payload: _T | None = None 79 | 80 | def __add__(self, other: Update[_T] | _T) -> Update[_T]: 81 | """Incorporate this update. By default we overwrite all previous updates.""" 82 | if isinstance(other, Update): 83 | return other 84 | else: 85 | return Update(other) 86 | 87 | def to_result(self) -> _T | None: 88 | """Produce a final result from the accumulated updates.""" 89 | if isinstance(self.payload, Update): 90 | return cast(Update[_T], self.payload).to_result() 91 | else: 92 | return self.payload 93 | 94 | def to_simplified_result(self) -> _T: 95 | """Produces a simplified result from the accumulated updates.""" 96 | if isinstance(self.payload, Update): 97 | return cast(Update[_T], self.payload).to_simplified_result() 98 | else: 99 | return self.payload 100 | 101 | 102 | @dataclasses.dataclass 103 | class AddableUpdate(Generic[_Tadd], Update[_Tadd]): 104 | """Update class for adding objects together with a + operator. 105 | 106 | This could be used for lists, but unlike the ListUpdate the order is not 107 | necessarily preserved: if several processes yield such updates in parallel 108 | there is no guarantee that the final list will have retain the order of the 109 | processes for example. 110 | """ 111 | 112 | payload: _Tadd | None = None 113 | 114 | def __add__( 115 | self, other: AddableUpdate[_Tadd] | _Tadd 116 | ) -> AddableUpdate[_Tadd]: 117 | """See base class.""" 118 | if isinstance(other, AddableUpdate): 119 | self.payload += other.payload 120 | else: 121 | self.payload += other 122 | return self 123 | 124 | 125 | @dataclasses.dataclass 126 | class ListUpdate(Generic[_T], Update[_T]): 127 | """Update class for maintaining a list. 128 | 129 | End result is of type list[T]. Every intermediate update (payload) provides 130 | a list of values and indices where these values need to go in the final 131 | result. 132 | 133 | For instance, ListUpdate([('ab', 10), (Update('b'), 23)]) tells us that final 134 | result is of type list['str'] and that elements in the final result with 135 | indices 10 and 23 should be set to 'ab' and 'b' respectively. 136 | 137 | We can also maintain nested lists. For example: 138 | ListUpdate([ 139 | (ListUpdate([32, 1]), 0), 140 | (ListUpdate(10, 2), 1), 141 | ]) 142 | tells us that final result is of type list[list[int]] and after this update 143 | (assuming it is the only update we accumulated) the final result will be: 144 | [[None, 32], [None, None, 10]]. 145 | 146 | Attributes: 147 | payload: List of tuples [value_or_update, index] that contain a value 148 | (possibly wrapped in an Update) together with index where this value 149 | should go in the final result. 150 | """ 151 | payload: list[ 152 | tuple[Update[_T] | _T, int] 153 | ] = dataclasses.field(default_factory=list) 154 | 155 | def __add__(self, other: ListUpdate[_T]) -> ListUpdate[_T]: 156 | """See base class.""" 157 | for update_or_value, index in other.payload: 158 | # Indices accumulated so far. 159 | accumulated_indices = [u[1] for u in self.payload] 160 | if index in accumulated_indices: 161 | index_in_payload = [u[1] for u in self.payload].index(index) 162 | found_update_or_value = self.payload[index_in_payload][0] 163 | if isinstance(found_update_or_value, ListUpdate): 164 | # Nested list case. 165 | self.payload[index_in_payload] = ( 166 | found_update_or_value + update_or_value, index 167 | ) 168 | else: 169 | # If the content is not a nested list update we just use the value to 170 | # replace the current element. 171 | self.payload[index_in_payload] = (update_or_value, index) 172 | else: 173 | # If we never saw the index before, just append it to the payload. 174 | self.payload.append((update_or_value, index)) 175 | return self 176 | 177 | def to_result(self) -> list[_T]: 178 | """See base class.""" 179 | if not self.payload: 180 | return [] 181 | largest_index = max([u[1] for u in self.payload]) 182 | result = [None] * (largest_index + 1) 183 | for update_or_value, index in self.payload: 184 | if isinstance(update_or_value, Update): 185 | result[index] = update_or_value.to_result() 186 | else: 187 | result[index] = update_or_value 188 | return result 189 | 190 | def to_simplified_result(self) -> list[_T]: 191 | """See base class.""" 192 | if not self.payload: 193 | return [] 194 | largest_index = max([u[1] for u in self.payload]) 195 | result = [None] * (largest_index + 1) 196 | for update_or_value, index in self.payload: 197 | if isinstance(update_or_value, Update): 198 | result[index] = update_or_value.to_simplified_result() 199 | else: 200 | result[index] = update_or_value 201 | return [r for r in result if r is not None] 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OneTwo 2 | 3 | TL;DR: OneTwo is a Python library designed to simplify interactions with large 4 | (language and multimodal) foundation models, primarily aimed at researchers 5 | in *prompting* and *prompting strategies*. 6 | 7 | Foundation Models are increasingly being used in complex scenarios with multiple 8 | back-and-forth interactions between the model and some traditional code 9 | (possibly generated by the model itself). This leads to the emergence of 10 | programs that combine ML models with traditional code. The goal of the OneTwo 11 | library is to enable the creation and execution of such programs. 12 | It is designed for researchers (and developers) who want to explore how to get 13 | the best out of foundational models in a situation where it is not necessarily 14 | possible to change their weights (i.e. perform fine-tuning). 15 | 16 | Some properties of OneTwo that are particularly impactful for researcher 17 | productivity include the following: 18 | 19 | - **Model-agnostic:** Provides a uniform API to access different models that 20 | can easily be swapped and compared. 21 | - **Flexible:** Supports implementation of arbitrarily complex computation 22 | graphs involving combinations of sequential and parallel operations, 23 | including interleaving of calls to foundation models and to other tools. 24 | - **Efficient:** Automatically optimizes request batching and other details of 25 | model server interactions under-the-hood for maximizing throughput, while 26 | allowing prompting strategies to be implemented straightforwardly, as if 27 | they were dealing with just single requests. 28 | - **Reproducible:** Automatically caches requests/replies for easy stop-and-go 29 | or replay of experiments. 30 | 31 | ### Features 32 | * **Uniform API**: OneTwo defines a set of primitives or so-called built-in 33 | functions representing the common ways to interact with foundation models. 34 | Since different open-source models or public APIs may have different 35 | capabilities and expose calls with different parameters, OneTwo attempts to 36 | provide a uniform way to access all of those, implementing some best-effort 37 | conversion to guarantee that a OneTwo program will run on any model. 38 | * **Convenient syntax**: Different syntaxes are proposed to enable writing 39 | prompts or sequences of prompts conveniently and in a modular fashion. 40 | One can use formatted strings or jinja templates or can assemble the prompt 41 | as a concatenation of (possibly custom-defined) function calls. 42 | * **Experimentation**: OneTwo offers functionality to easily run 43 | experiments on datasets and collect metrics, with automatic caching of the 44 | results so that minor modifications of the workflow can be replayed with 45 | minimal impact, without having to perform the same requests to the models, 46 | while at the same time being aware of the sampling behaviour (e.g. when 47 | different samples are needed for a given request). 48 | * **Tool use**: OneTwo supports different ways of calling tools from a model, 49 | including running model-generated Python code involving tool calls in a 50 | sandbox. 51 | * **Agents**: OneTwo provides abstractions for defining agents, i.e. functions 52 | that perform some complex action in a step-by-step manner, while updating some 53 | internal state. These agents can be combined naturally and used within 54 | generic optimization algorithms. 55 | * **Execution model**: The execution model takes care of the details of 56 | orchestrating the calls to the models and tools. It offers the following 57 | functionality: 58 | * **Asynchronous execution**: Thanks to the use of asynchronous execution, the 59 | actual calls to the models and tools can be performed in parallel. However, 60 | the user can define a complex workflow in an intuitive and functional manner 61 | without having to think of the details of the execution. 62 | * **Batching**: Similarly, whenever the backends support multiple 63 | simultaneous requests, or batched requests, the OneTwo library will leverage 64 | this and group the requests for maximizing throughput. 65 | * **Smart Caching**: In addition to making it easy to replay all or part of an 66 | experiment, the caching provided by the library handles the case of 67 | multiple random samples obtained from one model, keeping track of how many 68 | samples have been obtained for each request, so that no result is wasted 69 | and actual requests are minimized while the user is tuning their workflow. 70 | 71 | ## Quick start 72 | 73 | ### Installation 74 | 75 | You may want to install the package in a virtual environment, in which case you 76 | will need to start a virtual environment with the following command: 77 | 78 | ```shell 79 | python3 -m venv PATH_TO_DIRECTORY_FOR_VIRTUAL_ENV 80 | # Activate it. 81 | . PATH_TO_DIRECTORY_FOR_VIRTUAL_ENV/bin/activate 82 | ``` 83 | 84 | Once you no longer need it, this virtual environment can be deleted with the 85 | following command: 86 | 87 | ```shell 88 | deactivate 89 | ``` 90 | 91 | Install the package: 92 | 93 | ```shell 94 | pip install git+https://github.com/google-deepmind/onetwo 95 | ``` 96 | 97 | To start using it, import it with: 98 | 99 | ```python 100 | from onetwo import ot 101 | ``` 102 | 103 | ### Running unit tests 104 | 105 | In order to run the tests, first clone the repository: 106 | 107 | ```shell 108 | git clone https://github.com/google-deepmind/onetwo 109 | ``` 110 | 111 | Then from the cloned directory you can invoke `pytest`: 112 | 113 | ```shell 114 | pytest onetwo/core 115 | ``` 116 | 117 | However, doing `pytest onetwo` will not work as pytest collects all the test 118 | names without keeping their directory of origin so there may be name clashes, so 119 | you have to loop through the subdirectories. 120 | 121 | 122 | ## Tutorial and Documentation 123 | 124 | This 125 | [Colab](https://colab.research.google.com/github/google-deepmind/onetwo/blob/main/colabs/tutorial.ipynb) 126 | is a good starting point and demonstrates most of the features available in 127 | onetwo. 128 | 129 | Some background on the basic concepts of the library can be found here: 130 | [Basics](docs/basics.md). 131 | 132 | Some of the frequently asked questions are discussed here: [FAQ](docs/faq.md). 133 | 134 | ## Citing OneTwo 135 | 136 | To cite this repository: 137 | 138 | ```bibtex 139 | @software{onetwo2024github, 140 | author = {Olivier Bousquet and Nathan Scales and Nathanael Sch{\"a}rli and Ilya Tolstikhin}, 141 | title = {{O}ne{T}wo: {I}nteracting with {L}arge {M}odels}, 142 | url = {https://github.com/google-deepmind/onetwo}, 143 | version = {0.3.0}, 144 | year = {2024}, 145 | } 146 | ``` 147 | 148 | In the above BibTeX entry, names are in alphabetical order, the version number 149 | is intended to be the one returned by `ot.__version__` (i.e., the latest version 150 | mentioned in [version.py](version.py) and in the [CHANGELOG](CHANGELOG.md), 151 | and the year corresponds to the project's open-source release. 152 | 153 | ## License 154 | 155 | Copyright 2024 DeepMind Technologies Limited 156 | 157 | This code is licensed under the Apache License, Version 2.0 (the \"License\"); 158 | you may not use this file except in compliance with the License. You may obtain 159 | a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. 160 | 161 | Unless required by applicable law or agreed to in writing, software distributed 162 | under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR 163 | CONDITIONS OF ANY KIND, either express or implied. See the License for the 164 | specific language governing permissions and limitations under the License. 165 | 166 | ## Disclaimer 167 | 168 | This is not an official Google product. -------------------------------------------------------------------------------- /onetwo/core/routing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Registry implementation. 16 | 17 | The execution code emits various kinds of requests and they should be 18 | processed by different types of backends. 19 | A backend may process several kinds of requests (e.g. LLM backends may process 20 | both `generate_text` and `score_text` requests). 21 | """ 22 | 23 | from __future__ import annotations 24 | from collections.abc import Callable, MutableMapping, AsyncIterator 25 | import contextvars 26 | import copy 27 | import dataclasses 28 | from typing import Any, TypeAlias, TypeVar, final, Generic 29 | 30 | from onetwo.core import executing 31 | from onetwo.core import updating 32 | 33 | # This context variable contains the global function registry that can be used 34 | # anywhere in the code. 35 | # It should be accessed via the function_registry wrapper. 36 | _function_registry_var = contextvars.ContextVar[dict]( 37 | 'registry', default=dict() 38 | ) 39 | 40 | 41 | _RegistryEntry: TypeAlias = Callable[..., Any] 42 | _RegistryData: TypeAlias = tuple[dict[str, Any], dict[str, Any]] 43 | _T = TypeVar('_T') 44 | 45 | 46 | class RegistryReference: 47 | """Parent class to indicate that a Registry entry is a reference object. 48 | 49 | See routing._Registry.copy() for additional details. 50 | 51 | This is used to indicate when to actually copy the entry in the registry. 52 | Indeed, when creating a copy of the registry, we only do a shallow copy since 53 | the registered entries might be complex objects (e.g. a bound method that is 54 | bound to a complex object) that we don't want to copy. 55 | But we may also register some shallow objects that hold references to a 56 | method for example. 57 | So we use this parent class to indicate to _Registry.copy that it can 58 | copy it. 59 | """ 60 | 61 | 62 | class _Registry(MutableMapping[str, _RegistryEntry]): 63 | """Registry to store the mapping between names and functions.""" 64 | 65 | @executing.make_executable # pytype: disable=wrong-arg-types 66 | async def __call__(self, destination: str, *args, **kwargs) -> Any: 67 | # We call the registry function. 68 | result = _function_registry_var.get()[destination](*args, **kwargs) 69 | # If the result is an Executable, it will be executed, due to the 70 | # make_executable decorator which executes results by default. 71 | return result 72 | 73 | def __getitem__(self, key: str): 74 | if key not in _function_registry_var.get(): 75 | raise KeyError( 76 | f'Key "{key}" not registered in registry:' 77 | f' {_function_registry_var.get().keys()}' 78 | ) 79 | return _function_registry_var.get()[key] 80 | 81 | def __setitem__(self, key: str, item: _RegistryEntry): 82 | _function_registry_var.get()[key] = item 83 | 84 | def __delitem__(self, key: str): 85 | del _function_registry_var.get()[key] 86 | 87 | def __iter__(self): 88 | return iter(_function_registry_var.get()) 89 | 90 | def __len__(self): 91 | return len(_function_registry_var.get()) 92 | 93 | def __repr__(self): 94 | return repr(_function_registry_var.get()) 95 | 96 | 97 | def _copy_also_references(registry: dict[str, Any]) -> dict[str, Any]: 98 | """Returns a copy of the registry. 99 | 100 | This performs a shallow copy, but the entries that are references, i.e. 101 | of type executing.RegistryReference will be copied. 102 | This allows in particular to get the expected behaviour when copying 103 | a registry that contains builtin functions that have been configured, 104 | as the configuration is stored in a reference object and needs to be copied. 105 | 106 | Args: 107 | registry: The registry to copy. 108 | 109 | Returns: 110 | A copy of the registry. 111 | """ 112 | registry_copy = dict() 113 | for key, entry in registry.items(): 114 | registry_copy[key] = ( 115 | copy.copy(entry) 116 | if isinstance(entry, RegistryReference) 117 | else entry 118 | ) 119 | return registry_copy 120 | 121 | 122 | function_registry = _Registry() 123 | config_registry = contextvars.ContextVar[dict]( 124 | 'config_registry', default=dict() 125 | ) 126 | 127 | 128 | def copy_registry() -> _RegistryData: 129 | return _copy_also_references(_function_registry_var.get()), copy.copy( 130 | config_registry.get() 131 | ) 132 | 133 | 134 | def set_registry(registry: _RegistryData) -> None: 135 | _function_registry_var.set(registry[0]) 136 | config_registry.set(registry[1]) 137 | 138 | 139 | @dataclasses.dataclass 140 | class _RegistryDataWrapper( 141 | Generic[_T], executing.Executable[_T] 142 | ): 143 | """Wraps an executable with the registry data. 144 | 145 | Attributes: 146 | wrapped: Executable to be wrapped. 147 | registry: Registry to use when executing this Executable. 148 | """ 149 | 150 | wrapped: executing.Executable[_T] 151 | registry: _RegistryData 152 | 153 | @final 154 | async def _aiterate( 155 | self, iteration_depth: int = 1 156 | ) -> AsyncIterator[updating.Update[_T]]: 157 | """Yields the intermediate values and calls the final_value_callback.""" 158 | with RegistryContext(self.registry): 159 | it = self.wrapped.with_depth(iteration_depth).__aiter__() 160 | while True: 161 | with RegistryContext(self.registry): 162 | try: 163 | update = await it.__anext__() 164 | except StopAsyncIteration: 165 | break 166 | yield update 167 | 168 | @final 169 | async def _aexec(self) -> _T: 170 | """Iterate this value until done (including calling final_value_callback). 171 | 172 | Returns: 173 | The final value given by the AsyncIterator _inner(). 174 | """ 175 | with RegistryContext(self.registry): 176 | result = await self.wrapped 177 | return result 178 | 179 | 180 | def with_current_registry( 181 | executable: executing.Executable, 182 | ) -> _RegistryDataWrapper: 183 | """Wraps an executable and attaches the current registry to it.""" 184 | registry = copy_registry() 185 | return _RegistryDataWrapper(executable, registry) 186 | 187 | 188 | def with_registry( 189 | executable: executing.Executable, registry: _RegistryData 190 | ) -> _RegistryDataWrapper: 191 | """Wraps an executable and attaches the given registry to it.""" 192 | return _RegistryDataWrapper(executable, registry) 193 | 194 | 195 | @dataclasses.dataclass 196 | class RegistryContext: 197 | """Context Manager to update the registry locally. 198 | 199 | Attributes: 200 | registry: An optional registry to use in this context manager. 201 | If None is provided, makes a "local" copy of the function_registry 202 | that can be modified within the context. 203 | """ 204 | registry: _RegistryData | None = None 205 | 206 | def __enter__(self): 207 | if self.registry is None: 208 | # We create a copy of the current function_registry. 209 | updated = _copy_also_references(_function_registry_var.get()) 210 | updated_config = copy.copy(config_registry.get()) 211 | else: 212 | updated, updated_config = self.registry 213 | # We replace the current function_registry by the copy. 214 | self._token = _function_registry_var.set(updated) 215 | self._token_config = config_registry.set(updated_config) 216 | 217 | def __exit__(self, exc_type, exc_val, exc_tb): 218 | # We reset the function_registry to its value prior to entering the 219 | # context manager. 220 | _function_registry_var.reset(self._token) 221 | config_registry.reset(self._token_config) 222 | -------------------------------------------------------------------------------- /onetwo/core/executing_with_context_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | import dataclasses 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from onetwo.core import executing 21 | from onetwo.core import executing_with_context 22 | from onetwo.core import updating 23 | from typing_extensions import override 24 | 25 | 26 | def _collect_stream( 27 | executable: executing.Executable, 28 | depth: int = 1, 29 | ) -> list[str]: 30 | res = [] 31 | with executing.safe_stream(executable, iteration_depth=depth) as stream: 32 | updates = updating.Update[str]() 33 | for update in stream: 34 | updates += update 35 | result = updates.to_result() 36 | assert isinstance(result, str) or isinstance(result, list) 37 | res.append(result) 38 | return res 39 | 40 | 41 | @dataclasses.dataclass 42 | class ContextForTest: 43 | content: str = dataclasses.field(default_factory=str) 44 | 45 | 46 | @dataclasses.dataclass 47 | class ExecutableWithContextForTest( 48 | executing_with_context.ExecutableWithContext[ContextForTest, str] 49 | ): 50 | content: str = dataclasses.field(default_factory=str) 51 | 52 | @override 53 | def initialize_context(self, *args, **kwargs) -> ContextForTest: 54 | return ContextForTest(*args) 55 | 56 | @classmethod 57 | @override 58 | def wrap(cls, other: str) -> ExecutableWithContextForTest: 59 | return ExecutableWithContextForTest(content=other) 60 | 61 | @classmethod 62 | @override 63 | def get_result(cls, context: ContextForTest) -> str: 64 | return context.content 65 | 66 | @override 67 | @executing.make_executable # pytype: disable=wrong-arg-types 68 | async def execute( 69 | self, 70 | context: ContextForTest, 71 | ) -> str: 72 | # Add the content of this node to the context. 73 | context.content += self.content 74 | # Return the content between quotes. 75 | return f'"{self.content}"' 76 | 77 | 78 | @dataclasses.dataclass 79 | class SerialContextForTest: 80 | content: list[str] = dataclasses.field(default_factory=list) 81 | 82 | 83 | @dataclasses.dataclass 84 | class SerialExecutableWithContextForTest( 85 | executing_with_context.SerialExecutableWithContext[ 86 | SerialContextForTest, list[str] 87 | ] 88 | ): 89 | 90 | @override 91 | @classmethod 92 | def empty_result(cls) -> list[str]: 93 | return [] 94 | 95 | @override 96 | def initialize_context(self, *args, **kwargs) -> SerialContextForTest: 97 | return SerialContextForTest(list(*args)) 98 | 99 | @classmethod 100 | @override 101 | def wrap(cls, other: str) -> SingleStepSerialExecutableWithContextForTest: 102 | return SingleStepSerialExecutableWithContextForTest(content=other) 103 | 104 | @classmethod 105 | @override 106 | def get_result(cls, context: SerialContextForTest) -> str: 107 | return ''.join(context.content) 108 | 109 | 110 | @dataclasses.dataclass 111 | class SingleStepSerialExecutableWithContextForTest( 112 | SerialExecutableWithContextForTest 113 | ): 114 | content: str = dataclasses.field(default_factory=str) 115 | 116 | @override 117 | @executing.make_executable # pytype: disable=wrong-arg-types 118 | def execute(self, context: ContextForTest) -> list[str]: 119 | # Add the content of this node to the context. 120 | context.content += self.content 121 | # Return the content between quotes. 122 | return [f'"{self.content}"'] 123 | 124 | @override 125 | @executing.make_executable # pytype: disable=wrong-arg-types 126 | def iterate( 127 | self, context: ContextForTest, iteration_depth: int = 1 128 | ) -> list[str]: 129 | del iteration_depth 130 | # Add the content of this node to the context. 131 | context.content += self.content 132 | # Return the content between quotes. 133 | return [f'"{self.content}"'] 134 | 135 | 136 | class ExecutingWithContextTest(parameterized.TestCase): 137 | 138 | def test_execute_simple(self): 139 | e = ExecutableWithContextForTest(content='hello') 140 | with self.subTest('run_directly'): 141 | res = executing.run(e) 142 | self.assertEqual(res, 'hello') 143 | 144 | with self.subTest('result_has_quotes'): 145 | res = executing.run(e.execute(ContextForTest())) 146 | self.assertEqual(res, '"hello"') 147 | 148 | with self.subTest('run_with_arguments'): 149 | res = executing.run(e('prefix ')) 150 | self.assertEqual(res, 'prefix hello') 151 | 152 | with self.subTest('run_with_resetting'): 153 | res = executing.run(e()) 154 | self.assertEqual(res, 'hello') 155 | 156 | with self.subTest('stream'): 157 | res = _collect_stream(e()) 158 | self.assertListEqual(res, ['hello']) 159 | 160 | with self.subTest('stream_iterate'): 161 | res = _collect_stream(e.iterate(ContextForTest())) 162 | self.assertListEqual(res, ['"hello"']) 163 | 164 | def test_execute_serial(self): 165 | e = SerialExecutableWithContextForTest() 166 | e += 'hello ' 167 | e += 'world' 168 | 169 | with self.subTest('run_directly'): 170 | res = executing.run(e) 171 | self.assertEqual(res, 'hello world') 172 | 173 | with self.subTest('result_is_a_list_with_quotes'): 174 | res = executing.run(e.execute(SerialContextForTest())) 175 | self.assertListEqual(res, ['"hello "', '"world"']) 176 | 177 | with self.subTest('stored_result_is_correct_after_run'): 178 | _ = executing.run(e) 179 | result_field = e._result 180 | assert isinstance(result_field, list) 181 | self.assertListEqual(result_field, ['"hello "', '"world"']) 182 | 183 | with self.subTest('stored_result_is_correct_after_stream'): 184 | _ = _collect_stream(e()) 185 | self.assertListEqual(e._result, ['"hello "', '"world"']) 186 | 187 | with self.subTest('run_with_arguments'): 188 | res = executing.run(e(['prefix '])) 189 | self.assertEqual(res, 'prefix hello world') 190 | 191 | with self.subTest('result_with_arguments'): 192 | res = executing.run(e.execute(SerialContextForTest(['prefix ']))) 193 | self.assertEqual(res, ['"hello "', '"world"']) 194 | 195 | with self.subTest('run_with_resetting'): 196 | res = executing.run(e()) 197 | self.assertEqual(res, 'hello world') 198 | 199 | with self.subTest('stream_iterate_returns_lists'): 200 | res = _collect_stream(e.iterate(SerialContextForTest())) 201 | self.assertListEqual(res, [['"hello "'], ['"hello "', '"world"']]) 202 | 203 | def test_right_addition(self): 204 | e = 'hello ' 205 | e += SerialExecutableWithContextForTest() 206 | e += 'world' 207 | 208 | with self.subTest('run_directly'): 209 | res = executing.run(e) 210 | self.assertEqual(res, 'hello world') 211 | 212 | with self.subTest('result_is_a_list_with_quotes'): 213 | res = executing.run(e.execute(SerialContextForTest())) 214 | self.assertListEqual(res, ['"hello "', '"world"']) 215 | 216 | with self.subTest('run_with_arguments'): 217 | res = executing.run(e(['prefix '])) 218 | self.assertEqual(res, 'prefix hello world') 219 | 220 | with self.subTest('run_with_resetting'): 221 | res = executing.run(e()) 222 | self.assertEqual(res, 'hello world') 223 | 224 | with self.subTest('stream_iterate'): 225 | res = _collect_stream(e.iterate(SerialContextForTest())) 226 | self.assertListEqual(res, [['"hello "'], ['"hello "', '"world"']]) 227 | 228 | 229 | if __name__ == '__main__': 230 | absltest.main() 231 | -------------------------------------------------------------------------------- /onetwo/builtins/llm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Technical details and utilities for llm builtins.""" 16 | 17 | from __future__ import annotations 18 | 19 | 20 | from collections.abc import Mapping, Sequence 21 | import enum 22 | from typing import cast, Any, TypeAlias 23 | 24 | from onetwo.core import content as content_lib 25 | 26 | _ChunkList: TypeAlias = content_lib.ChunkList 27 | 28 | 29 | class TokenHealingOption(enum.Enum): 30 | """Options for token healing. 31 | 32 | Token healing is a technique that can often significantly improve generation 33 | results (or rather help to avoid erroneous results). Consider prompts ending 34 | with a whitespace, e.g., "Theory of ". Depending on the model and the 35 | tokenizer, this string is likely to be encoded into a token sequence where the 36 | last element is a whitespace token (WT) corresponding to a single whitespace " 37 | ". The model will then generate next tokens, following WT. An issue with this 38 | scenario is that words are often tokenized together with the leading 39 | whitespace. Likely, there is " relativity" token (RT) and for the majority of 40 | sentenses where the word "relativity" (case-sensitive!) occured, it was 41 | encoded into RT. So the model may be very "familiar" with RT, but not with 42 | "relativity" token (no space). Unfortunately, a separate whitespace token 43 | already appears in the end of the sequence, so the model can't generate RT 44 | (indeed, two whitespaces in a row is very uncommon, so the model won't do it) 45 | and instead generates something else (that often looks odd). This phenomenon 46 | is more general and can occur with other tokens in the end of the prompt as 47 | well. For more details refer to https://github.com/guidance-ai/guidance. 48 | """ 49 | NONE = 'NONE' 50 | # In token healing we replace the last token of the prompt and send the prompt 51 | # to completion. In the ideal scenario where constrained decoding is available 52 | # we make sure that the string representation of the first token that we 53 | # generate contains the string representation of the removed token as a 54 | # prefix. This way we guarantee that the completion matches user's 55 | # expectations and we avoid token boundary artifacts at the same time. If the 56 | # constrained decoding is not available for the model we sample completions 57 | # in hope that the constraint will be satisfied and roll back to vanilla 58 | # completion in case of failure. 59 | TOKEN_HEALING = 'TOKEN_HEALING' 60 | # In space healing we remove trailing whitespaces from the end of the prompt 61 | # and send the prompt to completion. In case the prompt has trailing 62 | # whitespaces we also remove any leading whitespaces from the generated text. 63 | # This may be considered as a less general ad-hoc implementation for the 64 | # token healing. 65 | SPACE_HEALING = 'SPACE_HEALING' 66 | 67 | 68 | def space_heal_reply( 69 | reply: str | tuple[str, Mapping[str, Any]] 70 | ) -> str | tuple[str, Mapping[str, Any]]: 71 | """Possibly remove spaces from the beginning of the reply. 72 | 73 | Args: 74 | reply: Reply that we want to space heal. Reply is either a string or a tuple 75 | of a string and a dictionary of additional information. 76 | 77 | Returns: 78 | Reply with leading whitespaces removed. In case of a tuple, the additional 79 | information is preserved. 80 | """ 81 | if isinstance(reply, str): 82 | reply: str = reply.lstrip(' ') 83 | elif isinstance(reply, tuple): 84 | reply, details = reply 85 | reply = reply.lstrip(' ') 86 | reply = (reply, details) 87 | else: 88 | raise ValueError(f'Unexpected type of the reply:{type(reply)}') 89 | return reply 90 | 91 | 92 | def maybe_heal_prompt( 93 | *, 94 | original_prompt: str | _ChunkList, 95 | healing_option: TokenHealingOption, 96 | ) -> _ChunkList: 97 | """Maybe heal the prompt for further generation. 98 | 99 | Args: 100 | original_prompt: Prompt that we may want to heal. 101 | healing_option: Healing option that we want to use. 102 | 103 | Returns: 104 | In case we use space healing, we remove trailing whitespaces from the 105 | prompt. The token healing is not supported yet. Otherwise, we return the 106 | prompt as is. 107 | """ 108 | if isinstance(original_prompt, str): 109 | original_prompt = _ChunkList([original_prompt]) 110 | healed_prompt: _ChunkList = original_prompt 111 | if healing_option == TokenHealingOption.SPACE_HEALING: 112 | # Remove trailing whitespaces. 113 | healed_prompt = healed_prompt.rstrip(' ') 114 | elif healing_option == TokenHealingOption.TOKEN_HEALING: 115 | # TODO: Support token healing. 116 | raise NotImplementedError('Token healing is not supported yet.') 117 | return healed_prompt 118 | 119 | 120 | def maybe_heal_reply( 121 | *, 122 | reply_text: str, 123 | original_prompt: str | _ChunkList, 124 | healing_option: TokenHealingOption, 125 | ) -> str: 126 | """Maybe heal the reply. 127 | 128 | Args: 129 | reply_text: Reply text that we may want to heal. 130 | original_prompt: The prompt that was used to generate the reply before any 131 | type of healing was applied to it. 132 | healing_option: Healing option that was used to generate the reply. 133 | 134 | Returns: 135 | In case we used space healing and the original prompt ended with a 136 | whitespace, we return the reply with leading whitespaces removed. The token 137 | healing is not supported yet. Otherwise, we return the reply as is. 138 | """ 139 | if ( 140 | healing_option == TokenHealingOption.SPACE_HEALING 141 | and original_prompt.endswith(' ') 142 | ): 143 | # We need to remove spaces from the beginning of the replies. 144 | return cast(str, space_heal_reply(reply_text)) 145 | elif healing_option == TokenHealingOption.TOKEN_HEALING: 146 | # TODO: Support token healing. 147 | raise NotImplementedError('Token healing is not supported yet.') 148 | return reply_text 149 | 150 | 151 | def maybe_heal_prompt_and_targets( 152 | *, 153 | original_prompt: str | _ChunkList, 154 | original_targets: Sequence[str], 155 | healing_option: TokenHealingOption, 156 | ) -> tuple[_ChunkList, Sequence[str]]: 157 | """Maybe heal the prompt and targets for further scoring. 158 | 159 | Args: 160 | original_prompt: Prompt that we may want to heal. 161 | original_targets: Targets that we may want to heal. 162 | healing_option: Healing option that we want to use. 163 | 164 | Returns: 165 | In case we use space healing, we remove trailing whitespaces from the 166 | prompt and add leading whitespaces to the targets when necessary. The token 167 | healing is not supported yet. Otherwise, we return the prompt and targets as 168 | is. 169 | """ 170 | if isinstance(original_prompt, str): 171 | original_prompt = _ChunkList([original_prompt]) 172 | healed_prompt: _ChunkList = original_prompt 173 | healed_targets = list(original_targets) 174 | if healing_option == TokenHealingOption.SPACE_HEALING: 175 | if healed_prompt.endswith(' '): 176 | # Remove trailing whitespaces. 177 | healed_prompt = healed_prompt.rstrip(' ') 178 | # Add leading whitespaces to the targets if necessary. 179 | for i, target in enumerate(healed_targets): 180 | if not target.startswith(' '): 181 | healed_targets[i] = ' ' + target 182 | elif healing_option == TokenHealingOption.TOKEN_HEALING: 183 | # TODO: Support token healing. 184 | raise NotImplementedError('Token healing is not supported yet.') 185 | return healed_prompt, healed_targets 186 | -------------------------------------------------------------------------------- /docs/basics.md: -------------------------------------------------------------------------------- 1 | # OneTwo Basics 2 | 3 | One of the key principles behind the OneTwo library is to enable the creation 4 | of complex flows involving several calls to foundation models and possibly other 5 | tools. 6 | For ease of experimentation, it is important to easily change the backends or 7 | their configuration and run the same flow on two backends/configurations, e.g. 8 | when doing comparisons. 9 | 10 | The bottleneck is often the multiple RPC requests that need to happen. This 11 | makes fast iterations or experimenting on many examples slow and tedious. In 12 | order to reduce this bottleneck, there are two strategies that are implemented 13 | in the OneTwo library: 14 | 15 | 1. **Caching**: The result of the calls to the models are cached, which enables 16 | one to very quickly replay a flow or an experiment which may have partially 17 | executed (e.g. failed in the middle of execution). For example, if you have a 18 | complex flow and want to add just one extra step, rerunning the whole thing 19 | amounts to reading everything from cache and only executing for real that one 20 | last step. 21 | 1. **Asynchronous Execution**: While some of the model calls might need to be 22 | chained serially, there are many situations when you may want to execute some 23 | calls in parallel (e.g. talking to different backends, running an experiment on 24 | many examples, or having a step in your flow where several independent tasks are 25 | performed). A natural way to do that is to use asynchronous programming, or 26 | multi-threading. 27 | 28 | ## Builtins 29 | 30 | In order to use a uniform language, we define a number of "built-in" functions 31 | representing the basic operations one may want to perform using a model. 32 | 33 | - `llm.generate_text()` - Generate raw text. 34 | - `llm.generate_object()` - Generate and parse text into a Python object. 35 | - `llm.select()` - Choose among alternatives. 36 | - `llm.instruct()` - Generate answer to instructions. 37 | - `llm.chat()` - Generate text in a multi-turn dialogue. 38 | 39 | We also decouple these functions from their implementations. 40 | So you can use them to define your **prompting strategy**, without specifying 41 | which **model** or which **model parameters** you want to use, and only specify 42 | those later. 43 | 44 | ## Executables 45 | 46 | Many of the basic functions provided by the library actually return what we call 47 | *Executables*. For example: 48 | 49 | ```python 50 | from onetwo.builtins import llm 51 | 52 | e = llm.generate_text( 53 | 'Q: What are three not so well known cities in france?\nA:', 54 | stop=['Q:'], 55 | max_tokens=20, 56 | ) 57 | # > Troyes, Dijon, Annecy. 58 | ``` 59 | 60 | Now this `e` variable has type `Executable[str]` and needs to be *executed* to 61 | produce the final result. This happens by calling `ot.run()`: 62 | 63 | ```python 64 | from onetwo import ot 65 | 66 | result = ot.run(e) 67 | ``` 68 | The benefit of this two-step process is that one can define possibly complex 69 | execution flows in a natural pythonic way, and decouple the definition of the 70 | flow from the actual backends that are used to execute it. 71 | 72 | ## Function Registry 73 | 74 | Specifying which backend is used to actually perform a built-in function like 75 | `llm.generate_text` is done when calling the `register()` method on a backend. 76 | This method registers the various function calls that the backend supports into 77 | a global function registry. 78 | 79 | You can temporarily override this registry if you want the calls to 80 | `llm.generate_text` to be routed elsewhere. 81 | 82 | For example, 83 | 84 | ```python 85 | backend = .... 86 | backend.register() 87 | 88 | def fake_generate_text(prompt: str | ChunkList): 89 | return prompt 90 | 91 | with ot.RegistryContext(): 92 | llm.generate_text.configure(fake_generate_text) 93 | print(ot.run(llm.generate_text('test'))) 94 | ``` 95 | 96 | As another example, assume you have two different backends, then it is possible 97 | to create two distinct registries and pick one of those at execution time: 98 | 99 | ```python 100 | backend1 = ... 101 | backend2 = ... 102 | 103 | with ot.RegistryContext(): 104 | backend1.register() 105 | registry1 = ot.copy_registry() 106 | 107 | with ot.RegistryContext(): 108 | backend2.register() 109 | registry2 = ot.copy_registry() 110 | ``` 111 | 112 | ```python 113 | ot.run(ot.with_registry(e, registry1)) 114 | ot.run(ot.with_registry(e, registry2)) 115 | ``` 116 | 117 | ## Asynchronous execution 118 | 119 | While it may take a bit of time to get used to `asyncio` if you never used it, 120 | we tried to make it as simple as possible. 121 | 122 | So if you need to perform two sequential calls to an LLM, you can of course run 123 | one after the other: 124 | 125 | ```python 126 | result = ot.run(llm.generate_text('Q: What is the southernmost city in France? A:')) 127 | result2 = ot.run(llm.generate_text(f'Q: Who is the mayor of {result}? A:')) 128 | ``` 129 | 130 | But a better way is to create a single Executable by combining them into a 131 | function decorated with `@ot.make_executable`: 132 | 133 | ```python 134 | @ot.make_executable 135 | async def f(*any_arguments): 136 | del any_arguments # This example does not use arguments. 137 | result = await llm.generate_text('Q: What is the southernmost city in France? A:') 138 | result2 = await llm.generate_text(f'Q: Who is the mayor of {result}? A:') 139 | return result2 140 | 141 | result = ot.run(f()) 142 | ``` 143 | 144 | Indeed, the `ot.run()` function will actually block on execution and will 145 | only return when the LLM has produced the output, while when an async function 146 | is created with multiple await calls in its body, the execution of this function 147 | can be interleaved with the execution of other async functions. This will be 148 | beneficial when creating complex workflows with multiple calls as they can be 149 | scheduled automatically in an optimal way. For example, if we were to repeatedly 150 | call the `f` function on different inputs the inner generate_text calls could be 151 | interleaved (see next section). 152 | 153 | Functions decorated with `@ot.make_executable` return `Executable` objects 154 | when called. I.e., `f()` is of type `Executable` and can be executed with 155 | `ot.run()`. 156 | 157 | ## Combining executables 158 | 159 | We provide in particular a way to combine executables in parallel: 160 | 161 | ```python 162 | e1 = llm.generate_text('Q: What is the southernmost city in France? A:') 163 | e2 = llm.generate_text('Q: What is the southernmost city in Spain? A:') 164 | e = ot.parallel(e1, e2) 165 | results = ot.run(e) 166 | ``` 167 | 168 | The creation of the `Executable` `e` as a parallel composition of two 169 | executables indicates that one does not depend on the output of the other and 170 | the calls to the LLM can thus be performed in parallel, assuming that the 171 | backend supports it. Typically a backend will be a proxy to a remote server that 172 | may support multiple simultaneous calls from different threads, or that may 173 | support sending requests in batches. In this case, the execution will 174 | automatically take advantage of this functionality to speed things up and not 175 | having to wait for the first `generate_text` call to return before performing 176 | the second one. 177 | 178 | ## Composing multi-step prompts 179 | 180 | While using `async`, `await`, and `ot.parallel` lets you create arbitrary 181 | complex flows using standard Python, there are cases where one might want a 182 | simpler way to specify basic or typical combinations. We thus also provide ways 183 | of composing prompts that can execute in multiple steps (i.e. involving multiple 184 | calls to a model). 185 | 186 | We support two different syntaxes for that: 187 | 188 | - Prompt templates in jinja2 templating language; 189 | - Composition via the `+` operator. 190 | 191 | ```python 192 | from onetwo.builtins import composables as c 193 | 194 | template = c.j("""\ 195 | What is the southernmost city in France? {{ generate_text() }} 196 | Who is its mayor? {{ generate_text() }} 197 | """) 198 | result = ot.run(template) 199 | ``` 200 | 201 | ```python 202 | e = 'What is the southernmost city in France?' + c.generate_text() + \ 203 | 'Who is its mayor?' + c.generate_text() 204 | result = ot.run(e) 205 | ``` -------------------------------------------------------------------------------- /onetwo/core/executing_impl_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from onetwo.agents import agents_test_utils 20 | from onetwo.core import executing 21 | from onetwo.core import executing_impl 22 | from onetwo.core import tracing 23 | 24 | 25 | def ordinary_function(x, y): 26 | return x + y 27 | 28 | 29 | async def async_function(x, y): 30 | return x + y 31 | 32 | 33 | @executing.make_executable # pytype: disable=wrong-arg-types 34 | def executable_function(x, y): 35 | return x + y 36 | 37 | 38 | @executing.make_executable # pytype: disable=wrong-arg-types 39 | async def async_executable_function(x, y): 40 | return x + y 41 | 42 | 43 | @tracing.trace # pytype: disable=wrong-arg-types 44 | def traced_ordinary_function(x, y): 45 | return x + y 46 | 47 | 48 | @tracing.trace # pytype: disable=wrong-arg-types 49 | async def traced_async_function(x, y): 50 | return x + y 51 | 52 | 53 | @executing.make_executable 54 | @tracing.trace # pytype: disable=wrong-arg-types 55 | def executable_traced_function(x, y): 56 | return x + y 57 | 58 | 59 | @tracing.trace 60 | @executing.make_executable # pytype: disable=wrong-arg-types 61 | def traced_executable_function(x, y): 62 | return x + y 63 | 64 | 65 | def executable_function_that_does_not_use_make_executable_decorator(x, y): 66 | return executing.serial(executable_function(x, y), executable_function(x, y)) 67 | 68 | 69 | class C: 70 | def ordinary_method(self, x, y): 71 | return x + y 72 | 73 | async def async_method(self, x, y): 74 | return x + y 75 | 76 | @executing.make_executable # pytype: disable=wrong-arg-types 77 | def executable_method(self, x, y): 78 | return x + y 79 | 80 | @executing.make_executable # pytype: disable=wrong-arg-types 81 | async def async_executable_method(self, x, y): 82 | return x + y 83 | 84 | 85 | class ExecutingImplTest(parameterized.TestCase): 86 | 87 | def test_set_decorated_with_make_executable(self): 88 | def f(x): 89 | return x 90 | 91 | with self.subTest('false_before_setting'): 92 | self.assertFalse(executing_impl.is_decorated_with_make_executable(f)) 93 | 94 | executing_impl.set_decorated_with_make_executable(f) 95 | 96 | with self.subTest('true_after_setting'): 97 | self.assertTrue(executing_impl.is_decorated_with_make_executable(f)) 98 | 99 | @parameterized.named_parameters( 100 | ('ordinary_function', ordinary_function, False), 101 | ('async_function', async_function, False), 102 | ('executable_function', executable_function, True), 103 | ('async_executable_function', async_executable_function, True), 104 | ('traced_ordinary_function', traced_ordinary_function, False), 105 | ('traced_async_function', traced_async_function, False), 106 | ('executable_traced_function', executable_traced_function, True), 107 | ('traced_executable_function', traced_executable_function, True), 108 | ('agent', agents_test_utils.StringAgent(), True), 109 | ('ordinary_method', C().ordinary_method, False), 110 | ('async_method', C().async_method, False), 111 | ('executable_method', C().executable_method, True), 112 | ('async_executable_method', C().async_executable_method, True), 113 | ( 114 | 'executable_function_that_does_not_use_make_executable_decorator', 115 | executable_function_that_does_not_use_make_executable_decorator, 116 | False, 117 | ), 118 | ) 119 | def test_is_decorated_with_make_executable(self, function, expected_result): 120 | self.assertEqual( 121 | expected_result, 122 | executing_impl.is_decorated_with_make_executable(function), 123 | ) 124 | 125 | @parameterized.named_parameters( 126 | ('ordinary_function', ordinary_function, False), 127 | ('async_function', async_function, True), 128 | ('executable_function', executable_function, True), 129 | ('async_executable_function', async_executable_function, True), 130 | ('traced_ordinary_function', traced_ordinary_function, False), 131 | ('traced_async_function', traced_async_function, True), 132 | ('executable_traced_function', executable_traced_function, True), 133 | ('traced_executable_function', traced_executable_function, True), 134 | ('agent', agents_test_utils.StringAgent(), True), 135 | ('ordinary_method', C().ordinary_method, False), 136 | ('async_method', C().async_method, True), 137 | ('executable_method', C().executable_method, True), 138 | ('async_executable_method', C().async_executable_method, True), 139 | # TODO: The following case should ideally return True but does 140 | # not currently, as it is difficult to predict that the function returns 141 | # an Executable in this case without actually calling it. 142 | ( 143 | 'executable_function_that_does_not_use_make_executable_decorator', 144 | executable_function_that_does_not_use_make_executable_decorator, 145 | False, 146 | ), 147 | ) 148 | def test_returns_awaitable(self, function, expected_result): 149 | self.assertEqual( 150 | expected_result, 151 | executing_impl.returns_awaitable(function), 152 | ) 153 | 154 | @parameterized.named_parameters( 155 | ('ordinary_function', ordinary_function), 156 | ('async_function', async_function), 157 | ('executable_function', executable_function), 158 | ('async_executable_function', async_executable_function), 159 | ('traced_ordinary_function', traced_ordinary_function), 160 | ('traced_async_function', traced_async_function), 161 | ('executable_traced_function', executable_traced_function), 162 | ('traced_executable_function', traced_executable_function), 163 | ('ordinary_method', C().ordinary_method), 164 | ('async_method', C().async_method), 165 | ('executable_method', C().executable_method), 166 | ('async_executable_method', C().async_executable_method), 167 | ) 168 | def test_call_and_maybe_await(self, function): 169 | result = executing.run( 170 | executing_impl.call_and_maybe_await(function, 'x', 'y') 171 | ) 172 | with self.subTest('positional_args'): 173 | self.assertEqual('xy', result) 174 | 175 | result = executing.run( 176 | executing_impl.call_and_maybe_await(function, x='x', y='y') 177 | ) 178 | with self.subTest('keyword_args'): 179 | self.assertEqual('xy', result) 180 | 181 | partial_function = functools.partial(function, x='x') 182 | result = executing.run( 183 | executing_impl.call_and_maybe_await(partial_function, y='y') 184 | ) 185 | with self.subTest('partial_function'): 186 | self.assertEqual('xy', result) 187 | 188 | def test_call_and_maybe_await_agent(self): 189 | function = agents_test_utils.StringAgent(sequence=['x', 'y']) 190 | result = executing.run( 191 | executing_impl.call_and_maybe_await(function, 'some_input') 192 | ) 193 | self.assertEqual('x y', result) 194 | 195 | def test_call_and_maybe_await_executable_function_that_does_not_use_make_executable_decorator( 196 | self, 197 | ): 198 | # Note that this case is handled correctly, even though `returns_awaitable` 199 | # does not return the correct value for this function. 200 | function = executable_function_that_does_not_use_make_executable_decorator 201 | result = executing.run( 202 | executing_impl.call_and_maybe_await(function, 'x', 'y') 203 | ) 204 | self.assertEqual(['xy', 'xy'], result) 205 | 206 | 207 | if __name__ == '__main__': 208 | absltest.main() 209 | -------------------------------------------------------------------------------- /onetwo/stdlib/code_execution/python_execution_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper methods, to be used by implementations.""" 16 | 17 | from __future__ import annotations 18 | 19 | import ast 20 | import datetime 21 | import importlib 22 | import json 23 | import logging 24 | import textwrap 25 | from typing import Any, Callable, TypeAlias 26 | 27 | from onetwo.stdlib.code_execution import python_execution 28 | 29 | # Aliases 30 | _ExecutionStatus: TypeAlias = python_execution.ExecutionStatus 31 | _SandboxStatus: TypeAlias = python_execution.SandboxStatus 32 | _SandboxResult: TypeAlias = python_execution.SandboxResult 33 | _SandboxResultTiming: TypeAlias = python_execution.SandboxResultTiming 34 | 35 | # Constants for hook communication 36 | DATA_KEY: str = 'data' 37 | EXCEPTION_KEY: str = 'exception' 38 | EX_ARGS_KEY: str = 'args' 39 | EX_CLASS_KEY: str = 'exception_class' 40 | EX_HOOK_NAME_KEY: str = 'hook_name' 41 | 42 | 43 | def current_timing( 44 | *, start: datetime.datetime, base: datetime.datetime 45 | ) -> _SandboxResultTiming: 46 | """Returns a timing object based on the current time and given start/base. 47 | 48 | Args: 49 | start: Time the execution request was sent to the sandbox. 50 | base: Time of last interaction with the sandbox (i.e., the later of the 51 | start time, and the time of the last callback to a hook function). 52 | """ 53 | now = datetime.datetime.now() 54 | return _SandboxResultTiming( 55 | since_start=now - start, since_last_interaction=now - base 56 | ) 57 | 58 | 59 | def adjust_code_to_set_final_expression_value( 60 | code: str, 61 | variable_name: str = 'result', 62 | default_value: Any = None, 63 | ) -> str: 64 | r"""Returns adjusted code that store the final expression value in a variable. 65 | 66 | Ignores trailing comment lines when determining the "final expression value". 67 | 68 | If the provided code is not valid Python code, then will return the original 69 | code unchanged. (Ideally, the caller should perform validation of the code 70 | prior to calling this function.) 71 | 72 | Examples (assuming variable_name = 'result'): 73 | * 'x = 2\nx + 3' => 'x = 2\nresult = x + 3'. 74 | * 'x = 2\ndel x' => 'x = 2\ndel x\nresult = None'. 75 | * 'x = 2\n# Comment' => 'result = x = 2\n# Comment'. 76 | 77 | Args: 78 | code: String containing the original unadjusted Python code. 79 | variable_name: The variable name into which to store the final expression 80 | value. 81 | default_value: Default value to assign to `variable_name` in the case where 82 | the actual final expression value was undefined. 83 | """ 84 | dedented_code = textwrap.dedent(code.rstrip()) 85 | lines = dedented_code.split('\n') 86 | 87 | # Special case when `code` is empty or consists of only whitespace. 88 | if len(lines) == 1 and not lines[0].strip(): 89 | lines = [] 90 | 91 | try: 92 | parse_tree = ast.parse(dedented_code) 93 | except Exception: # pylint: disable=broad-exception-caught 94 | logging.warning( 95 | 'Failed to parse Python code:\n```\n%s\n```', 96 | dedented_code, 97 | exc_info=True, 98 | ) 99 | return code 100 | 101 | result_idx = None 102 | if parse_tree and parse_tree.body: 103 | # ast.AST.lineno is 1-indexed. We subtract 1 to make it 0-based. 104 | last_statement = parse_tree.body[-1] 105 | # The only Python statements that are compatible with variable assignment 106 | # (as far as we are aware...) are expressions and assignments, e.g.: 107 | # * Expression: `2 + 3` ==> `result = 2 + 3` 108 | # * Assignment: `y = x + 2` ==> `result = y = x + 2` 109 | # Other Python statements that are incompatible with variable assignment 110 | # include compound statements (`if`, `while`, `for`, etc.) and various 111 | # simple non-expression / non-assignment statements (e.g., `assert`, `del`, 112 | # `import`, `raise`, etc.). 113 | # For background, see: https://docs.python.org/3/reference/index.html 114 | if isinstance(last_statement, ast.Assign) or isinstance( 115 | last_statement, ast.Expr 116 | ): 117 | result_idx = last_statement.lineno - 1 118 | 119 | if result_idx is None: 120 | # In this case, since the last statement does not return a value that can be 121 | # assigned to the result variable, we just set the result equal to None. 122 | lines.append(f'{variable_name} = {default_value}') 123 | else: 124 | # If the last statement does return a value, then we can simply set the 125 | # result variable equal to that. 126 | lines[result_idx] = f'{variable_name} = ' + lines[result_idx] 127 | 128 | return '\n'.join(lines) 129 | 130 | 131 | def parse_sandbox_result_json( 132 | result_str: str, 133 | start_time: datetime.datetime, 134 | base_time: datetime.datetime, 135 | ) -> _SandboxResult | None: 136 | """Tries to parse a JSON string into a SandboxResult object. 137 | 138 | Args: 139 | result_str: The string output from the sandbox, expected to be JSON. 140 | start_time: The time execution started. 141 | base_time: The time of the last interaction. 142 | 143 | Returns: 144 | A _SandboxResult if result_str is valid JSON and has the expected 145 | structure (contains at least 'stdout'), otherwise None. 146 | """ 147 | timing = current_timing(start=start_time, base=base_time) 148 | 149 | try: 150 | # Note that this uses the JSON decoder with potential 151 | # implementation-specific customizations. 152 | result_dict = json.loads(result_str) 153 | except json.JSONDecodeError: 154 | logging.warning('Failed to parse result string:\n```\n%s\n```', result_str) 155 | return None 156 | 157 | if isinstance(result_dict, dict) and 'stdout' in result_dict: 158 | # Received a full SandboxResult in dict form, as expected. 159 | return _SandboxResult( 160 | final_expression_value=result_dict.get('final_expression_value'), 161 | stdout=result_dict.get('stdout', ''), 162 | sandbox_status=_SandboxStatus( 163 | result_dict.get('sandbox_status', 'AFTER_RUNNING_CODE') 164 | ), 165 | execution_status=_ExecutionStatus( 166 | result_dict.get('execution_status', 'SUCCESS') 167 | ), 168 | status_message=result_dict.get('status_message', ''), 169 | failure_details=python_execution.parse_failure_details(result_dict), 170 | timing=timing, 171 | ) 172 | else: 173 | # JSON was valid, but not the expected SandboxResult structure. 174 | return None 175 | 176 | 177 | def create_sandbox_hook_callable( 178 | hook_name: str, conn: Any 179 | ) -> Callable[..., Any]: 180 | """Creates a callable to be used inside a sandbox for a specific hook. 181 | 182 | This function is intended to be run *inside* the sandboxed process. 183 | Args: 184 | hook_name: The name of the hook. 185 | conn: The connection object to the main process. 186 | 187 | Returns: 188 | A wrapper function that communicates with the main process. 189 | """ 190 | 191 | def hook_wrapper(*args, **kwargs): 192 | message = {'hook': hook_name, 'args': args, 'kwargs': kwargs} 193 | conn.send(json.dumps(message)) 194 | result_str = conn.recv() 195 | result = json.loads(result_str) 196 | 197 | ex = result.get(EXCEPTION_KEY) 198 | if ex: 199 | error_class_name = ex.get(EX_CLASS_KEY, 'RuntimeError') 200 | error_args = ex.get(EX_ARGS_KEY, ()) 201 | try: 202 | error_module = importlib.import_module('builtins') 203 | error_class = getattr(error_module, error_class_name) 204 | except AttributeError: 205 | error_class = RuntimeError 206 | error = error_class(*error_args) 207 | setattr(error, '_hook_name', ex.get(EX_HOOK_NAME_KEY)) 208 | raise error 209 | return result.get(DATA_KEY) 210 | 211 | return hook_wrapper 212 | -------------------------------------------------------------------------------- /onetwo/core/executing_impl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Core content from `executing.py` factored out to avoid circular dependencies. 16 | 17 | The intention is to keep the content in this file to the minimum needed to avoid 18 | circular dependencies between `executing.py` and other low-level libraries like 19 | `utils.py` or `batching.py`. In particular, we should minimize the dependencies 20 | from this file to other OneTwo libraries. 21 | 22 | The existence of this file should be treated as an implementation detail from 23 | the perspective of general users of OneTwo. So rather than importing this file 24 | directly, most code (both inside and outside of the OneTwo core) should instead 25 | import `executing.py` and/or `utils.py`, which contain public aliases to the 26 | classes and functions defined here. 27 | """ 28 | 29 | from __future__ import annotations 30 | 31 | import abc 32 | from collections.abc import AsyncIterator, Awaitable, Callable, Generator 33 | import dataclasses 34 | import functools 35 | import inspect 36 | from typing import Any, final, Generic, ParamSpec, TypeVar 37 | 38 | from onetwo.core import updating 39 | 40 | 41 | # Basic type variables that need to be specified when using this library. 42 | _Result = TypeVar('_Result') 43 | _Args = ParamSpec('_Args') 44 | 45 | _Update = updating.Update 46 | 47 | 48 | class Executable( 49 | Generic[_Result], 50 | Awaitable[_Result], 51 | AsyncIterator[_Update[_Result]], 52 | metaclass=abc.ABCMeta, 53 | ): 54 | """Interface for a process that can be executed step by step. 55 | 56 | Executable supports both `await` and `async for` statements. If the underlying 57 | implementation supports only `await`, we wrap it into the async iterator that 58 | yields only that one item. Similarly, if the underlying implementation is an 59 | async iterator and only supports `async for`, we implement `await` by manually 60 | iterating through all of the values and returning the final one. 61 | 62 | When awaited, executable returns a value of type Result. However, when using 63 | `async for`, executable produces instances of class (or subclasses of) 64 | `updating.Update`. Class Update helps maintaining (`accumulate`) intermediate 65 | resuts and obtaining the final result (`to_result`) based on accumulated 66 | information. 67 | """ 68 | 69 | @abc.abstractmethod 70 | async def _aexec(self) -> _Result: 71 | """Implementation as an Awaitable.""" 72 | 73 | @abc.abstractmethod 74 | async def _aiterate( 75 | self, iteration_depth: int = 1 76 | ) -> AsyncIterator[_Update[_Result]]: 77 | """Implementation as an AsyncIterator with configurable depth.""" 78 | yield _Update() # For correct typing 79 | 80 | @final 81 | def with_depth(self, iteration_depth: int) -> AsyncIterator[_Update[_Result]]: 82 | return self._aiterate(iteration_depth) 83 | 84 | @final 85 | def __await__(self) -> Generator[Any, Any, _Result]: 86 | """Method that is called when using `await executable`.""" 87 | result = yield from self._aexec().__await__() 88 | return result 89 | 90 | @final 91 | def __aiter__(self) -> AsyncIterator[_Update[_Result]]: 92 | """Method that is called when using `async for executable`.""" 93 | return self._aiterate().__aiter__() 94 | 95 | @final 96 | async def __anext__(self) -> _Update[_Result]: 97 | """Method that is called when using `async for executable`.""" 98 | return await self._aiterate().__anext__() 99 | 100 | 101 | @dataclasses.dataclass 102 | class ExecutableWithPostprocessing( 103 | Generic[_Result], Executable[_Result] 104 | ): 105 | """An executable with a callback executed at the end of the processing. 106 | 107 | One can define two callbacks, one for when the executable is executed with 108 | `await` and one when the executable is iterated through with `async for`. 109 | The latter one is optional. 110 | 111 | Attributes: 112 | wrapped: Executable to augment with postprocessing. 113 | postprocessing_callback: Callback to call at the end of the processing (in 114 | case we call the executable with `await`). 115 | update_callback: Callback to call after each update from the 116 | wrapped Executable (in case we call the executable with `async for`). If 117 | this is None, we will use the postprocessing_callback on 118 | `update.to_result()` (hence converting the updates into results, calling 119 | the callback and converting this back into an Update to be yielded). 120 | """ 121 | 122 | wrapped: Executable[_Result] 123 | postprocessing_callback: Callable[[_Result], _Result] 124 | update_callback: Callable[[_Update[_Result]], _Update[_Result]] | None = None 125 | 126 | @final 127 | async def _aiterate( 128 | self, iteration_depth: int = 1 129 | ) -> AsyncIterator[_Update[_Result]]: 130 | """Yields the intermediate values and calls the final_value_callback.""" 131 | updates = _Update() 132 | async for update in self.wrapped.with_depth(iteration_depth): 133 | updates += update 134 | if self.update_callback is not None: 135 | yield self.update_callback(update) 136 | else: 137 | yield _Update(self.postprocessing_callback(updates.to_result())) 138 | 139 | @final 140 | async def _aexec(self) -> _Result: 141 | """Iterate this value until done (including calling final_value_callback). 142 | 143 | Returns: 144 | The final value given by the AsyncIterator _inner(). 145 | """ 146 | result = await self.wrapped 147 | return self.postprocessing_callback(result) 148 | 149 | 150 | def set_decorated_with_make_executable(f: Callable[..., Any]) -> None: 151 | """Marks the callable as being decorated with @executing.make_executable. 152 | 153 | Args: 154 | f: An arbitrary callable. 155 | """ 156 | f.decorated_with_make_executable = True 157 | 158 | 159 | def is_decorated_with_make_executable(f: Callable[..., Any]) -> bool: 160 | """Returns whether the callable is decorated with @executing.make_executable. 161 | 162 | Args: 163 | f: An arbitrary callable. 164 | """ 165 | # Special handling for partial functions. 166 | while isinstance(f, functools.partial): 167 | f = f.func 168 | 169 | # Special handling for callable objects (e.g., sub-classes of Agent). 170 | if ( 171 | hasattr(f, '__call__') 172 | and not inspect.isfunction(f) 173 | and not inspect.ismethod(f) 174 | ): 175 | f = f.__call__ 176 | 177 | return getattr(f, 'decorated_with_make_executable', False) 178 | 179 | 180 | def returns_awaitable(f: Callable[..., Any]) -> bool: 181 | """Returns whether the callable returns something that is awaitable. 182 | 183 | Note that while using this function is more reliable than calling 184 | `inspect.iscoroutinefunction` directly, there are still some cases that it 185 | does not catch, such as when `f` is manually implemented to return an 186 | Executable, without using the `@executing.make_executable` decorator. To 187 | catch these cases, it is recommended to use `call_and_maybe_await` where 188 | possible, rather than depending on `returns_awaitable`. 189 | 190 | Args: 191 | f: An arbitrary callable. 192 | """ 193 | if ( 194 | hasattr(f, '__call__') 195 | and not inspect.isfunction(f) 196 | and not inspect.ismethod(f) 197 | and not isinstance(f, functools.partial) 198 | ): 199 | f = f.__call__ 200 | 201 | return inspect.iscoroutinefunction(f) or is_decorated_with_make_executable(f) 202 | 203 | 204 | async def call_and_maybe_await( 205 | f: Callable[_Args, _Result | Awaitable[_Result] | Executable[_Result]], 206 | *args: _Args.args, 207 | **kwargs: _Args.kwargs, 208 | ) -> _Result: 209 | """"Calls the callable and awaits the result if appropriate.""" 210 | if returns_awaitable(f): 211 | result = await f(*args, **kwargs) 212 | else: 213 | result = f(*args, **kwargs) 214 | # The below condition is needed in case `f` was manually implemented to return 215 | # an Executable, without using the `@executing.make_executable` decorator. 216 | if isinstance(result, Executable): 217 | result = await result 218 | return result 219 | -------------------------------------------------------------------------------- /onetwo/agents/critics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Interfaces and utilities for scoring, ranking and selecting updates. 16 | 17 | These can be used in combination with agents to implement optimization 18 | algorithms. 19 | """ 20 | 21 | import abc 22 | from collections.abc import Sequence 23 | from typing import Generic, TypeAlias, TypeVar 24 | 25 | from onetwo.agents import agents_base 26 | from onetwo.core import executing 27 | 28 | 29 | # Type used to represent a state. 30 | _S = TypeVar('_S') 31 | 32 | # Type used to represent an incremental update of a state. 33 | _U = TypeVar('_U') 34 | 35 | 36 | class ScoringFunction(Generic[_S, _U], metaclass=abc.ABCMeta): 37 | """Interface for a scoring functions.""" 38 | 39 | @executing.make_executable(copy_self=False) 40 | @abc.abstractmethod 41 | async def __call__(self, state: _S, update: _U) -> float: 42 | """Returns an absolute score of the update given the current state. 43 | 44 | The score represents the quality of the pair (state, update), i.e. how good 45 | the state after update, which means how good is the result of doing 46 | `state + update` (since updates can be added to states via the `+` operator 47 | to form new states). 48 | Ideally it should be comparable across different (state, update) pairs, but 49 | there could be cases where the starting state is identical across different 50 | updates (e.g. we are comparing different updates to the same state) in which 51 | case it is okay to return a score that does not take the state into account, 52 | but this is not recommended. 53 | 54 | Args: 55 | state: The current state. 56 | update: An incremental update of the state. 57 | 58 | Returns: 59 | A float representing the score of the state after update. 60 | The score is not normalized and can take any (positive or negative) value. 61 | Higher means better. 62 | """ 63 | return 0.0 64 | 65 | 66 | class RankingFunction(Generic[_S, _U], metaclass=abc.ABCMeta): 67 | """Interface for a ranking function.""" 68 | 69 | @executing.make_executable(copy_self=False) 70 | @abc.abstractmethod 71 | async def __call__( 72 | self, states_and_updates: Sequence[tuple[_S, _U]] 73 | ) -> list[int]: 74 | """Ranks a list of states and updates and returns indices. 75 | 76 | The order is from best to worst. 77 | When considering a pair (state, update), the value of the pair corresponds 78 | to how good the state after update is, i.e. the value of `state + update`. 79 | The order should thus make sense even when comparing pairs with different 80 | starting states. 81 | 82 | Args: 83 | states_and_updates: A list of states and updates to be ranked. 84 | 85 | Returns: 86 | A list of indices corresponding to the ranking of the states and updates. 87 | The indices refer to the order in the input list. 88 | """ 89 | return [] 90 | 91 | 92 | class SelectingFunction(Generic[_S, _U], metaclass=abc.ABCMeta): 93 | """Interface for a selecting function.""" 94 | 95 | @executing.make_executable(copy_self=False) 96 | @abc.abstractmethod 97 | async def __call__( 98 | self, states_and_updates: Sequence[tuple[_S, _U]] 99 | ) -> int: 100 | """Returns the index of the best (state, update) pair. 101 | 102 | When considering a pair (state, update), the value of the pair corresponds 103 | to how good the state after update is, i.e. the value of `state + update`. 104 | The notion of "best" should thus make sense even when comparing pairs with 105 | different starting states. 106 | 107 | Args: 108 | states_and_updates: A list of states and updates to select from. 109 | 110 | Returns: 111 | The index of the selected pair in the input list. 112 | """ 113 | return 0 114 | 115 | 116 | def ranker_from_scorer( 117 | scorer: ScoringFunction[_S, _U] 118 | ) -> RankingFunction[_S, _U]: 119 | """Converts a ScoringFunction into a RankingFunction. 120 | 121 | The RankingFunction ranks the states and updates by decreasing score. 122 | 123 | Args: 124 | scorer: A ScoringFunction. 125 | 126 | Returns: 127 | A RankingFunction. 128 | """ 129 | @executing.make_executable # pytype: disable=wrong-arg-types 130 | async def ranker(states_and_updates: Sequence[tuple[_S, _U]]) -> list[int]: 131 | executables = [scorer(s[0], s[1]) for s in states_and_updates] # pytype: disable=wrong-arg-count 132 | scores = await executing.par_iter(executables) 133 | sorted_updates = sorted( 134 | enumerate(scores), key=lambda x: x[1], reverse=True 135 | ) 136 | return [update[0] for update in sorted_updates] 137 | 138 | return ranker 139 | 140 | 141 | async def _select_k_best( 142 | states_and_updates: Sequence[tuple[_S, _U]], 143 | critic: SelectingFunction[_S, _U], 144 | k: int, 145 | ) -> list[int]: 146 | """Selects the k best states and updates.""" 147 | current_list = list(states_and_updates) 148 | selected = [] 149 | while len(selected) < k and current_list: 150 | if len(current_list) == 1: 151 | # If there is only one element left, we can stop the loop. 152 | selected.append(0) 153 | break 154 | else: 155 | best_index = await critic(current_list) # pytype: disable=wrong-arg-count 156 | current_list.pop(best_index) 157 | selected.append(best_index) 158 | return selected 159 | 160 | 161 | def ranker_from_selector( 162 | selector: SelectingFunction[_S, _U] 163 | ) -> RankingFunction[_S, _U]: 164 | """Converts a SelectingFunction into a RankingFunction. 165 | 166 | The RankingFunction repeatedly calls the SelectingFunction to select the best 167 | state/update pair and removes it from the list. 168 | 169 | Args: 170 | selector: A SelectingFunction. 171 | 172 | Returns: 173 | A RankingFunction. 174 | """ 175 | 176 | @executing.make_executable # pytype: disable=wrong-arg-types 177 | async def ranker(states_and_updates: Sequence[tuple[_S, _U]]) -> list[int]: 178 | return await _select_k_best( 179 | states_and_updates, selector, len(states_and_updates) 180 | ) 181 | 182 | return ranker 183 | 184 | 185 | class ScoreFromUpdates( 186 | Generic[_S, _U], ScoringFunction[_S, agents_base.ScoredUpdate[_U]] 187 | ): 188 | """Scoring function that extracts the score directly from the update.""" 189 | 190 | @executing.make_executable # pytype: disable=wrong-arg-types 191 | async def __call__( 192 | self, state: _S, update: agents_base.ScoredUpdate[_U] 193 | ) -> float: 194 | return update.score 195 | 196 | 197 | # Type used to represent inputs. 198 | _I = TypeVar('_I') 199 | 200 | _ListOfScoredUpdates: TypeAlias = agents_base.UpdateListState[ 201 | _I, agents_base.ScoredUpdate[_U] 202 | ] 203 | _ScoredUpdate: TypeAlias = agents_base.ScoredUpdate[_U] 204 | 205 | 206 | class ScoreFromUpdateList( 207 | Generic[_I, _U], ScoringFunction[_ListOfScoredUpdates, _ScoredUpdate] 208 | ): 209 | """Scoring function that extracts the score from an update list.""" 210 | 211 | @executing.make_executable # pytype: disable=wrong-arg-types 212 | async def __call__( 213 | self, state: _ListOfScoredUpdates, update: _ScoredUpdate 214 | ) -> float: 215 | """Sums the scores of the updates in the list and the new update. 216 | 217 | This scoring function can be used when for example we have an agent 218 | that associates a score to each of the updates it performs. In this case 219 | the score of a state can be the sum of the scores of all the updates. 220 | For example if we have a sequential distribution as the agent, it may 221 | assign a probability to each update, which we can convert (by taking the 222 | log) into a score that can be added across the whole sequence. 223 | 224 | Args: 225 | state: The current state, i.e. a list of scored updates. 226 | update: The latest scored update. 227 | 228 | Returns: 229 | The sum of the scores of all the updates in the list and the new update. 230 | """ 231 | updated_state = state + update 232 | score = 0.0 233 | for s in updated_state.updates: 234 | score += s.score 235 | return score 236 | 237 | -------------------------------------------------------------------------------- /onetwo/backends/onetwo_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Google OneTwo API connector. 16 | 17 | A Backend that connects to a OneTwo model server and exposes its 18 | functionalities. 19 | """ 20 | 21 | import collections 22 | from collections.abc import Mapping, Sequence 23 | import dataclasses 24 | import json 25 | from typing import Any 26 | from onetwo.backends import backends_base 27 | from onetwo.builtins import llm 28 | from onetwo.core import batching 29 | from onetwo.core import caching 30 | from onetwo.core import content as content_lib 31 | from onetwo.core import tracing 32 | from onetwo.core import utils 33 | import requests 34 | 35 | 36 | 37 | 38 | 39 | 40 | @batching.add_batching # Methods of this class are batched. 41 | @dataclasses.dataclass 42 | class OneTwoAPI( 43 | caching.FileCacheEnabled, # Methods of this class are cached. 44 | backends_base.Backend, 45 | ): 46 | """Google OneTwo API. 47 | 48 | Attributes: 49 | disable_caching: Whether caching is enabled for this object (inherited from 50 | CacheEnabled). 51 | cache_filename: Name of the file (full path) where the cache is stored 52 | (inherited from FileCacheEnabled) 53 | endpoint: The address to connect to (typically some http endpoint). 54 | batch_size: Number of requests (generate_text or chat or generate_embedding) 55 | that is grouped together when sending them to OneTwo API. OneTwo API does 56 | not explicitly support batching (i.e. multiple requests can't be passed 57 | via arguments). Instead we send multiple requests from separate threads. 58 | enable_streaming: Whether to enable streaming replies from generate_text. 59 | max_qps: Maximum queries per second for the backend (if None, no rate 60 | limiting is applied). 61 | temperature: Temperature parameter (float) for LLM generation (can be set as 62 | a default and can be overridden per request). 63 | max_tokens: Maximum number of tokens to generate (can be set as a default 64 | and can be overridden per request). 65 | stop: Stop sequences (as a list of strings) for LLM text generation (can be 66 | set as a default and can be overridden per request). 67 | top_p: Top-p parameter (float) for LLM text generation (can be set as a 68 | default and can be overridden per request). 69 | top_k: Top-k parameter (int) for LLM text generation (can be set as a 70 | default and can be overridden per request). 71 | """ 72 | endpoint: str = dataclasses.field(init=True, default_factory=str) 73 | batch_size: int = 1 74 | enable_streaming: bool = False 75 | max_qps: float | None = None 76 | 77 | # Generation parameters 78 | temperature: float | None = None 79 | max_tokens: int | None = None 80 | stop: Sequence[str] | None = None 81 | top_p: float | None = None 82 | top_k: int | None = None 83 | 84 | _counters: collections.Counter[str] = dataclasses.field( 85 | init=False, default_factory=collections.Counter 86 | ) 87 | 88 | def register(self, name: str | None = None) -> None: 89 | """See parent class.""" 90 | del name 91 | # Reset all the defaults in case some other backend was already registered. 92 | # Indeed, we rely on certain builtins configured with OneTwo defaults. 93 | llm.reset_defaults() 94 | llm.generate_text.configure( 95 | self.generate_text, 96 | temperature=self.temperature, 97 | max_tokens=self.max_tokens, 98 | stop=self.stop, 99 | top_p=self.top_p, 100 | top_k=self.top_k, 101 | ) 102 | llm.count_tokens.configure(self.count_tokens) 103 | llm.tokenize.configure(self.tokenize) 104 | 105 | def __post_init__(self) -> None: 106 | # Create cache. 107 | self._cache_handler = caching.SimpleFunctionCache( 108 | cache_filename=self.cache_filename, 109 | ) 110 | # Check the health status of the endpoint. 111 | try: 112 | response = requests.get(self.endpoint + '/health') 113 | if response.status_code != requests.codes.ok: 114 | raise ValueError(f'OneTwoAPI endpoint unhealthy: {response.text}') 115 | except Exception as err: 116 | raise ValueError(f'OneTwoAPI connection failed: {err}') from err 117 | 118 | @tracing.trace(name='OneTwoAPI.generate_text') 119 | @caching.cache_method( # Cache this method. 120 | name='generate_text', 121 | is_sampled=True, # Two calls with same args may return different replies. 122 | cache_key_maker=lambda: caching.CacheKeyMaker(hashed=['prompt']), 123 | ) 124 | @batching.batch_method_with_threadpool( 125 | batch_size=utils.FromInstance('batch_size'), 126 | wrapper=batching.add_logging, 127 | ) 128 | def generate_text( 129 | self, 130 | prompt: str | content_lib.ChunkList, 131 | *, 132 | temperature: float | None = None, 133 | max_tokens: int | None = None, 134 | stop: Sequence[str] | None = None, 135 | top_k: int | None = None, 136 | top_p: float | None = None, 137 | include_details: bool = False, 138 | **kwargs, # Optional server-specific arguments. 139 | ) -> str | tuple[str, Mapping[str, Any]]: 140 | """See builtins.llm.generate_text.""" 141 | self._counters['generate_text'] += 1 142 | 143 | if isinstance(prompt, content_lib.ChunkList): 144 | prompt = str(prompt) 145 | 146 | args = { 147 | 'prompt': prompt, 148 | 'temperature': temperature, 149 | 'max_tokens': max_tokens, 150 | 'stop': stop, 151 | 'top_k': top_k, 152 | 'top_p': top_p, 153 | 'include_details': include_details, 154 | } 155 | args.update(kwargs) 156 | # TODO: Trace this external API call. 157 | response = requests.post( 158 | self.endpoint + '/generate_text', 159 | headers={'Content-Type': 'application/json'}, 160 | data=json.dumps(args), 161 | ) 162 | if response.status_code != requests.codes.ok: 163 | raise ValueError(f'OneTwoAPI /generate_text failed: {response.text}') 164 | response = json.loads(response.text) 165 | return (response if include_details else response[0]) 166 | 167 | @caching.cache_method( # Cache this method. 168 | name='tokenize', 169 | is_sampled=False, 170 | cache_key_maker=lambda: caching.CacheKeyMaker(hashed=['prompt']), 171 | ) 172 | @batching.batch_method_with_threadpool( 173 | batch_size=utils.FromInstance('batch_size'), 174 | wrapper=batching.add_logging, 175 | ) 176 | def tokenize( 177 | self, 178 | content: str | content_lib.ChunkList, 179 | ) -> list[int]: 180 | """See builtins.llm.tokenize.""" 181 | self._counters['tokenize'] += 1 182 | 183 | if isinstance(content, content_lib.ChunkList): 184 | content = str(content) 185 | 186 | # TODO: Trace this external API call. 187 | response = requests.post( 188 | self.endpoint + '/tokenize', 189 | json={ 190 | 'content': content, 191 | }, 192 | ) 193 | if response.status_code != requests.codes.ok: 194 | raise ValueError(f'OneTwoAPI /tokenize failed: {response.text}') 195 | response = json.loads(response.text) 196 | return response['result'] 197 | 198 | @caching.cache_method( # Cache this method. 199 | name='count_tokens', 200 | is_sampled=False, 201 | cache_key_maker=lambda: caching.CacheKeyMaker(hashed=['prompt']), 202 | ) 203 | @batching.batch_method_with_threadpool( 204 | batch_size=utils.FromInstance('batch_size'), 205 | wrapper=batching.add_logging, 206 | ) 207 | def count_tokens( 208 | self, 209 | content: str | content_lib.ChunkList, 210 | ) -> int: 211 | """See builtins.llm.tokenize.""" 212 | self._counters['count_tokens'] += 1 213 | 214 | if isinstance(content, content_lib.ChunkList): 215 | content = str(content) 216 | 217 | # TODO: Trace this external API call. 218 | response = requests.post( 219 | url=self.endpoint + '/count_tokens', 220 | json={'content': content}, 221 | ) 222 | if response.status_code != requests.codes.ok: 223 | raise ValueError( 224 | f'OneTwoAPI /count_tokens failed: {response.text}' 225 | ) 226 | response = json.loads(response.text) 227 | return response['result'] 228 | --------------------------------------------------------------------------------