├── .editorconfig ├── .gitattributes ├── .github └── workflows │ ├── docs.yml │ ├── lints.yml │ ├── publish.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── benches.yml ├── benches ├── README.md ├── __init__.py ├── evaluate.py ├── extractors.py ├── run.py └── shared.py ├── codecov.yml ├── cogitator ├── __init__.py ├── clustering.py ├── embedding.py ├── model │ ├── __init__.py │ ├── base.py │ ├── ollama.py │ ├── openai.py │ └── py.typed ├── py.typed ├── schemas.py ├── strategies │ ├── __init__.py │ ├── auto_cot.py │ ├── cdw_cot.py │ ├── graph_of_thoughts.py │ ├── least_to_most.py │ ├── py.typed │ ├── sc_cot.py │ └── tree_of_thoughts.py └── utils.py ├── docs ├── api │ ├── clustering.md │ ├── embedding.md │ ├── model │ │ ├── base.md │ │ ├── ollama.md │ │ └── openai.md │ ├── schemas.md │ ├── strategies │ │ ├── auto_cot.md │ │ ├── cdw_cot.md │ │ ├── graph_of_thoughts.md │ │ ├── least_to_most.md │ │ ├── sc_cot.md │ │ └── tree_of_thoughts.md │ └── utils.md ├── assets │ └── images │ │ ├── cogitator_v1.dot │ │ ├── cogitator_v1.svg │ │ ├── cogitator_v2.dot │ │ ├── cogitator_v2.svg │ │ └── make_figures.sh ├── benchmarking.md ├── contributing.md └── index.md ├── examples ├── README.md ├── __init__.py ├── run_auto_cot.py ├── run_cdw_cot.py ├── run_graph_of_thoughts.py ├── run_least_to_most.py ├── run_sc_cot.py ├── run_simple_example.py ├── run_tree_of_thoughts.py └── shared.py ├── logo.svg ├── mkdocs.yml ├── poetry.toml ├── pyproject.toml └── tests ├── conftest.py ├── extra_tests ├── test_benches.py └── test_examples.py ├── test_auto_cot.py ├── test_cdw_cot.py ├── test_graph_of_thoughts.py ├── test_least_to_most.py ├── test_model.py ├── test_sc_cot.py └── test_tree_of_thoughts.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: https://EditorConfig.org 2 | 3 | # Top-most EditorConfig file 4 | root = true 5 | 6 | # Global settings (applicable to all files unless overridden) 7 | [*] 8 | charset = utf-8 # Default character encoding 9 | end_of_line = lf # Use LF for line endings (Unix-style) 10 | indent_style = space # Use spaces for indentation 11 | indent_size = 4 # Default indentation size 12 | insert_final_newline = true # Make sure files end with a newline 13 | trim_trailing_whitespace = true # Remove trailing whitespace 14 | 15 | [*.py] 16 | max_line_length = 100 17 | 18 | # Markdown files 19 | [*.md] 20 | max_line_length = 130 21 | trim_trailing_whitespace = false 22 | 23 | # Bash scripts 24 | [*.sh] 25 | indent_size = 2 26 | 27 | # YAML files 28 | [*.{yml,yaml}] 29 | indent_size = 2 30 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Common document and text file formats 2 | *.docx filter=lfs diff=lfs merge=lfs -text 3 | *.doc filter=lfs diff=lfs merge=lfs -text 4 | *.pdf filter=lfs diff=lfs merge=lfs -text 5 | *.xls filter=lfs diff=lfs merge=lfs -text 6 | *.xlsx filter=lfs diff=lfs merge=lfs -text 7 | *.ppt filter=lfs diff=lfs merge=lfs -text 8 | *.pptx filter=lfs diff=lfs merge=lfs -text 9 | 10 | # Common image formats 11 | *.jpg filter=lfs diff=lfs merge=lfs -text 12 | *.jpeg filter=lfs diff=lfs merge=lfs -text 13 | *.png filter=lfs diff=lfs merge=lfs -text 14 | *.gif filter=lfs diff=lfs merge=lfs -text 15 | *.bmp filter=lfs diff=lfs merge=lfs -text 16 | *.tiff filter=lfs diff=lfs merge=lfs -text 17 | *.tif filter=lfs diff=lfs merge=lfs -text 18 | 19 | # Common compressed file formats 20 | *.zip filter=lfs diff=lfs merge=lfs -text 21 | *.gz filter=lfs diff=lfs merge=lfs -text 22 | *.tar filter=lfs diff=lfs merge=lfs -text 23 | *.tgz filter=lfs diff=lfs merge=lfs -text 24 | *.bz2 filter=lfs diff=lfs merge=lfs -text 25 | *.7z filter=lfs diff=lfs merge=lfs -text 26 | *.rar filter=lfs diff=lfs merge=lfs -text 27 | 28 | # Common file formats in machine learning projects 29 | *.bin filter=lfs diff=lfs merge=lfs -text 30 | *.model filter=lfs diff=lfs merge=lfs -text 31 | *.h5 filter=lfs diff=lfs merge=lfs -text 32 | *.tfrecord filter=lfs diff=lfs merge=lfs -text 33 | *.hdf5 filter=lfs diff=lfs merge=lfs -text 34 | *.keras filter=lfs diff=lfs merge=lfs -text 35 | *.pth filter=lfs diff=lfs merge=lfs -text 36 | *.pt filter=lfs diff=lfs merge=lfs -text 37 | *.joblib filter=lfs diff=lfs merge=lfs -text 38 | *.pkl filter=lfs diff=lfs merge=lfs -text 39 | *.pickle filter=lfs diff=lfs merge=lfs -text 40 | *.npy filter=lfs diff=lfs merge=lfs -text 41 | 42 | # Common audio and video formats 43 | *.mp3 filter=lfs diff=lfs merge=lfs -text 44 | *.mp4 filter=lfs diff=lfs merge=lfs -text 45 | *.wav filter=lfs diff=lfs merge=lfs -text 46 | *.avi filter=lfs diff=lfs merge=lfs -text 47 | *.mov filter=lfs diff=lfs merge=lfs -text 48 | *.flac filter=lfs diff=lfs merge=lfs -text 49 | *.mkv filter=lfs diff=lfs merge=lfs -text 50 | *.webm filter=lfs diff=lfs merge=lfs -text 51 | *.ogg filter=lfs diff=lfs merge=lfs -text 52 | *.ogv filter=lfs diff=lfs merge=lfs -text 53 | 54 | # Common data transfer formats 55 | #*.csv filter=lfs diff=lfs merge=lfs -text 56 | #*.tsv filter=lfs diff=lfs merge=lfs -text 57 | #*.json filter=lfs diff=lfs merge=lfs -text 58 | #*.xml filter=lfs diff=lfs merge=lfs -text 59 | *.parquet filter=lfs diff=lfs merge=lfs -text 60 | *.feather filter=lfs diff=lfs merge=lfs -text 61 | *.msgpack filter=lfs diff=lfs merge=lfs -text 62 | *.avro filter=lfs diff=lfs merge=lfs -text 63 | *.arrow filter=lfs diff=lfs merge=lfs -text 64 | *.orc filter=lfs diff=lfs merge=lfs -text 65 | 66 | # Exclude files from language stats (GitHub Linguist) 67 | *.ipynb linguist-vendored 68 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Docs 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | contents: read 8 | pages: write 9 | id-token: write 10 | 11 | # Only allow one deployment at a time running 12 | concurrency: 13 | group: "pages" 14 | cancel-in-progress: false 15 | 16 | jobs: 17 | deploy: 18 | environment: 19 | name: github-pages 20 | url: ${{ steps.deployment.outputs.page_url }} 21 | runs-on: ubuntu-latest 22 | steps: 23 | - name: Checkout Repository 24 | uses: actions/checkout@v4 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.11' 30 | 31 | - name: Install Poetry 32 | uses: snok/install-poetry@v1 33 | with: 34 | virtualenvs-create: true 35 | virtualenvs-in-project: true 36 | 37 | - name: Load cached venv 38 | id: cached-poetry-dependencies 39 | uses: actions/cache@v4 40 | with: 41 | path: .venv 42 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 43 | 44 | - name: Install Dependencies 45 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 46 | run: make install 47 | 48 | - name: Build Documentation 49 | run: make docs 50 | 51 | - name: Setup Pages 52 | uses: actions/configure-pages@v5 53 | 54 | - name: Upload Documentation as Artifact 55 | uses: actions/upload-pages-artifact@v3 56 | with: 57 | path: './site' 58 | 59 | - name: Deploy to GitHub Pages 60 | id: deployment 61 | uses: actions/deploy-pages@v4 62 | -------------------------------------------------------------------------------- /.github/workflows/lints.yml: -------------------------------------------------------------------------------- 1 | name: Run Linters 2 | 3 | on: 4 | workflow_dispatch: 5 | workflow_call: 6 | push: 7 | tags: 8 | - 'v*' 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | 17 | strategy: 18 | matrix: 19 | python-version: [ "3.10", "3.11", "3.12", "3.13" ] 20 | 21 | steps: 22 | - name: Checkout Repository 23 | uses: actions/checkout@v4 24 | 25 | - name: Set Up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install Dependencies 31 | run: | 32 | make setup 33 | make install 34 | 35 | - name: Run Tests with Coverage 36 | run: | 37 | make lint 38 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | tags: 7 | - 'v*' 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | 14 | call_tests: 15 | uses: ./.github/workflows/tests.yml 16 | 17 | publish_to_pypi: 18 | runs-on: ubuntu-latest 19 | needs: call_tests 20 | 21 | steps: 22 | - name: Checkout Repository 23 | uses: actions/checkout@v4 24 | 25 | - name: Set Up Python 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: "3.10" 29 | 30 | - name: Install Dependencies 31 | run: | 32 | make setup 33 | make install 34 | 35 | - name: Build and Publish to PyPI 36 | run: | 37 | PYPI_TOKEN=${{ secrets.PYPI_API_TOKEN }} make publish 38 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | workflow_call: 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | 17 | strategy: 18 | matrix: 19 | python-version: [ "3.10", "3.11", "3.12", "3.13" ] 20 | 21 | steps: 22 | - name: Checkout Repository 23 | uses: actions/checkout@v4 24 | 25 | - name: Set Up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install Dependencies 31 | run: | 32 | make setup 33 | make install 34 | 35 | - name: Run Tests with Coverage 36 | run: | 37 | make test 38 | 39 | - name: Upload coverage reports to Codecov 40 | uses: codecov/codecov-action@v5 41 | with: 42 | token: ${{ secrets.CODECOV_TOKEN }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python specific 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Virtual environments 7 | .env/ 8 | env/ 9 | .venv/ 10 | venv/ 11 | 12 | # Packaging and distribution files 13 | .Python 14 | build/ 15 | dist/ 16 | *.egg-info/ 17 | *.egg 18 | MANIFEST 19 | 20 | # Dependency directories 21 | develop-eggs/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | .installed.cfg 32 | 33 | # Test and coverage reports 34 | htmlcov/ 35 | .tox/ 36 | .coverage 37 | .coverage.* 38 | .cache 39 | nosetests.xml 40 | coverage.xml 41 | *.cover 42 | .hypothesis/ 43 | .pytest_cache/ 44 | .benchmarks/ 45 | 46 | # IDE specific files and directories 47 | .idea/ 48 | *.iml 49 | .vscode/ 50 | 51 | # Jupyter Notebook files 52 | .ipynb_checkpoints 53 | 54 | # Temporary files created by editors and the system and folders to ignore 55 | *.swp 56 | *~ 57 | *.bak 58 | *.tmp 59 | temp/ 60 | output/ 61 | tmp/ 62 | tmp2/ 63 | out/ 64 | out2/ 65 | 66 | # Database files (SQLite, DuckDB, etc.) 67 | *.duckdb 68 | *.db 69 | *.wal 70 | *.sqlite 71 | 72 | # Dependency lock files (uncomment to ignore) 73 | poetry.lock 74 | 75 | # Miscellaneous files and directories to ignore 76 | # Add any additional file patterns a directory names that should be ignored down here 77 | *_output.txt 78 | .env 79 | benches/data/ 80 | coverage_*_report 81 | *.jsonl 82 | site/ 83 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-toml 9 | - id: check-added-large-files 10 | - id: check-merge-conflict 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | We adhere to the [Python Software Foundation Code of Conduct](https://policies.python.org/python.org/code-of-conduct). 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guidelines 2 | 3 | Thank you for considering contributing to this project! 4 | Contributions are always welcome and appreciated. 5 | 6 | ## How to Contribute 7 | 8 | Please check the [issue tracker](https://github.com/habedi/cogitator/issues) to see if there is an issue you 9 | would like to work on or if it has already been resolved. 10 | 11 | ### Reporting Bugs 12 | 13 | 1. Open an issue on the [issue tracker](https://github.com/habedi/cogitator/issues). 14 | 2. Include information such as steps to reproduce, expected/actual behavior, and relevant logs or screenshots. 15 | 16 | ### Suggesting Features 17 | 18 | 1. Open an issue on the [issue tracker](https://github.com/habedi/cogitator/issues). 19 | 2. Provide details about the feature, its purpose, and potential implementation ideas. 20 | 21 | ## Submitting Pull Requests 22 | 23 | - Ensure all tests pass before submitting a pull request. 24 | - Write a clear description of the changes you made and the reasons behind them. 25 | 26 | > [!IMPORTANT] 27 | > It's assumed that by submitting a pull request, you agree to license your contributions under the project's license. 28 | 29 | ## Development Workflow 30 | 31 | ### Prerequisites 32 | 33 | Install GNU Make on your system if it's not already installed. 34 | 35 | ```shell 36 | # For Debian-based systems like Debian, Ubuntu, etc. 37 | sudo apt-get install make 38 | ``` 39 | 40 | - Use the `make setup` command to install the development dependencies. 41 | 42 | ### Code Style 43 | 44 | - Use the `make format` command to format the code. 45 | 46 | ### Running Tests 47 | 48 | - Use the `make test` command to run the tests. 49 | 50 | ### Running Linter Checks 51 | 52 | - Use the `make lint` command to run the linter checks. 53 | 54 | ### See Available Commands 55 | 56 | - Run `make help` to see all available commands for managing different tasks. 57 | 58 | ## Code of Conduct 59 | 60 | We adhere to the [Python Software Foundation Code of Conduct](https://policies.python.org/python.org/code-of-conduct). 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Hassan Abedi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Variables 2 | PYTHON = python 3 | PIP = pip 4 | POETRY = poetry 5 | SHELL = /bin/bash 6 | BENCH_DIR = benches 7 | EXAMPLE_DIR = examples 8 | OLLAMA_MODEL ?= gemma3:12b 9 | OPENAI_MODE ?= gpt-4o-mini 10 | 11 | # Default target 12 | .DEFAULT_GOAL := help 13 | 14 | # Help target 15 | .PHONY: help 16 | help: ## Show help messages for all available targets 17 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; \ 18 | {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 19 | 20 | # Setup and installation 21 | .PHONY: setup 22 | setup: ## Install system dependencies 23 | sudo apt-get update 24 | sudo apt-get install -y python3-pip 25 | $(PIP) install poetry 26 | 27 | .PHONY: install 28 | install: ## Install Python dependencies 29 | $(POETRY) install --with dev 30 | 31 | # Testing and linting 32 | .PHONY: test 33 | test: ## Run the tests 34 | CUDA_VISIBLE_DEVICES="" $(POETRY) run pytest 35 | 36 | .PHONY: lint 37 | lint: ## Run the linter checks 38 | $(POETRY) run ruff check --fix 39 | 40 | .PHONY: format 41 | format: ## Format the Python files 42 | $(POETRY) run ruff format . 43 | 44 | .PHONY: typecheck 45 | typecheck: ## Typecheck the Python files 46 | $(POETRY) run mypy . 47 | 48 | # Cleaning 49 | .PHONY: clean 50 | clean: ## Remove build artifacts, caches, and temporary files 51 | find . -type f -name '*.pyc' -delete 52 | find . -type d -name '__pycache__' -exec rm -r {} + 53 | rm -rf site dist 54 | rm -rf .mypy_cache .pytest_cache .ruff_cache .coverage htmlcov coverage.xml junit 55 | 56 | # Build and publish 57 | .PHONY: build 58 | build: ## Build the wheel and source distribution 59 | $(POETRY) build 60 | 61 | .PHONY: publish 62 | publish: ## Publish the library to PyPI (requires PYPI_TOKEN to be set) 63 | $(POETRY) config pypi-token.pypi $(PYPI_TOKEN) 64 | $(POETRY) publish --build 65 | 66 | .PHONY: example-openai 67 | example-openai: ## Run the examples using OpenAI (needs OPENAI_API_KEY to be set) 68 | @for script in $(EXAMPLE_DIR)/run_*.py; do \ 69 | echo "Running $$script --provider openai --openai-key ******** --model-name $(OPENAI_MODE) --use-async"; \ 70 | $(POETRY) run python $$script --provider openai --openai-key $(OPENAI_API_KEY) --model-name $(OPENAI_MODE) --use-async; \ 71 | done 72 | 73 | example-ollama: ## Run the examples using Ollama 74 | @echo "Running examples with Ollama provider (Model: $(OLLAMA_MODEL))" 75 | @for script in $(EXAMPLE_DIR)/run_*.py; do \ 76 | echo "Running $$script --provider ollama --model-name $(OLLAMA_MODEL)"; \ 77 | $(POETRY) run python $$script --provider ollama --model-name $(OLLAMA_MODEL); \ 78 | done 79 | 80 | .PHONY: docs 81 | docs: ## Generate the project documentation 82 | $(POETRY) run mkdocs build 83 | 84 | # All-in-one target 85 | .PHONY: all 86 | all: install check build ## Install Python dependencies, run lint, typecheck, tests, and build the library 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Cogitator Logo 4 | 5 |
6 | 7 |

Cogitator

8 | 9 | [![Tests](https://img.shields.io/github/actions/workflow/status/habedi/cogitator/tests.yml?label=tests&style=flat&labelColor=333333&logo=github&logoColor=white)](https://github.com/habedi/cogitator/actions/workflows/tests.yml) 10 | [![Code Coverage](https://img.shields.io/codecov/c/github/habedi/cogitator?style=flat&label=coverage&labelColor=333333&logo=codecov&logoColor=white)](https://codecov.io/gh/habedi/cogitator) 11 | [![Code Quality](https://img.shields.io/codefactor/grade/github/habedi/cogitator?style=flat&label=code%20quality&labelColor=333333&logo=codefactor&logoColor=white)](https://www.codefactor.io/repository/github/habedi/cogitator) 12 | [![Python Version](https://img.shields.io/badge/python-%3E=3.10-3776ab?style=flat&labelColor=333333&logo=python&logoColor=white)](https://github.com/habedi/cogitator) 13 | [![PyPI Version](https://img.shields.io/pypi/v/cogitator.svg?style=flat&label=pypi&labelColor=333333&logo=pypi&logoColor=white&color=3775a9)](https://pypi.org/project/cogitator) 14 | [![Downloads](https://img.shields.io/pypi/dm/cogitator.svg?style=flat&label=downloads&labelColor=333333&logo=pypi&logoColor=white&color=cc8400)](https://github.com/habedi/cogitator) 15 |
16 | [![License](https://img.shields.io/badge/license-MIT-00acc1?style=flat&labelColor=333333&logo=open-source-initiative&logoColor=white)](https://github.com/habedi/cogitator/blob/main/LICENSE) 17 | [![Docs](https://img.shields.io/badge/docs-latest-8ca0d7?style=flat&labelColor=333333&logo=readthedocs&logoColor=white)](https://habedi.github.io/cogitator) 18 | [![DOI](https://img.shields.io/badge/doi-10.5281/zenodo.15331821-6f42c1.svg?style=flat&labelColor=333333&logo=zenodo&logoColor=white)](https://doi.org/10.5281/zenodo.15331821) 19 | 20 | A Python Toolkit for Chain-of-Thought Prompting 21 | 22 |
23 | 24 | --- 25 | 26 | Cogitator is a Python toolkit for experimenting and working with 27 | [chain-of-thought (CoT) prompting](https://arxiv.org/abs/2201.11903) 28 | methods in large language models (LLMs). 29 | CoT prompting improves LLM performance on complex tasks (like question-answering, reasoning, and problem-solving) 30 | by guiding the models to generate intermediate reasoning steps before arriving at the final answer. 31 | Additionally, it can be used to improve the interpretability of LLMs by providing insight into the model's reasoning process. 32 | The toolkit aims to make it easier to use popular CoT strategies and frameworks for research or integrating them into AI 33 | applications. 34 | 35 | ### Features 36 | 37 | * Provides unified sync/async API for CoT strategies 38 | * Supports using OpenAI and Ollama as LLM providers 39 | * Supports structured model output with Pydantic validation 40 | * Includes a customizable benchmarking framework (see [benches](benches)) 41 | * Includes implementations of popular CoT strategies and frameworks like 42 | - [Self-Consistency CoT (ICLR 2023)](https://arxiv.org/abs/2203.11171) 43 | - [Automatic CoT (ICLR 2023)](https://arxiv.org/abs/2210.03493) 44 | - [Least-to-Most Prompting (ICLR 2023)](https://arxiv.org/abs/2205.10625) 45 | - [Tree of Thoughts (NeurIPS 2023)](https://arxiv.org/abs/2305.10601) 46 | - [Graph of Thoughts (AAAI 2024)](https://arxiv.org/abs/2308.09687) 47 | - [Clustered Distance-Weighted CoT (AAAI 2025)](https://arxiv.org/abs/2501.12226) 48 | 49 | --- 50 | 51 | ### Getting Started 52 | 53 | You can install Cogitator with 54 | 55 | ```bash 56 | pip install cogitator 57 | ``` 58 | 59 | Or, if you want to install from the latest version with examples and benchmarks included 60 | 61 | ```bash 62 | git clone https://github.com/habedi/cogitator && cd cogitator 63 | 64 | # Set up Python environment 65 | pip install poetry 66 | poetry install --with dev 67 | 68 | # Run the tests to make sure everything is working (optional) 69 | poetry run pytest 70 | ``` 71 | 72 | #### Examples 73 | 74 | Below is a simple example of using the Self-Consistency CoT with Ollama. 75 | 76 | ```python 77 | import logging 78 | from cogitator import SelfConsistency, OllamaLLM 79 | 80 | # Step 1: Configure logging (optional, but helpful) 81 | logging.basicConfig(level=logging.INFO) 82 | logging.getLogger("httpx").setLevel(logging.WARNING) # Suppress HTTPX logs 83 | 84 | # Step 2: Initialize the LLM (using Ollama) 85 | # Needs Ollama running locally with the model pulled (e.g., `ollama pull gemma3:4b`) 86 | try: 87 | llm = OllamaLLM(model="gemma3:4b") 88 | except Exception as e: 89 | print(f"Error initializing Ollama LLM: {e}") 90 | print("Please make sure Ollama is running and the model is pulled.") 91 | exit(1) 92 | 93 | # Step 3: Choose a CoT strategies (Self-Consistency in this case) 94 | # Self-Consistency generates multiple reasoning paths and finds the most common answer 95 | sc_strategy = SelfConsistency( 96 | llm, 97 | n_samples=5, # Number of reasoning paths to generate 98 | temperature=0.7 # Higher temperature can lead to more diverse answers 99 | ) 100 | 101 | # Step 4: Define the prompt (with a basic CoT trigger) 102 | question = "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?" 103 | prompt = f"Q: {question}\nA: Let's think step by step." 104 | 105 | # Step 5: Run the CoT prompting sc_strategy 106 | print(f"\nQuestion: {question}") 107 | print("Running Self-Consistency CoT...") 108 | final_answer = sc_strategy.run(prompt) # Returns the most consistent (repeated) answer 109 | 110 | # Expected output: $0.05 or 0.05 (may vary slightly based on model and temperature) 111 | print(f"\nCogitator's Answer (Self-Consistency): {final_answer}") 112 | ``` 113 | 114 | Check out the [examples](examples) directory for more examples. 115 | 116 | --- 117 | 118 | ### Documentation 119 | 120 | Cogitator documentation is available [here](https://habedi.github.io/cogitator). 121 | 122 | --- 123 | 124 | ### Benchmarking Framework 125 | 126 | This project includes a customizable and extensible benchmarking framework to evaluate the performance of different 127 | CoT strategies on various datasets like [GSM8K](https://arxiv.org/abs/2110.14168) and 128 | [StrategyQA](https://arxiv.org/abs/2101.02235). 129 | 130 | Check out the [benches](benches) directory for more details about the framework and how it could be used. 131 | 132 | --- 133 | 134 | ### Contributing 135 | 136 | See [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to make a contribution. 137 | 138 | ### Citations 139 | 140 | If you find this project useful, please give it a star! 141 | If you have any questions or feedback, please use the discussion section of the repository or open an issue. 142 | If you use this project in your research, please consider citing using the following information: 143 | 144 | ```bibtex 145 | @software{abedi_cogitator_2025, 146 | author = {Abedi Firouzjaei, Hassan}, 147 | title = {{Cogitator: A Python Toolkit for Chain-of-Thought Prompting}}, 148 | year = {2025--}, 149 | publisher = {Zenodo}, 150 | doi = {10.5281/zenodo.15331821}, 151 | url = {https://github.com/habedi/cogitator} 152 | } 153 | ``` 154 | 155 | ### Logo 156 | 157 | The logo is named "Cognition" and was originally created by [vectordoodle](https://www.svgrepo.com/author/vectordoodle). 158 | 159 | ### License 160 | 161 | Cogitator is licensed under the [MIT License](LICENSE). 162 | -------------------------------------------------------------------------------- /benches.yml: -------------------------------------------------------------------------------- 1 | common: 2 | debug: false 3 | openai_key_env_var: "OPENAI_API_KEY" 4 | 5 | generation: 6 | dataset: "gsm8k" 7 | cutoff: 50 8 | provider: "ollama" 9 | model_name: "gemma3:4b" 10 | ollama_host: null 11 | use_async: true 12 | concurrency: 3 13 | use_json_strategies: false 14 | output_file: "benchmark_results.jsonl" 15 | llm_params: 16 | max_tokens: 2048 17 | seed: 33 18 | temperature: 0.7 19 | 20 | evaluation: 21 | input_file: "benchmark_results.jsonl" 22 | extractor: 23 | type: "heuristic" 24 | provider: "ollama" 25 | model_name: "gemma3:12b" 26 | ollama_host: null 27 | llm_params: 28 | max_tokens: 64 29 | seed: 42 30 | temperature: 0.1 31 | show_details: false 32 | concurrency: 3 33 | 34 | strategies: 35 | ZeroShotCoT: 36 | enabled: true 37 | 38 | AutoCoT: 39 | enabled: true 40 | n_demos: 5 41 | max_q_tokens: 100 42 | max_steps: 8 43 | max_retries: 3 44 | prompt_template: "Let's think step-by-step." 45 | 46 | CDWCoT: 47 | enabled: true 48 | pool_size: 10 49 | n_clusters: 5 50 | lr: 0.1 51 | sample_size: 10 52 | max_grad_norm: 1.0 53 | init_pool_retries: 1 54 | train_params: 55 | epochs: 5 56 | patience: 2 57 | val_split: 0.2 58 | 59 | SelfConsistency: 60 | enabled: true 61 | n_samples: 10 62 | temperature: 0.8 63 | stop: null 64 | internal_extraction_format: "heuristic" 65 | 66 | LeastToMost: 67 | enabled: true 68 | max_subqs: 10 69 | intermediate_output_format: "text" 70 | 71 | TreeOfThoughts: 72 | enabled: true 73 | max_depth: 3 74 | num_branches: 3 75 | sims: 10 76 | c_puct: 1.0 77 | 78 | GraphOfThoughts: 79 | enabled: true 80 | graph_of_operations: 81 | - [ "Generate", { k: 3, target_set: "frontier", output_set: "generated_1", prompt_key: "expand" } ] 82 | - [ "Score", { target_set: "generated_1", prompt_key: "evaluate" } ] 83 | - [ "KeepBest", { N: 1, target_set: "generated_1", output_set: "frontier" } ] 84 | - [ "Generate", { k: 2, target_set: "frontier", output_set: "generated_2", prompt_key: "expand" } ] 85 | - [ "Score", { target_set: "generated_2", prompt_key: "evaluate" } ] 86 | - [ "KeepBest", { N: 1, target_set: "generated_2", output_set: "frontier" } ] 87 | final_answer_format: "text" 88 | -------------------------------------------------------------------------------- /benches/README.md: -------------------------------------------------------------------------------- 1 | ## Benchmarks 2 | 3 | Benchmarks are primarily run using the `benches/run.py` script, which handles the generation phase. 4 | A separate script, `benches/evaluate.py`, is used afterward to calculate accuracy from the generated results. 5 | 6 | Run the generation script from the project root directory: 7 | 8 | ```bash 9 | poetry run python benches/run.py [OPTIONS] 10 | ``` 11 | 12 | Available Options for `run.py`: 13 | 14 | * `--dataset `: Dataset to use (default: `gsm8k`). Available options: `gsm8k`, `multiarith`, `aqua`, `csqa`, 15 | `strategyqa`, `coin`, and `letter`. 16 | * `--cutoff `: Number of samples to load from the dataset (-1 for all; default: `50`). These samples are used 17 | for both setup (if needed) and generation and testing. 18 | * `--provider `: LLM provider (`ollama` or `openai`; default: `ollama`). 19 | * `--model-name `: Model name for the provider (default: `gemma2:9b` for ollama, `gpt-4o-mini` for openai). Verify model 20 | availability. 21 | * `--openai-key `: OpenAI API key (needed for `--provider openai`, can use `OPENAI_API_KEY` environment variable if 22 | it is set). 23 | * `--use-async`: Use asynchronous execution for LLM calls (default: sync). Highly recommended for speed. 24 | * `--concurrency `: Max concurrent LLM requests when using `--use-async` (default: `3`). 25 | * `--use-json-strategies`: Use JSON mode within strategies where applicable (LtM, GoT, SC) (default: disabled). 26 | * `--output-file `: File to save raw generation results in JSONL format (default: `benchmark_results.jsonl`). 27 | * `--debug`: Enable debug logging for more verbose output. 28 | 29 | Check out OpenAI's [API documentation](https://platform.openai.com/docs/api-reference) for more details on the models 30 | and their capabilities. Use `ollama list` to see the available models for the `ollama` provider. 31 | 32 | ### Benchmark Workflow (`run.py`) 33 | 34 | The `run.py` script executes the following steps: 35 | 36 | 1. **Configuration:** Loads settings from `benches.yml` and merges them with command-line arguments (CLI > YAML > Defaults). 37 | 2. **Dataset Loading:** Loads the specified dataset subset based on the final configuration. 38 | 3. **Model & CoT Strategies:** Sets up the language model and instances of enabled CoT strategies, configured according to 39 | `benches.yml`. 40 | 4. **Setup Phase (One-Time Cost per Run):** Before generating answers for the test questions, strategies that need fitting 41 | or training (like AutoCoT and CDWCoT) perform this step *once* using the loaded dataset samples and configured 42 | parameters. 43 | 5. **Generation Phase (Per Question):** The script iterates through each loaded question: 44 | * For each question, it executes all *enabled* CoT strategies using their configured parameters. 45 | * If run synchronously, strategies execute one after another. If async, calls run concurrently. 46 | * The raw text output from the LLM and the execution time are recorded. 47 | 6. **Output:** Results are saved line-by-line in JSONL format to the specified output file. 48 | 49 | See `poetry run python benches/run.py --help` to see all available options. 50 | 51 | ### Evaluation (`evaluate.py`) 52 | 53 | After `run.py` generates the result file, use `evaluate.py` to calculate metrics: 54 | 55 | ```bash 56 | poetry run python benches/evaluate.py --input-file [EVAL_OPTIONS] 57 | ``` 58 | 59 | This script reads the JSONL file, loads its configuration from `benches.yml` and merges with CLI options, extracts the 60 | final answer from the raw model output (using the configured extractor type: heuristic or LLM), compares it to the 61 | correct answer, and calculates the accuracy for each CoT strategy present in the result file. It then shows a summary 62 | table. See `poetry run python benches/evaluate.py --help` for evaluation-specific options. 63 | 64 | ### Configuration (`benches.yml`) 65 | 66 | Benchmark runs are configured using `benches.yml` in the project root, combined with command-line arguments. 67 | **Configuration Precedence:** 68 | 69 | 1. **Command-Line Arguments:** Highest priority (e.g., `--dataset`, `--provider`). 70 | 2. **`benches.yml`:** Values from this file are used if not specified via CLI. 71 | 3. **Code Defaults:** Lowest priority, used if not set in CLI or YAML. 72 | 73 | **YAML Structure:** 74 | 75 | * **`common`**: Shared settings. 76 | * `debug`: `true` or `false` for verbose logging. 77 | * `openai_key_env_var`: Name of the environment variable holding the OpenAI key. 78 | * **`generation`**: Settings for the generation script (`run.py`). 79 | * `dataset`: Name of the dataset (e.g., `gsm8k`). 80 | * `cutoff`: Max number of samples to use (-1 for all). 81 | * `provider`: `ollama` or `openai`. 82 | * `model_name`: Specific model for the provider. 83 | * `ollama_host`: (**Optional**) Specify the host address for the Ollama server (e.g., `http://192.168.1.100:11434`). If `null` 84 | or omitted, uses `OLLAMA_HOST` env var or defaults to `http://localhost:11434`. 85 | * `use_async`: `true` to run LLM calls concurrently. 86 | * `concurrency`: Max parallel requests for async runs. 87 | * `use_json_strategies`: `true` or `false`. Default for strategies supporting JSON output (can be overridden per strategy). 88 | * `output_file`: Path to save raw results (JSONL). 89 | * `llm_params`: Global LLM settings (`max_tokens`, `seed`, `temperature`) applied unless overridden per strategy. 90 | * **`evaluation`**: Settings for the evaluation script (`evaluate.py`). 91 | * `input_file`: Path to the results JSONL file (defaults to `generation.output_file`). 92 | * `extractor`: Configures how final answers are extracted. 93 | * `type`: `heuristic` or `llm`. 94 | * `provider`, `model_name`: Settings for the LLM extractor if `type` is `llm`. 95 | * `ollama_host`: Specify the Ollama host for the extractor LLM, if using `type: llm` and 96 | `provider: ollama`. Defaults apply if null/omitted. 97 | * `llm_params`: Settings for the LLM extractor if `type` is `llm`. 98 | * `show_details`: `true` to print per-question evaluation details. 99 | * `concurrency`: Max parallel requests for the LLM extractor. 100 | * **`strategies`**: Configure individual CoT strategies. 101 | * Each key is the strategy's class name (e.g., `AutoCoT`). 102 | * Including a section enables the strategy by default. Add `enabled: false` to disable. 103 | * Set strategy-specific parameters (e.g., `n_demos`, `pool_size`, `n_samples`, `max_depth`). 104 | * Strategy-specific LLM parameters (like `temperature` for `SelfConsistency`) or format choices (`internal_extraction_format`, 105 | `intermediate_output_format`, `final_answer_format`) set here override global settings from the `generation` section for 106 | that specific strategy. 107 | 108 | See the example `benches.yml` in the repository for detailed options. 109 | 110 | > [!NOTE] 111 | > For parameters like OpenAI keys, it is recommended to specify the *environment variable name* in `benches.yml` (e.g., 112 | `openai_key_env_var: "MY_API_KEY"`) rather than pasting the key directly into the file. 113 | > The scripts will then read the key from the specified environment variable. 114 | > You can still override this by passing `--openai-key` on the command line. 115 | 116 | ### Example Usages (`run.py`) 117 | 118 | * Run using configuration defined in `benches.yml`: 119 | ```bash 120 | # Assumes benches.yml is configured as desired 121 | poetry run python benches/run.py 122 | ``` 123 | 124 | * Run using `benches.yml` but override the dataset via command line: 125 | ```bash 126 | poetry run python benches/run.py --dataset aqua 127 | ``` 128 | 129 | ### Dependencies 130 | 131 | To run the benchmarks, you might want to install the development dependencies along with Cogitator itself. 132 | 133 | ```bash 134 | poetry install --with dev 135 | ``` 136 | 137 | Additionally, any model used in the benchmarks must be available. 138 | Make sure the Ollama server is **running** and pull desired models using `ollama pull `. 139 | Make sure the OpenAI key is set correctly if using the OpenAI models. 140 | 141 | ### More Examples 142 | 143 | ```bash 144 | # Run using benches.yml (assuming it's configured for Ollama, gemma2:9b, aqua, async, etc.) 145 | poetry run python benches/run.py --output-file my_ollama_results.jsonl 146 | 147 | # Evaluate the results using heuristic (as configured in benches.yml or default) 148 | poetry run python benches/evaluate.py --input-file my_ollama_results.jsonl --show-details 149 | 150 | # Evaluate the results using LLM extractor (override benches.yml extractor setting) 151 | poetry run python benches/evaluate.py --extractor-type llm --provider ollama --model-name llama3 --input-file my_ollama_results.jsonl 152 | 153 | # Run specifically with OpenAI, overriding YAML if necessary 154 | poetry run python benches/run.py --provider openai --model-name gpt-4o-mini --dataset csqa --cutoff 10 --use-async --output-file my_openai_results.jsonl 155 | 156 | # Evaluate the OpenAI results using an OpenAI extractor model 157 | poetry run python benches/evaluate.py --input-file my_openai_results.jsonl --extractor-type llm --provider openai --model-name gpt-4o-mini 158 | ``` 159 | 160 | ## Performance Metric 161 | 162 | Accuracy is the primary metric reported by the `evaluate.py` script. 163 | It is defined as the percentage of correctly answered questions out of the total number of successfully extracted answers for a 164 | given CoT strategy. 165 | 166 | > [!NOTE] 167 | > This definition means accuracy reflects performance *only* on runs where the final answer could be successfully extracted. 168 | > Runs resulting in extraction errors (e.g., the extractor fails to find an answer in the raw output) are excluded from the 169 | > accuracy calculation, which is important when comparing strategies with different extraction success rates. 170 | 171 | ## Datasets 172 | 173 | The following datasets can be used for values in the `--dataset` argument of the `run.py` script. 174 | 175 | | Dataset Name | Source Link | Category Tags | Description | 176 | |:-------------|:-------------------------------------------------------------------------------------|:--------------------------|:--------------------------------------------------| 177 | | `gsm8k` | [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) | `math` | Grade school math word problems | 178 | | `multiarith` | [ChilleD/MultiArith](https://huggingface.co/datasets/ChilleD/MultiArith) | `math` | Multi-step arithmetic problems | 179 | | `aqua` | [deepmind/aqua_rat](https://huggingface.co/datasets/deepmind/aqua_rat) | `math` | Algebraic word problems with rationales | 180 | | `csqa` | [tau/commonsense_qa](https://huggingface.co/datasets/tau/commonsense_qa) | `commonsense` | Multiple-choice commonsense questions | 181 | | `strategyqa` | [ChilleD/StrategyQA](https://huggingface.co/datasets/ChilleD/StrategyQA) | `commonsense`, `symbolic` | Yes and no questions requiring implicit reasoning | 182 | | `coin` | [skrishna/coin_flip](https://huggingface.co/datasets/skrishna/coin_flip) | `symbolic` | Symbolic tasks involving state tracking | 183 | | `letter` | [ChilleD/LastLetterConcat](https://huggingface.co/datasets/ChilleD/LastLetterConcat) | `symbolic`, `text` | Extract and concatenate last letters from words | 184 | -------------------------------------------------------------------------------- /benches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/benches/__init__.py -------------------------------------------------------------------------------- /benches/extractors.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Optional 4 | 5 | logger_extractors = logging.getLogger("benchmark_extractors") 6 | 7 | 8 | def _extract_label_heuristic(text: str) -> Optional[str]: 9 | text = str(text).strip() 10 | lines = text.strip().splitlines() 11 | 12 | mcq_patterns = [ 13 | r'(?:final answer|answer|choice|option) is\s*:?\s*\(([A-Ea-e])\)', 14 | r'(?:final answer|answer|choice|option) is\s*:?\s*([A-Ea-e])\b', 15 | r'\(([A-Ea-e])\)\s*is the correct answer', 16 | r'\b([A-Ea-e])\)\s*$', 17 | r'^([A-Ea-e])\)$', 18 | r'correct label \(character before \'\)\'\)\n([A-Ea-e])', 19 | r'\b(?:Answer|Choice|Option):\s*([A-Ea-e])\b', 20 | ] 21 | 22 | if lines: 23 | last_line = lines[-1].strip() 24 | for pattern in mcq_patterns: 25 | match = re.search(pattern, last_line, re.IGNORECASE) 26 | if match: 27 | ans = match.group(1).upper() 28 | logger_extractors.debug(f"Extracted MCQ label '{ans}' from last line") 29 | return ans 30 | if re.fullmatch(r'[A-Ea-e]', last_line): 31 | ans = last_line.upper() 32 | logger_extractors.debug(f"Extracted MCQ label '{ans}' as last line content") 33 | return ans 34 | 35 | for pattern in mcq_patterns: 36 | match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) 37 | if match: 38 | ans = match.group(1).upper() 39 | logger_extractors.debug(f"Extracted MCQ label '{ans}' from full text") 40 | return ans 41 | 42 | return None 43 | 44 | 45 | def _extract_numerical_heuristic(text: str) -> Optional[str]: 46 | text = str(text).strip() 47 | lines = text.strip().splitlines() 48 | 49 | gsm_pattern = r'####\s*([+-]?\d+(?:,\d+)*\.?\d*)' 50 | boxed_pattern = r'\\boxed\{([+-]?\d+(?:,\d+)*\.?\d*)\}' 51 | 52 | match = re.search(gsm_pattern, text) 53 | if match: 54 | extracted_num = match.group(1).replace(",", "") 55 | logger_extractors.debug(f"Extracted numerical answer '{extracted_num}' using GSM pattern") 56 | return extracted_num 57 | 58 | match = re.search(boxed_pattern, text) 59 | if match: 60 | extracted_num = match.group(1).replace(",", "") 61 | logger_extractors.debug(f"Extracted numerical answer '{extracted_num}' using boxed pattern") 62 | return extracted_num 63 | 64 | num_pattern_loose = r'[+-]?\d+(?:,\d+)*(?:\.\d+)?' 65 | numerical_patterns = [ 66 | r'(?:final answer is|the final answer is|final answer:|answer:|the answer is)\s*:?\s*(' + num_pattern_loose + r')\b', 67 | r'(?:is|equals|result is|=)\s*(' + num_pattern_loose + r')\s*\.?\s*$', 68 | ] 69 | 70 | if lines: 71 | last_line = lines[-1].strip() 72 | for pattern in numerical_patterns: 73 | match = re.search(pattern, last_line, re.IGNORECASE) 74 | if match: 75 | extracted_num = match.group(1).replace(",", "") 76 | logger_extractors.debug( 77 | f"Extracted numerical answer '{extracted_num}' from last line pattern") 78 | return extracted_num 79 | if re.fullmatch(num_pattern_loose, last_line.replace("$", "")): 80 | extracted_num = last_line.replace(",", "").replace("$", "") 81 | logger_extractors.debug( 82 | f"Extracted numerical answer '{extracted_num}' as last line content") 83 | return extracted_num 84 | 85 | for pattern in numerical_patterns: 86 | match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) 87 | if match: 88 | extracted_num = match.group(1).replace(",", "") 89 | logger_extractors.debug( 90 | f"Extracted numerical answer '{extracted_num}' from full text pattern") 91 | return extracted_num 92 | 93 | numbers = re.findall(num_pattern_loose, text) 94 | if numbers: 95 | last_num_str = numbers[-1].replace(",", "") 96 | logger_extractors.debug(f"Extracted numerical answer '{last_num_str}' as last number found") 97 | return last_num_str 98 | 99 | return None 100 | 101 | 102 | def _extract_boolean_heuristic(text: str) -> Optional[str]: 103 | text = str(text).strip().lower() 104 | lines = text.strip().splitlines() 105 | 106 | if lines: 107 | last_line = lines[-1].strip().lower().rstrip(".") 108 | if last_line == "yes": return "yes" 109 | if last_line == "no": return "no" 110 | if last_line == "true": return "yes" 111 | if last_line == "false": return "no" 112 | match = re.search(r'\b(?:answer|result) is\s+(yes|no|true|false)\b', last_line) 113 | if match: 114 | ans = match.group(1) 115 | logger_extractors.debug(f"Extracted boolean '{ans}' from last line pattern") 116 | return "yes" if ans == "true" else "no" if ans == "false" else ans 117 | 118 | match = re.search(r'\b(?:final answer|answer|result) is\s+(yes|no|true|false)\b', text) 119 | if match: 120 | ans = match.group(1) 121 | logger_extractors.debug(f"Extracted boolean '{ans}' from full text pattern") 122 | return "yes" if ans == "true" else "no" if ans == "false" else ans 123 | 124 | if "yes" in text.split()[-5:]: return "yes" 125 | if "no" in text.split()[-5:]: return "no" 126 | 127 | return None 128 | 129 | 130 | def _extract_letter_concat_heuristic(text: str) -> Optional[str]: 131 | text = str(text).strip() 132 | lines = text.strip().splitlines() 133 | 134 | pattern = r'(?:final answer|answer|result) is\s*:?\s*([a-zA-Z]+)\b' 135 | 136 | if lines: 137 | last_line = lines[-1].strip().rstrip(".") 138 | match = re.search(pattern, last_line, re.IGNORECASE) 139 | if match: 140 | ans = match.group(1) 141 | logger_extractors.debug(f"Extracted letter concat '{ans}' from last line pattern") 142 | return ans 143 | if re.fullmatch(r'[a-zA-Z]+', last_line): 144 | logger_extractors.debug(f"Extracted letter concat '{last_line}' as last line content") 145 | return last_line 146 | 147 | match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) 148 | if match: 149 | ans = match.group(1) 150 | logger_extractors.debug(f"Extracted letter concat '{ans}' from full text pattern") 151 | return ans 152 | 153 | if lines and re.fullmatch(r'[a-zA-Z]{2,}', lines[-1].strip()): 154 | logger_extractors.debug( 155 | f"Extracted letter concat '{lines[-1].strip()}' as fallback last line") 156 | return lines[-1].strip() 157 | 158 | return None 159 | 160 | 161 | def extract_answer_heuristic_custom(raw_output: str, dataset_name: str) -> str: 162 | if not raw_output or str(raw_output).strip() == "[ERROR]" or str(raw_output).strip().startswith( 163 | "[ERROR:"): 164 | return "[ERROR]" 165 | 166 | text = str(raw_output).strip() 167 | extracted: Optional[str] = None 168 | 169 | if dataset_name in ["aqua", "csqa"]: 170 | extracted = _extract_label_heuristic(text) 171 | elif dataset_name in ["gsm8k", "multiarith"]: 172 | extracted = _extract_numerical_heuristic(text) 173 | elif dataset_name in ["strategyqa", "coin"]: 174 | extracted = _extract_boolean_heuristic(text) 175 | elif dataset_name == "letter": 176 | extracted = _extract_letter_concat_heuristic(text) 177 | else: 178 | logger_extractors.warning( 179 | f"No specific heuristic defined for dataset '{dataset_name}'. Using generic fallback.") 180 | extracted = _extract_numerical_heuristic(text) 181 | if extracted is None: 182 | extracted = _extract_label_heuristic(text) 183 | 184 | if extracted is None: 185 | logger_extractors.warning( 186 | f"Could not extract answer for dataset '{dataset_name}' using custom heuristic from: '{text[:150]}...'") 187 | return "[EXTRACTION_HEURISTIC_FAILURE]" 188 | 189 | return extracted.strip() 190 | 191 | 192 | MCQ_EXTRACTION_PROMPT = """ 193 | Original Question: 194 | {question} 195 | 196 | LLM Reasoning and Output: 197 | \"\"\" 198 | {raw_output} 199 | \"\"\" 200 | 201 | Analyze the LLM Reasoning and Output based on the Original Question. 202 | The original question is multiple choice with options typically labeled A, B, C, D, E. 203 | Extract ONLY the single capital letter corresponding to the final choice identified in the reasoning or output. 204 | If the reasoning calculates a value, ensure you extract the letter label associated with that value in the options. 205 | If no specific choice label is clearly identified as the final answer, return null. 206 | Return the result as a JSON object with a single key "final_answer" containing the final answer label (A, B, C, D, or E) as a string, or null if not found. 207 | 208 | JSON Output: 209 | """ 210 | 211 | NUMERICAL_EXTRACTION_PROMPT = """ 212 | Original Question: 213 | {question} 214 | 215 | LLM Reasoning and Output: 216 | \"\"\" 217 | {raw_output} 218 | \"\"\" 219 | 220 | Analyze the LLM Reasoning and Output based on the Original Question. 221 | Extract ONLY the final numerical answer stated in the text. Ignore intermediate calculations. 222 | Look for patterns like "Final Answer: [number]", "#### [number]", or the last number mentioned if it seems conclusive. 223 | Return the result as a JSON object with a single key "final_answer" containing the final numerical answer as a string (e.g., "15", "36.5", "77.5"), or null if no definitive numerical answer is found. 224 | 225 | JSON Output: 226 | """ 227 | 228 | BOOLEAN_EXTRACTION_PROMPT = """ 229 | Original Question: 230 | {question} 231 | 232 | LLM Reasoning and Output: 233 | \"\"\" 234 | {raw_output} 235 | \"\"\" 236 | 237 | Analyze the LLM Reasoning and Output based on the Original Question. 238 | Extract ONLY the final boolean answer (yes or no) stated in the text. 239 | Look for explicit statements like "Answer: yes", "Final Answer: no", or the concluding yes/no statement. 240 | Return the result as a JSON object with a single key "final_answer" containing either the string "yes" or "no", or null if no definitive boolean answer is found. 241 | 242 | JSON Output: 243 | """ 244 | 245 | TEXT_EXTRACTION_PROMPT = """ 246 | Original Question: 247 | {question} 248 | 249 | LLM Reasoning and Output: 250 | \"\"\" 251 | {raw_output} 252 | \"\"\" 253 | 254 | Analyze the LLM Reasoning and Output based on the Original Question. 255 | Extract ONLY the final short text answer stated in the text (e.g., a concatenated string of letters for the 'letter' dataset). 256 | Look for patterns like "Answer: [text]" or the concluding text segment if it seems to be the final answer. 257 | Return the result as a JSON object with a single key "final_answer" containing the final text answer as a string, or null if no definitive text answer is found. 258 | 259 | JSON Output: 260 | """ 261 | 262 | 263 | def get_llm_extraction_prompt(question: str, raw_output: str, dataset_name: str) -> str: 264 | template: str 265 | if dataset_name in ["aqua", "csqa"]: 266 | template = MCQ_EXTRACTION_PROMPT 267 | elif dataset_name in ["gsm8k", "multiarith"]: 268 | template = NUMERICAL_EXTRACTION_PROMPT 269 | elif dataset_name in ["strategyqa", "coin"]: 270 | template = BOOLEAN_EXTRACTION_PROMPT 271 | elif dataset_name == "letter": 272 | template = TEXT_EXTRACTION_PROMPT 273 | else: 274 | logger_extractors.warning( 275 | f"No specific LLM prompt template defined for dataset '{dataset_name}'. Using generic numerical fallback.") 276 | template = NUMERICAL_EXTRACTION_PROMPT 277 | 278 | return template.format(question=question, raw_output=raw_output) 279 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - "examples/*" 3 | - "benches/*" 4 | -------------------------------------------------------------------------------- /cogitator/__init__.py: -------------------------------------------------------------------------------- 1 | """Cogitator: A Python Toolkit for Chain-of-Thought Prompting. 2 | 3 | This package provides implementations of various chain-of-thought (CoT) prompting 4 | strategies and frameworks, along with supporting utilities like LLM provider interfaces, 5 | embedding models, clustering algorithms, and data validation schemas. 6 | It aims to make it easier to try and integrate CoT methods into AI applications. 7 | """ 8 | 9 | import importlib 10 | import logging 11 | 12 | from .clustering import BaseClusterer, KMeansClusterer 13 | from .embedding import BaseEmbedder, SentenceTransformerEmbedder 14 | from .model import BaseLLM, OllamaLLM, OpenAILLM 15 | from .schemas import ( 16 | EvaluationResult, 17 | ExtractedAnswer, 18 | LTMDecomposition, 19 | ThoughtExpansion, 20 | ) 21 | from .strategies import ( 22 | AutoCoT, 23 | CDWCoT, 24 | GraphOfThoughts, 25 | LeastToMost, 26 | SelfConsistency, 27 | TreeOfThoughts, 28 | ) 29 | from .utils import accuracy, approx_token_length, count_steps, exact_match 30 | 31 | _logger = logging.getLogger(__name__) 32 | try: 33 | __version__ = importlib.metadata.version("cogitator") 34 | except importlib.metadata.PackageNotFoundError: 35 | __version__ = "0.0.0-unknown" 36 | _logger.warning( 37 | "Could not determine package version using importlib.metadata. " 38 | "Is the library installed correctly?" 39 | ) 40 | 41 | __all__ = [ 42 | "AutoCoT", 43 | "BaseClusterer", 44 | "BaseEmbedder", 45 | "BaseLLM", 46 | "CDWCoT", 47 | "EvaluationResult", 48 | "ExtractedAnswer", 49 | "GraphOfThoughts", 50 | "KMeansClusterer", 51 | "LTMDecomposition", 52 | "LeastToMost", 53 | "OllamaLLM", 54 | "OpenAILLM", 55 | "SelfConsistency", 56 | "SentenceTransformerEmbedder", 57 | "ThoughtExpansion", 58 | "TreeOfThoughts", 59 | "accuracy", 60 | "approx_token_length", 61 | "count_steps", 62 | "exact_match", 63 | ] 64 | -------------------------------------------------------------------------------- /cogitator/clustering.py: -------------------------------------------------------------------------------- 1 | """Provides abstractions and implementations for clustering algorithms.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Tuple 5 | 6 | import numpy as np 7 | from sklearn.cluster import KMeans 8 | 9 | 10 | class BaseClusterer(ABC): 11 | """Abstract base class for clustering algorithms.""" 12 | 13 | @abstractmethod 14 | def cluster( 15 | self, embeddings: np.ndarray, n_clusters: int, **kwargs: Any 16 | ) -> Tuple[np.ndarray, np.ndarray]: 17 | """Clusters the given embeddings into a specified number of clusters. 18 | 19 | Args: 20 | embeddings: A NumPy array where each row is an embedding vector. 21 | n_clusters: The desired number of clusters. 22 | **kwargs: Additional keyword arguments specific to the clustering implementation. 23 | 24 | Returns: 25 | A tuple containing: 26 | - A NumPy array of cluster labels assigned to each embedding. 27 | - A NumPy array of cluster centers. 28 | """ 29 | ... 30 | 31 | 32 | class KMeansClusterer(BaseClusterer): 33 | """A clustering implementation using the K-Means algorithm from scikit-learn.""" 34 | 35 | def cluster( 36 | self, embeddings: np.ndarray, n_clusters: int, **kwargs: Any 37 | ) -> Tuple[np.ndarray, np.ndarray]: 38 | """Clusters embeddings using K-Means. 39 | 40 | Args: 41 | embeddings: The embeddings to cluster (shape: [n_samples, n_features]). 42 | n_clusters: The number of clusters to form. 43 | **kwargs: Additional arguments for `sklearn.cluster.KMeans`. 44 | Supported args include `random_seed` (or `seed`) and `n_init`. 45 | 46 | Returns: 47 | A tuple containing: 48 | - labels (np.ndarray): Integer labels array (shape: [n_samples,]). 49 | - centers (np.ndarray): Coordinates of cluster centers (shape: [n_clusters, n_features]). 50 | 51 | Raises: 52 | ValueError: If `n_clusters` is invalid or embeddings are incompatible. 53 | """ 54 | random_seed = kwargs.get("random_seed") or kwargs.get("seed") 55 | n_init = kwargs.get("n_init", "auto") 56 | kmeans = KMeans( 57 | n_clusters=n_clusters, 58 | random_state=random_seed, 59 | n_init=n_init, 60 | init="k-means++", 61 | ) 62 | labels = kmeans.fit_predict(embeddings) 63 | return labels, kmeans.cluster_centers_ 64 | -------------------------------------------------------------------------------- /cogitator/embedding.py: -------------------------------------------------------------------------------- 1 | """Provides abstractions and implementations for text embedding models.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List, Optional 5 | 6 | import numpy as np 7 | from sentence_transformers import SentenceTransformer 8 | 9 | 10 | class BaseEmbedder(ABC): 11 | """Abstract base class for text embedding models.""" 12 | 13 | @abstractmethod 14 | def encode(self, texts: List[str]) -> List[np.ndarray]: 15 | """Encodes a list of texts into embedding vectors. 16 | 17 | Args: 18 | texts: A list of strings to encode. 19 | 20 | Returns: 21 | A list of NumPy arrays, where each array is the embedding vector for 22 | the corresponding text. 23 | """ 24 | ... 25 | 26 | 27 | class SentenceTransformerEmbedder(BaseEmbedder): 28 | """An embedder implementation using the sentence-transformers library. 29 | 30 | This class uses a singleton pattern to avoid reloading the model multiple times. 31 | """ 32 | 33 | _instance: Optional["SentenceTransformerEmbedder"] = None 34 | _model: Optional[SentenceTransformer] = None 35 | 36 | def __new__(cls, model_name: str = "all-MiniLM-L6-v2") -> "SentenceTransformerEmbedder": 37 | """Creates or returns the singleton instance of the embedder. 38 | 39 | Args: 40 | model_name: The name of the sentence-transformer model to load. 41 | This argument is only used during the first instantiation. 42 | 43 | Returns: 44 | The singleton instance of SentenceTransformerEmbedder. 45 | """ 46 | if cls._instance is None: 47 | cls._instance = super(SentenceTransformerEmbedder, cls).__new__(cls) 48 | cls._model = SentenceTransformer(model_name) 49 | return cls._instance 50 | 51 | def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None: 52 | """Initializes the SentenceTransformerEmbedder instance. 53 | 54 | Note: Due to the singleton pattern implemented in `__new__`, the 55 | `model_name` argument here is effectively ignored after the first 56 | instantiation. The model loaded is determined by the `model_name` 57 | passed during the first call to `__new__` or `__init__`. 58 | 59 | Args: 60 | model_name: The name of the sentence-transformer model. Defaults to 61 | "all-MiniLM-L6-v2". 62 | """ 63 | pass 64 | 65 | def encode(self, texts: List[str]) -> List[np.ndarray]: 66 | """Encodes a list of texts using the loaded sentence-transformer model. 67 | 68 | Args: 69 | texts: The list of strings to encode. 70 | 71 | Returns: 72 | A list of NumPy ndarray embeddings. 73 | 74 | Raises: 75 | RuntimeError: If the embedding model has not been initialized correctly. 76 | """ 77 | if self._model is None: 78 | raise RuntimeError("Embedder model not initialized.") 79 | embeddings: List[np.ndarray] = self._model.encode( 80 | texts, convert_to_numpy=True, show_progress_bar=False 81 | ) 82 | return embeddings 83 | -------------------------------------------------------------------------------- /cogitator/model/__init__.py: -------------------------------------------------------------------------------- 1 | """Provides interfaces and implementations for LLM providers. 2 | 3 | This subpackage defines the abstract base class `BaseLLM` and implementations for interacting with 4 | different LLM services like OpenAI and Ollama. 5 | """ 6 | 7 | from .base import BaseLLM 8 | from .ollama import OllamaLLM 9 | from .openai import OpenAILLM 10 | 11 | __all__ = [ 12 | "BaseLLM", 13 | "OllamaLLM", 14 | "OpenAILLM", 15 | ] 16 | -------------------------------------------------------------------------------- /cogitator/model/base.py: -------------------------------------------------------------------------------- 1 | """Defines the abstract base class for LLM providers.""" 2 | 3 | import asyncio 4 | import json 5 | import logging 6 | import re 7 | import time 8 | from abc import ABC, abstractmethod 9 | from typing import Any, AsyncIterator, Iterator, Optional, Tuple, Type 10 | 11 | from pydantic import BaseModel, ValidationError 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class BaseLLM(ABC): 17 | """Abstract base class defining the interface for LLM providers.""" 18 | 19 | def __init__(self) -> None: 20 | """Initializes token count storage.""" 21 | self._last_prompt_tokens: Optional[int] = None 22 | self._last_completion_tokens: Optional[int] = None 23 | 24 | def get_last_prompt_tokens(self) -> Optional[int]: 25 | """Returns the token count for the last prompt, if available.""" 26 | return self._last_prompt_tokens 27 | 28 | def get_last_completion_tokens(self) -> Optional[int]: 29 | """Returns the token count for the last completion, if available.""" 30 | return self._last_completion_tokens 31 | 32 | def _reset_token_counts(self) -> None: 33 | """Resets the stored token counts.""" 34 | self._last_prompt_tokens = None 35 | self._last_completion_tokens = None 36 | 37 | @abstractmethod 38 | def generate(self, prompt: str, **kwargs: Any) -> str: 39 | """Generates a single text completion for the given prompt. 40 | 41 | Args: 42 | prompt: The input text prompt. 43 | **kwargs: Additional provider-specific parameters (e.g., temperature, 44 | max_tokens, stop sequences, seed). 45 | 46 | Returns: 47 | The generated text completion as a string. 48 | 49 | Raises: 50 | RuntimeError: If the generation fails after retries or due to API errors. 51 | """ 52 | ... 53 | 54 | @abstractmethod 55 | async def generate_async(self, prompt: str, **kwargs: Any) -> str: 56 | """Asynchronously generates a single text completion for the given prompt. 57 | 58 | Args: 59 | prompt: The input text prompt. 60 | **kwargs: Additional provider-specific parameters. 61 | 62 | Returns: 63 | The generated text completion as a string. 64 | 65 | Raises: 66 | RuntimeError: If the asynchronous generation fails. 67 | """ 68 | ... 69 | 70 | @abstractmethod 71 | def generate_stream(self, prompt: str, **kwargs: Any) -> Iterator[str]: 72 | """Generates a stream of text chunks for the given prompt. 73 | 74 | Args: 75 | prompt: The input text prompt. 76 | **kwargs: Additional provider-specific parameters. 77 | 78 | Yields: 79 | Strings representing chunks of the generated text. 80 | 81 | Raises: 82 | RuntimeError: If starting the stream generation fails. 83 | """ 84 | ... 85 | 86 | @abstractmethod 87 | async def generate_stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[str]: 88 | """Asynchronously generates a stream of text chunks for the given prompt. 89 | 90 | Args: 91 | prompt: The input text prompt. 92 | **kwargs: Additional provider-specific parameters. 93 | 94 | Yields: 95 | Strings representing chunks of the generated text asynchronously. 96 | 97 | Raises: 98 | RuntimeError: If starting the asynchronous stream generation fails. 99 | """ 100 | ... 101 | 102 | @abstractmethod 103 | def _generate_json_internal( 104 | self, prompt: str, response_model: Type[BaseModel], **kwargs: Any 105 | ) -> Tuple[str, Optional[str]]: 106 | """Internal method to generate raw JSON output string from the LLM. 107 | 108 | This method handles the actual API call for JSON generation, potentially 109 | using provider-specific features like JSON mode or schema enforcement. 110 | It should also handle updating the internal token counts. 111 | 112 | Args: 113 | prompt: The input prompt, potentially instructing JSON format. 114 | response_model: The Pydantic model class for the expected response structure. 115 | **kwargs: Additional provider-specific parameters. 116 | 117 | Returns: 118 | A tuple containing: 119 | - The raw string response from the LLM, expected to be JSON. 120 | - An optional string indicating the JSON generation mode used (e.g., 121 | 'json_schema', 'json_object', 'heuristic'), or None if extraction 122 | is needed. 123 | 124 | Raises: 125 | RuntimeError: If the underlying LLM call fails. 126 | """ 127 | ... 128 | 129 | @abstractmethod 130 | async def _generate_json_internal_async( 131 | self, prompt: str, response_model: Type[BaseModel], **kwargs: Any 132 | ) -> Tuple[str, Optional[str]]: 133 | """Asynchronous internal method to generate raw JSON output string from the LLM. 134 | 135 | It should also handle updating the internal token counts. 136 | 137 | Args: 138 | prompt: The input prompt, potentially instructing JSON format. 139 | response_model: The Pydantic model class for the expected response structure. 140 | **kwargs: Additional provider-specific parameters. 141 | 142 | Returns: 143 | A tuple containing: 144 | - The raw string response from the LLM, expected to be JSON. 145 | - An optional string indicating the JSON generation mode used. 146 | 147 | Raises: 148 | RuntimeError: If the underlying asynchronous LLM call fails. 149 | """ 150 | ... 151 | 152 | def _extract_json_block(self, text: str) -> str: 153 | """Extracts the first JSON object or array from a string. 154 | 155 | Handles JSON enclosed in markdown code fences (```json ... ``` or ``` ... ```) 156 | or finds the first substring starting with '{' and ending with '}' or 157 | starting with '[' and ending with ']'. 158 | 159 | Args: 160 | text: The string possibly containing a JSON block. 161 | 162 | Returns: 163 | The extracted JSON string, or the original text if no block is found. 164 | """ 165 | fence_match = re.search( 166 | r"```(?:json)?\s*(\{.*\}|\[.*\])\s*```", text, re.DOTALL | re.IGNORECASE 167 | ) 168 | if fence_match: 169 | return fence_match.group(1) 170 | 171 | # Find the first standalone JSON object or array 172 | first_obj_start = text.find("{") 173 | first_arr_start = text.find("[") 174 | 175 | if first_obj_start == -1 and first_arr_start == -1: 176 | return text # No JSON start found 177 | 178 | start_index = -1 179 | if first_obj_start != -1 and first_arr_start != -1: 180 | start_index = min(first_obj_start, first_arr_start) 181 | elif first_obj_start != -1: 182 | start_index = first_obj_start 183 | else: # first_arr_start != -1 184 | start_index = first_arr_start 185 | 186 | # Attempt to find the matching end brace/bracket 187 | # This is a simplified approach and might fail for complex nested structures 188 | # if they appear outside the main intended JSON block. 189 | json_str = text[start_index:] 190 | try: 191 | # Try parsing to find the end implicitly 192 | parsed_obj, end_index = json.JSONDecoder().raw_decode(json_str) 193 | return json_str[:end_index] 194 | except json.JSONDecodeError: 195 | # Fallback: Search for the last brace/bracket if raw_decode fails 196 | # This is less reliable. 197 | last_brace = text.rfind("}") 198 | last_bracket = text.rfind("]") 199 | end_index = max(last_brace, last_bracket) 200 | if end_index > start_index: 201 | potential_json = text[start_index : end_index + 1] 202 | # Final check if this substring is valid JSON 203 | try: 204 | json.loads(potential_json) 205 | return potential_json 206 | except json.JSONDecodeError: 207 | pass # Fall through if this substring isn't valid 208 | 209 | # If parsing/fallback fails, return the original text 210 | return text 211 | 212 | def generate_json( 213 | self, prompt: str, response_model: Type[BaseModel], retries: int = 2, **kwargs: Any 214 | ) -> BaseModel: 215 | """Generates a response and parses it into a Pydantic model instance. 216 | 217 | Uses `_generate_json_internal` and attempts to parse the result. 218 | Retries on validation or decoding errors. Also updates internal token counts. 219 | 220 | Args: 221 | prompt: The input prompt, often instructing the LLM to respond in JSON. 222 | response_model: The Pydantic model class to validate the response against. 223 | retries: The number of times to retry on parsing/validation failure. 224 | **kwargs: Additional provider-specific parameters for generation. 225 | 226 | Returns: 227 | An instance of the `response_model` populated with data from the LLM response. 228 | 229 | Raises: 230 | RuntimeError: If parsing fails after all retries. 231 | ValidationError: If the final response does not match the `response_model`. 232 | json.JSONDecodeError: If the final response is not valid JSON. 233 | """ 234 | last_error = None 235 | temp = kwargs.pop("temperature", 0.1) 236 | json_kwargs = {**kwargs, "temperature": temp} 237 | self._reset_token_counts() # Reset before attempts 238 | 239 | for attempt in range(retries + 1): 240 | raw = "" 241 | block = "" 242 | mode_used = None 243 | try: 244 | # _generate_json_internal is responsible for updating token counts 245 | raw, mode_used = self._generate_json_internal(prompt, response_model, **json_kwargs) 246 | 247 | if mode_used in ["json_schema", "json_object", "ollama_schema_format"]: 248 | # Assume the provider handled JSON enforcement 249 | block = raw 250 | else: 251 | # Fallback to extracting JSON block heuristically 252 | block = self._extract_json_block(raw) 253 | 254 | validated_model = response_model.model_validate_json(block.strip()) 255 | # Token counts should have been set by _generate_json_internal 256 | return validated_model 257 | except (json.JSONDecodeError, ValidationError) as ve: 258 | last_error = ve 259 | logger.warning( 260 | "JSON validation/decode error %d/%d (mode: %s): %s\nBlock: %.200s\nRaw: %.200s", 261 | attempt + 1, 262 | retries + 1, 263 | mode_used, 264 | ve, 265 | block, 266 | raw, 267 | ) 268 | self._reset_token_counts() # Reset counts on error 269 | except Exception as e: 270 | last_error = e 271 | logger.error( 272 | "Error generating JSON %d/%d (mode: %s): %s", 273 | attempt + 1, 274 | retries + 1, 275 | mode_used, 276 | e, 277 | exc_info=True, 278 | ) 279 | self._reset_token_counts() # Reset counts on error 280 | 281 | if attempt < retries: 282 | sleep_time = 2**attempt 283 | logger.info(f"Retrying JSON generation in {sleep_time} seconds...") 284 | time.sleep(sleep_time) 285 | self._reset_token_counts() # Reset before retry 286 | 287 | # If loop finishes without success 288 | raise RuntimeError( 289 | f"generate_json failed after {retries + 1} attempts. Last error: {type(last_error).__name__}: {last_error}" 290 | ) 291 | 292 | async def generate_json_async( 293 | self, prompt: str, response_model: Type[BaseModel], retries: int = 2, **kwargs: Any 294 | ) -> BaseModel: 295 | """Asynchronously generates a response and parses it into a Pydantic model instance. 296 | 297 | Uses `_generate_json_internal_async` and attempts to parse the result. 298 | Retries on validation or decoding errors. Also updates internal token counts. 299 | 300 | Args: 301 | prompt: The input prompt, often instructing the LLM to respond in JSON. 302 | response_model: The Pydantic model class to validate the response against. 303 | retries: The number of times to retry on parsing/validation failure. 304 | **kwargs: Additional provider-specific parameters for generation. 305 | 306 | Returns: 307 | An instance of the `response_model` populated with data from the LLM response. 308 | 309 | Raises: 310 | RuntimeError: If parsing fails after all retries. 311 | ValidationError: If the final response does not match the `response_model`. 312 | json.JSONDecodeError: If the final response is not valid JSON. 313 | """ 314 | last_error = None 315 | temp = kwargs.pop("temperature", 0.1) 316 | json_kwargs = {**kwargs, "temperature": temp} 317 | self._reset_token_counts() # Reset before attempts 318 | 319 | for attempt in range(retries + 1): 320 | raw = "" 321 | block = "" 322 | mode_used = None 323 | try: 324 | # _generate_json_internal_async is responsible for updating token counts 325 | raw, mode_used = await self._generate_json_internal_async( 326 | prompt, response_model, **json_kwargs 327 | ) 328 | 329 | if mode_used in ["json_schema", "json_object", "ollama_schema_format"]: 330 | block = raw 331 | else: 332 | block = self._extract_json_block(raw) 333 | 334 | validated_model = response_model.model_validate_json(block.strip()) 335 | # Token counts should have been set by _generate_json_internal_async 336 | return validated_model 337 | except (json.JSONDecodeError, ValidationError) as ve: 338 | last_error = ve 339 | logger.warning( 340 | "Async JSON validation/decode error %d/%d (mode: %s): %s\nBlock: %.200s\nRaw: %.200s", 341 | attempt + 1, 342 | retries + 1, 343 | mode_used, 344 | ve, 345 | block, 346 | raw, 347 | ) 348 | self._reset_token_counts() # Reset counts on error 349 | except Exception as e: 350 | last_error = e 351 | logger.error( 352 | "Error generating JSON async %d/%d (mode: %s): %s", 353 | attempt + 1, 354 | retries + 1, 355 | mode_used, 356 | e, 357 | exc_info=True, 358 | ) 359 | self._reset_token_counts() # Reset counts on error 360 | 361 | if attempt < retries: 362 | sleep_time = 2**attempt 363 | logger.info(f"Retrying async JSON generation in {sleep_time} seconds...") 364 | await asyncio.sleep(sleep_time) 365 | self._reset_token_counts() # Reset before retry 366 | 367 | # If loop finishes without success 368 | raise RuntimeError( 369 | f"generate_json_async failed after {retries + 1} attempts. Last error: {type(last_error).__name__}: {last_error}" 370 | ) 371 | -------------------------------------------------------------------------------- /cogitator/model/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/cogitator/model/py.typed -------------------------------------------------------------------------------- /cogitator/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/cogitator/py.typed -------------------------------------------------------------------------------- /cogitator/schemas.py: -------------------------------------------------------------------------------- 1 | """Defines Pydantic models for structured data exchange within Cogitator.""" 2 | 3 | from typing import List, Optional, Union 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class LTMDecomposition(BaseModel): 9 | """Schema for the output of the Least-to-Most decomposition step.""" 10 | 11 | subquestions: List[str] = Field(..., description="List of sequential subquestions") 12 | 13 | 14 | class ThoughtExpansion(BaseModel): 15 | """Schema for the output of a thought expansion step (e.g., in ToT).""" 16 | 17 | thoughts: List[str] = Field(..., description="List of distinct reasoning steps or thoughts") 18 | 19 | 20 | class EvaluationResult(BaseModel): 21 | """Schema for the output of an evaluation step (e.g., in ToT, GoT).""" 22 | 23 | score: int = Field(..., description="Quality score from 1 to 10") 24 | justification: str = Field(..., description="Brief justification for the score") 25 | 26 | 27 | class ExtractedAnswer(BaseModel): 28 | """Schema for the final extracted answer from a reasoning chain.""" 29 | 30 | final_answer: Optional[Union[str, int, float]] = Field( 31 | ..., description="The final extracted answer" 32 | ) 33 | -------------------------------------------------------------------------------- /cogitator/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | """Initializes the strategies subpackage. 2 | 3 | Exports the main CoT strategy classes implemented within this subpackage, making them 4 | available for direct import from `cogitator.strategies`. This includes various CoT and related reasoning frameworks. 5 | """ 6 | 7 | from .auto_cot import AutoCoT 8 | from .cdw_cot import CDWCoT 9 | from .graph_of_thoughts import GraphOfThoughts 10 | from .least_to_most import LeastToMost 11 | from .sc_cot import SelfConsistency 12 | from .tree_of_thoughts import TreeOfThoughts 13 | 14 | __all__ = [ 15 | "AutoCoT", 16 | "CDWCoT", 17 | "GraphOfThoughts", 18 | "LeastToMost", 19 | "SelfConsistency", 20 | "TreeOfThoughts", 21 | ] 22 | -------------------------------------------------------------------------------- /cogitator/strategies/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/cogitator/strategies/py.typed -------------------------------------------------------------------------------- /cogitator/strategies/sc_cot.py: -------------------------------------------------------------------------------- 1 | """Implements the Self-Consistency Chain-of-Thought (SC-CoT) strategy.""" 2 | 3 | import asyncio 4 | import logging 5 | import re 6 | from collections import Counter 7 | from typing import Any, AsyncIterator, Iterator, List, Literal, Optional 8 | 9 | from ..model import BaseLLM 10 | from ..schemas import ExtractedAnswer 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class SelfConsistency: 16 | """Implements the Self-Consistency Chain-of-Thought (SC-CoT) strategy. 17 | 18 | Self-Consistency improves CoT prompting by generating multiple diverse 19 | reasoning paths (using sampling with temperature > 0) and then selecting 20 | the most consistent answer among the paths via majority voting. 21 | 22 | Reference: 23 | Wang et al. (v4; 2023) "Self-Consistency Improves Chain of Thought Reasoning in Language Models". 24 | https://arxiv.org/abs/2203.11171 25 | """ 26 | 27 | def __init__( 28 | self, 29 | llm: BaseLLM, 30 | n_samples: int = 10, 31 | temperature: float = 0.8, 32 | max_tokens: int = 256, 33 | stop: Optional[List[str]] = None, 34 | internal_extraction_format: Literal["heuristic", "json"] = "heuristic", 35 | answer_extraction_prompt: Optional[str] = None, 36 | seed: Optional[int] = None, 37 | **gen_kwargs: Any, 38 | ) -> None: 39 | """Initializes the SelfConsistency strategy handler. 40 | 41 | Args: 42 | llm: The language model instance. 43 | n_samples: The number of reasoning paths (samples) to generate. 44 | temperature: Sampling temperature for generating diverse paths. Should be > 0. 45 | max_tokens: Maximum tokens for each generated reasoning path. 46 | stop: Optional stop sequences for LLM generation. 47 | internal_extraction_format: Method for extracting the final answer from 48 | each CoT path ('heuristic' or 'json'). 49 | answer_extraction_prompt: Prompt template used only if `internal_extraction_format` 50 | is 'json'. Must include {cot}. Expects JSON output 51 | matching ExtractedAnswer schema. 52 | seed: Base random seed for LLM sampling (each sample may use seed + i). 53 | **gen_kwargs: Additional default keyword arguments for LLM generation calls. 54 | """ 55 | self.llm = llm 56 | self.n_samples = n_samples 57 | self.temperature = temperature 58 | self.max_tokens = max_tokens 59 | self.stop = stop 60 | self.internal_extraction_format = internal_extraction_format 61 | self.seed = seed 62 | self.gen_kwargs = gen_kwargs 63 | 64 | if self.internal_extraction_format == "json": 65 | self.answer_extraction_prompt = ( 66 | answer_extraction_prompt 67 | or "Analyze the following reasoning chain and extract the final numerical or short answer. " 68 | "Return the result as a JSON object with a single key 'final_answer' containing the answer as a string.\n\n" 69 | "Reasoning Chain:\n{cot}\n\nJSON Answer:" 70 | ) 71 | else: 72 | self.answer_extraction_prompt = None 73 | 74 | def _extract_answer_heuristic(self, cot: str) -> str: 75 | """Extracts the final answer from a CoT string using heuristics. 76 | 77 | Searches for common patterns like "answer is X", lines starting with "Answer:", 78 | numeric lines, etc., working from the end of the CoT string upwards. 79 | 80 | Args: 81 | cot: The Chain-of-Thought reasoning string. 82 | 83 | Returns: 84 | The extracted answer string, or the last line as a fallback. 85 | """ 86 | lines = cot.strip().splitlines() 87 | for line in reversed(lines): 88 | text = line.strip().rstrip(".") 89 | if "=" in text: 90 | parts = text.split("=", 1) 91 | if len(parts) > 1: 92 | answer = parts[1].strip().lstrip("$").strip() 93 | logger.debug(f"Heuristically extracted answer (equals): '{answer}'") 94 | return answer 95 | m0 = re.search(r"(?i)\bthe answer is\s+(\S+)", text) 96 | if m0: 97 | answer = m0.group(1).lstrip("$").strip() 98 | logger.debug(f"Heuristically extracted answer (the answer is): '{answer}'") 99 | return answer 100 | m1 = re.match(r"(?i)^(?:Answer|Final Answer|Ans)\b[: ]\s*(.+)$", text) 101 | if m1: 102 | answer = m1.group(1).strip() 103 | logger.debug(f"Heuristically extracted answer (Prefix): '{answer}'") 104 | return answer 105 | m2 = re.match(r"^#+\s*([+-]?\d+(?:\.\d+)?)$", text) 106 | if m2: 107 | answer = m2.group(1) 108 | logger.debug(f"Heuristically extracted answer (Header): '{answer}'") 109 | return answer 110 | m3 = re.match(r"^\*{1,2}A[: ]\s*(.+?)\*{0,2}$", text, re.IGNORECASE) 111 | if m3: 112 | answer = m3.group(1).strip() 113 | logger.debug(f"Heuristically extracted answer (Markdown A:): '{answer}'") 114 | return answer 115 | m4 = re.search(r":\s*([+-]?\d+(?:\.\d+)?|[A-Za-z]+)\s*$", text) 116 | if m4: 117 | answer = m4.group(1).strip() 118 | logger.debug(f"Heuristically extracted answer (Colon End): '{answer}'") 119 | return answer 120 | if re.fullmatch(r"\$?[+-]?\d+(?:\.\d+)?", text): 121 | answer = text.lstrip("$") 122 | logger.debug(f"Heuristically extracted answer (Numeric Line): '{answer}'") 123 | return answer 124 | fallback_answer = lines[-1].strip() if lines else "" 125 | logger.debug(f"Heuristically extracted answer (Fallback): '{fallback_answer}'") 126 | return fallback_answer 127 | 128 | def _extract_answer_json(self, cot: str, **kwargs: Any) -> str: 129 | """Extracts the final answer using an LLM call with JSON parsing. 130 | 131 | Uses the `answer_extraction_prompt` and expects a JSON response matching 132 | the `ExtractedAnswer` schema. Falls back to heuristic extraction on failure. 133 | 134 | Args: 135 | cot: The Chain-of-Thought reasoning string. 136 | **kwargs: Additional arguments passed to the LLM `generate_json` call. 137 | 138 | Returns: 139 | The extracted answer string. 140 | """ 141 | if not self.answer_extraction_prompt: 142 | logger.warning("JSON extraction requested but prompt is not configured.") 143 | return self._extract_answer_heuristic(cot) 144 | 145 | prompt = self.answer_extraction_prompt.format(cot=cot) 146 | logger.debug("Attempting JSON extraction with prompt:\n%s", prompt) 147 | try: 148 | local_kwargs = kwargs.copy() 149 | result = self.llm.generate_json( 150 | prompt, 151 | response_model=ExtractedAnswer, 152 | max_tokens=local_kwargs.pop("max_tokens", self.max_tokens), 153 | seed=local_kwargs.pop("seed", self.seed), 154 | **local_kwargs, 155 | ) 156 | answer = str(result.final_answer).strip() 157 | logger.debug(f"JSON extracted answer: '{answer}'") 158 | return answer 159 | except Exception as e: 160 | logger.error("JSON extraction failed: %s", e, exc_info=True) 161 | logger.warning("JSON extraction failed, falling back to heuristic.") 162 | return self._extract_answer_heuristic(cot) 163 | 164 | async def _extract_answer_json_async(self, cot: str, **kwargs: Any) -> str: 165 | """Asynchronously extracts the final answer using an LLM call with JSON parsing. 166 | 167 | Similar to `_extract_answer_json` but uses async LLM calls. 168 | 169 | Args: 170 | cot: The Chain-of-Thought reasoning string. 171 | **kwargs: Additional arguments passed to the async LLM `generate_json_async` call. 172 | 173 | Returns: 174 | The extracted answer string. 175 | """ 176 | if not self.answer_extraction_prompt: 177 | logger.warning("Async JSON extraction requested but prompt is not configured.") 178 | return self._extract_answer_heuristic(cot) 179 | 180 | prompt = self.answer_extraction_prompt.format(cot=cot) 181 | logger.debug("Attempting async JSON extraction with prompt:\n%s", prompt) 182 | try: 183 | local_kwargs = kwargs.copy() 184 | result = await self.llm.generate_json_async( 185 | prompt, 186 | response_model=ExtractedAnswer, 187 | max_tokens=local_kwargs.pop("max_tokens", self.max_tokens), 188 | seed=local_kwargs.pop("seed", self.seed), 189 | **local_kwargs, 190 | ) 191 | answer = str(result.final_answer).strip() 192 | logger.debug(f"Async JSON extracted answer: '{answer}'") 193 | return answer 194 | except Exception as e: 195 | logger.error("Async JSON extraction failed: %s", e, exc_info=True) 196 | logger.warning("Async JSON extraction failed, falling back to heuristic.") 197 | return self._extract_answer_heuristic(cot) 198 | 199 | def extract_answer(self, cot: str, **kwargs: Any) -> str: 200 | """Extracts the final answer from a CoT string based on the configured method. 201 | 202 | Delegates to either `_extract_answer_heuristic` or `_extract_answer_json`. 203 | 204 | Args: 205 | cot: The Chain-of-Thought reasoning string. 206 | **kwargs: Arguments passed to the underlying extraction method (if JSON). 207 | 208 | Returns: 209 | The extracted answer string. 210 | """ 211 | if self.internal_extraction_format == "json": 212 | return self._extract_answer_json(cot, **kwargs) 213 | return self._extract_answer_heuristic(cot) 214 | 215 | async def extract_answer_async(self, cot: str, **kwargs: Any) -> str: 216 | """Asynchronously extracts the final answer based on the configured method. 217 | 218 | Delegates to `_extract_answer_heuristic` or `_extract_answer_json_async`. 219 | 220 | Args: 221 | cot: The Chain-of-Thought reasoning string. 222 | **kwargs: Arguments passed to the underlying async extraction method (if JSON). 223 | 224 | Returns: 225 | The extracted answer string. 226 | """ 227 | if self.internal_extraction_format == "json": 228 | return await self._extract_answer_json_async(cot, **kwargs) 229 | return self._extract_answer_heuristic(cot) 230 | 231 | def run(self, prompt: str, **kwargs: Any) -> str: 232 | """Executes the Self-Consistency strategy. 233 | 234 | Generates `n_samples` reasoning paths using the LLM with the specified 235 | temperature. Extracts the final answer from each path and returns the 236 | most frequent answer (majority vote). 237 | 238 | Args: 239 | prompt: The input prompt for the LLM. 240 | **kwargs: Additional arguments passed to the LLM generation and 241 | answer extraction calls. 242 | 243 | Returns: 244 | The most consistent answer string among the generated paths, or an 245 | empty string if no valid answers are generated. 246 | """ 247 | answers: List[str] = [] 248 | combined_kwargs = {**self.gen_kwargs, **kwargs} 249 | 250 | for i in range(self.n_samples): 251 | try: 252 | iter_seed = (self.seed + i) if self.seed is not None else None 253 | current_gen_kwargs = combined_kwargs.copy() 254 | cot = self.llm.generate( 255 | prompt, 256 | temperature=current_gen_kwargs.pop("temperature", self.temperature), 257 | max_tokens=current_gen_kwargs.pop("max_tokens", self.max_tokens), 258 | stop=current_gen_kwargs.pop("stop", self.stop), 259 | seed=iter_seed, 260 | **current_gen_kwargs, 261 | ) 262 | logger.debug(f"Raw CoT sample {i}: {cot}") 263 | ans = self.extract_answer(cot, **kwargs) 264 | if ans: 265 | answers.append(ans) 266 | else: 267 | logger.debug(f"Sample {i} produced empty answer after extraction.") 268 | except Exception as e: 269 | logger.error(f"Error during SC sample {i}: {e}", exc_info=True) 270 | 271 | if not answers: 272 | logger.warning("SelfConsistency generated no valid answers.") 273 | return "" 274 | 275 | try: 276 | count = Counter(answers) 277 | top_answer, _ = count.most_common(1)[0] 278 | logger.debug(f"SelfConsistency vote counts: {count}") 279 | return top_answer 280 | except IndexError: 281 | logger.error("Could not determine most common answer despite having answers.") 282 | return "" 283 | 284 | async def run_async( 285 | self, prompt: str, semaphore: Optional[asyncio.Semaphore] = None, **kwargs: Any 286 | ) -> str: 287 | """Asynchronously executes the Self-Consistency strategy. 288 | 289 | Generates `n_samples` reasoning paths concurrently using async LLM calls. 290 | Extracts answers asynchronously and returns the majority vote answer. 291 | 292 | Args: 293 | prompt: The input prompt for the LLM. 294 | semaphore: Optional asyncio.Semaphore to limit concurrent LLM calls. 295 | **kwargs: Additional arguments passed to the async LLM generation and 296 | answer extraction calls. 297 | 298 | Returns: 299 | The most consistent answer string, or an empty string if none are generated. 300 | """ 301 | combined_kwargs = {**self.gen_kwargs, **kwargs} 302 | 303 | async def sample(i: int) -> Optional[str]: 304 | sample_kwargs = combined_kwargs.copy() 305 | iter_seed = (self.seed + i) if self.seed is not None else None 306 | gen_args = { 307 | "temperature": sample_kwargs.pop("temperature", self.temperature), 308 | "max_tokens": sample_kwargs.pop("max_tokens", self.max_tokens), 309 | "stop": sample_kwargs.pop("stop", self.stop), 310 | "seed": iter_seed, 311 | **sample_kwargs, 312 | } 313 | extraction_kwargs = kwargs.copy() 314 | 315 | task_semaphore = semaphore 316 | if task_semaphore: 317 | await task_semaphore.acquire() 318 | try: 319 | cot = await self.llm.generate_async(prompt, **gen_args) 320 | logger.debug(f"Raw async CoT sample {i}: {cot}") 321 | ans = await self.extract_answer_async(cot, **extraction_kwargs) 322 | if not ans: 323 | logger.debug(f"Async sample {i} produced empty answer after extraction.") 324 | return ans 325 | except Exception as e: 326 | logger.error(f"Error during async SC sample {i}: {e}", exc_info=True) 327 | return None 328 | finally: 329 | if task_semaphore: 330 | task_semaphore.release() 331 | 332 | results = await asyncio.gather(*(sample(i) for i in range(self.n_samples))) 333 | answers = [a for a in results if a is not None and a != ""] 334 | if not answers: 335 | logger.warning("SelfConsistency (async) generated no valid answers.") 336 | return "" 337 | 338 | try: 339 | count = Counter(answers) 340 | top_answer, _ = count.most_common(1)[0] 341 | logger.debug(f"SelfConsistency async vote counts: {count}") 342 | return top_answer 343 | except IndexError: 344 | logger.error("Could not determine most common async answer despite having answers.") 345 | return "" 346 | 347 | def run_stream(self, prompt: str) -> Iterator[str]: 348 | """Streaming is not supported for Self-Consistency.""" 349 | raise NotImplementedError("Streaming not supported for SelfConsistency.") 350 | 351 | async def run_stream_async(self, prompt: str) -> AsyncIterator[str]: 352 | """Streaming is not supported for Self-Consistency.""" 353 | raise NotImplementedError("Streaming not supported for SelfConsistency.") 354 | -------------------------------------------------------------------------------- /cogitator/utils.py: -------------------------------------------------------------------------------- 1 | """Provides utility functions used across the Cogitator library.""" 2 | 3 | import re 4 | 5 | 6 | def count_steps(cot: str) -> int: 7 | """Counts the number of reasoning steps in a Chain-of-Thought string. 8 | 9 | Identifies steps based on lines starting with digits followed by a period/paren 10 | or lines starting with list markers like -, *, •. 11 | 12 | Args: 13 | cot: The string containing the Chain-of-Thought reasoning. 14 | 15 | Returns: 16 | The integer count of identified reasoning steps. 17 | """ 18 | return sum(1 for line in cot.splitlines() if re.match(r"^(\d+[\.\)]|[-*•])\s+", line.strip())) 19 | 20 | 21 | def approx_token_length(text: str) -> int: 22 | """Approximates the number of tokens in a string. 23 | 24 | Counts sequences of word characters and any non-whitespace, non-word characters 25 | as separate tokens. This provides a rough estimate, not a precise token count 26 | based on a specific tokenizer model. 27 | 28 | Args: 29 | text: The input string. 30 | 31 | Returns: 32 | An approximate integer count of tokens. 33 | """ 34 | return len(re.findall(r"\w+|[^\w\s]", text)) 35 | 36 | 37 | def exact_match(pred: str, gold: str) -> bool: 38 | """Performs case-insensitive exact matching between two strings. 39 | 40 | Strips leading/trailing whitespace and converts both strings to lowercase 41 | before comparison. 42 | 43 | Args: 44 | pred: The predicted string. 45 | gold: The ground truth (gold standard) string. 46 | 47 | Returns: 48 | True if the normalized strings are identical, False otherwise. 49 | """ 50 | return pred.strip().lower() == gold.strip().lower() 51 | 52 | 53 | def accuracy(preds: list[str], golds: list[str]) -> float: 54 | """Calculates the exact match accuracy between lists of predictions and golds. 55 | 56 | Uses the `exact_match` function for comparison. Handles potential differences 57 | in list lengths by iterating up to the length of the shorter list if `strict=False` 58 | (default in zip). If `golds` is empty, returns 0.0. 59 | 60 | Args: 61 | preds: A list of predicted strings. 62 | golds: A list of ground truth strings. 63 | 64 | Returns: 65 | The accuracy score as a float between 0.0 and 1.0. 66 | """ 67 | if not golds: 68 | return 0.0 69 | matches = sum(exact_match(p, g) for p, g in zip(preds, golds, strict=False)) 70 | return matches / len(golds) 71 | -------------------------------------------------------------------------------- /docs/api/clustering.md: -------------------------------------------------------------------------------- 1 | # Clustering Module 2 | 3 | This module defines abstractions and implementations for clustering algorithms used by a few of the strategies. 4 | 5 | ::: cogitator.clustering.BaseClusterer 6 | options: 7 | show_root_heading: false 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | 12 | ::: cogitator.clustering.KMeansClusterer 13 | options: 14 | show_root_heading: false 15 | show_source: true 16 | members_order: source 17 | heading_level: 2 18 | -------------------------------------------------------------------------------- /docs/api/embedding.md: -------------------------------------------------------------------------------- 1 | # Text Embedding Module 2 | 3 | This module provides abstractions and implementations for text embedding models. 4 | 5 | ::: cogitator.embedding.BaseEmbedder 6 | options: 7 | show_root_heading: false 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | 12 | ::: cogitator.embedding.SentenceTransformerEmbedder 13 | options: 14 | show_root_heading: false 15 | show_source: true 16 | members_order: source 17 | heading_level: 2 18 | -------------------------------------------------------------------------------- /docs/api/model/base.md: -------------------------------------------------------------------------------- 1 | # LLM Provider Interface 2 | 3 | This module defines the abstract base class for all LLM providers. 4 | 5 | ::: cogitator.model.base.BaseLLM 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/model/ollama.md: -------------------------------------------------------------------------------- 1 | # Ollama LLM Implementation 2 | 3 | Implementation of the BaseLLM interface for Ollama models. 4 | 5 | ::: cogitator.model.ollama.OllamaLLM 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/model/openai.md: -------------------------------------------------------------------------------- 1 | # OpenAI LLM Implementation 2 | 3 | Implementation of the BaseLLM interface for OpenAI models. 4 | 5 | ::: cogitator.model.openai.OpenAILLM 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/schemas.md: -------------------------------------------------------------------------------- 1 | # Schemas for Structured Data 2 | 3 | This module defines Pydantic models used for structuring and validating intermediate and final outputs from LLMs. 4 | 5 | ::: cogitator.schemas.LTMDecomposition 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | 12 | ::: cogitator.schemas.ThoughtExpansion 13 | options: 14 | show_root_heading: true 15 | show_source: true 16 | members_order: source 17 | heading_level: 2 18 | 19 | ::: cogitator.schemas.EvaluationResult 20 | options: 21 | show_root_heading: true 22 | show_source: true 23 | members_order: source 24 | heading_level: 2 25 | 26 | ::: cogitator.schemas.ExtractedAnswer 27 | options: 28 | show_root_heading: true 29 | show_source: true 30 | members_order: source 31 | heading_level: 2 32 | -------------------------------------------------------------------------------- /docs/api/strategies/auto_cot.md: -------------------------------------------------------------------------------- 1 | # Automatic Chain-of-Thought 2 | 3 | An implementation of the automatic chain-of-thought (CoT) prompting strategy from [this paper](https://arxiv.org/abs/2210.03493). 4 | 5 | ::: cogitator.strategies.auto_cot.AutoCoT 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/strategies/cdw_cot.md: -------------------------------------------------------------------------------- 1 | # Clustered Distance-Weighted Chain-of-Thought 2 | 3 | An implementation of the clustered distance-weighted CoT framework from [this paper](https://arxiv.org/abs/2501.12226). 4 | 5 | ::: cogitator.strategies.cdw_cot.CDWCoT 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/strategies/graph_of_thoughts.md: -------------------------------------------------------------------------------- 1 | # Graph of Thoughts Framework 2 | 3 | An implementation of the Graph of Thoughts (GoT) reasoning framework from [this paper](https://arxiv.org/abs/2308.09687). 4 | 5 | The implementation represents the reasoning process as a graph where nodes are thoughts and edges represent transformations. 6 | The flow of reasoning is controlled by a **Graph of Operations (GoO)**. 7 | 8 | ## Defining the Graph of Operations (GoO) 9 | 10 | To use `GraphOfThoughts`, you must provide a `graph_of_operations` argument to the `run_async` method. 11 | This argument is a list of tuples, where each tuple defines an operation step: 12 | 13 | `graph_of_operations: List[Tuple[str, Dict]]` 14 | 15 | * The first element of the tuple is the **name** of the operation (e.g., `'Generate'`, `'Score'`, `'KeepBest'`). 16 | * The second element is a **dictionary** containing the parameters specific to that operation. 17 | 18 | **Example GoO:** 19 | 20 | ```python 21 | from cogitator import ThoughtExpansion 22 | 23 | EXAMPLE_GOO = [ 24 | # Step 1: Generate 3 new thoughts from the initial question (in 'frontier' set) 25 | # Store results in the 'generated_thoughts' set. Use the 'expand' prompt. Expect ThoughtExpansion schema. 26 | ('Generate', {'k': 3, 'target_set': 'frontier', 'output_set': 'generated_thoughts', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}), 27 | 28 | # Step 2: Score the thoughts generated in the previous step. Use 'evaluate' prompt. 29 | ('Score', {'target_set': 'generated_thoughts', 'prompt_key': 'evaluate'}), 30 | 31 | # Step 3: Keep only the single best-scoring thought from the previous step. 32 | # Put the result back into the 'frontier' set for potential further steps or final answer generation. 33 | ('KeepBest', {'N': 1, 'target_set': 'generated_thoughts', 'output_set': 'frontier'}) 34 | ] 35 | ``` 36 | 37 | ## Main Class (`GraphOfThoughts`) 38 | 39 | ::: cogitator.strategies.graph_of_thoughts.GraphOfThoughts 40 | options: 41 | show_root_heading: true 42 | show_source: true 43 | members_order: source 44 | heading_level: 3 45 | 46 | ## Available Operations 47 | 48 | Here are the standard operations available. 49 | You can create custom operations by subclassing `GoTOperation`. 50 | 51 | ### Base Operation Class 52 | 53 | ::: cogitator.strategies.graph_of_thoughts.GoTOperation 54 | options: 55 | show_root_heading: true 56 | show_source: false 57 | members_order: source 58 | heading_level: 3 59 | 60 | ### Generate Operation 61 | 62 | ::: cogitator.strategies.graph_of_thoughts.GenerateOp 63 | options: 64 | show_root_heading: true 65 | show_source: false 66 | members_order: source 67 | heading_level: 3 68 | 69 | ### Score Operation 70 | 71 | ::: cogitator.strategies.graph_of_thoughts.ScoreOp 72 | options: 73 | show_root_heading: true 74 | show_source: false 75 | members_order: source 76 | heading_level: 3 77 | 78 | ### KeepBest Operation 79 | 80 | ::: cogitator.strategies.graph_of_thoughts.KeepBestOp 81 | options: 82 | show_root_heading: true 83 | show_source: false 84 | members_order: source 85 | heading_level: 3 86 | 87 | ### Aggregate Operation 88 | 89 | ::: cogitator.strategies.graph_of_thoughts.AggregateOp 90 | options: 91 | show_root_heading: true 92 | show_source: false 93 | members_order: source 94 | heading_level: 3 95 | 96 | ## Internal State (Advanced) 97 | 98 | These classes manage the internal graph structure. 99 | 100 | ### GoTNode 101 | 102 | ::: cogitator.strategies.graph_of_thoughts.GoTNode 103 | options: 104 | show_root_heading: true 105 | show_source: false 106 | members: ["__init__"] 107 | heading_level: 3 108 | 109 | ### GraphReasoningState 110 | 111 | ::: cogitator.strategies.graph_of_thoughts.GraphReasoningState 112 | options: 113 | show_root_heading: true 114 | show_source: false 115 | heading_level: 3 116 | -------------------------------------------------------------------------------- /docs/api/strategies/least_to_most.md: -------------------------------------------------------------------------------- 1 | # Least-to-Most Prompting 2 | 3 | An implementation of the least-to-most prompting strategy from [this paper](https://arxiv.org/abs/2205.10625). 4 | 5 | ::: cogitator.strategies.least_to_most.LeastToMost 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/strategies/sc_cot.md: -------------------------------------------------------------------------------- 1 | # Self-Consistency Prompting 2 | 3 | An implementation of the self-consistency prompting strategy from [this paper](https://arxiv.org/abs/2003.04933). 4 | 5 | ::: cogitator.strategies.sc_cot.SelfConsistency 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/strategies/tree_of_thoughts.md: -------------------------------------------------------------------------------- 1 | # Tree of Thoughts 2 | 3 | An implementation of the tree of thoughts CoT framework from [this paper](https://arxiv.org/abs/2305.10601). 4 | 5 | ::: cogitator.strategies.tree_of_thoughts.TreeOfThoughts 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | -------------------------------------------------------------------------------- /docs/api/utils.md: -------------------------------------------------------------------------------- 1 | # Utility Functions 2 | 3 | This module provides various utility functions used throughout the Cogitator library, such as metrics calculation and text processing helpers. 4 | 5 | ::: cogitator.utils.count_steps 6 | options: 7 | show_root_heading: true 8 | show_source: true 9 | members_order: source 10 | heading_level: 2 11 | 12 | ::: cogitator.utils.approx_token_length 13 | options: 14 | show_root_heading: true 15 | show_source: true 16 | members_order: source 17 | heading_level: 2 18 | 19 | ::: cogitator.utils.exact_match 20 | options: 21 | show_root_heading: true 22 | show_source: true 23 | members_order: source 24 | heading_level: 2 25 | 26 | ::: cogitator.utils.accuracy 27 | options: 28 | show_root_heading: true 29 | show_source: true 30 | members_order: source 31 | heading_level: 2 32 | -------------------------------------------------------------------------------- /docs/assets/images/cogitator_v1.dot: -------------------------------------------------------------------------------- 1 | digraph CogitatorWorkflow { 2 | fontname = "Helvetica,Arial,sans-serif" 3 | layout = dot 4 | rankdir = TB // Top-to-Bottom workflow layout 5 | node [ 6 | fontname = "Helvetica,Arial,sans-serif", 7 | shape = box, 8 | style = "filled,rounded", 9 | color = "grey", 10 | fillcolor = "white", 11 | penwidth = 2 12 | ] 13 | edge [ 14 | fontname = "Helvetica,Arial,sans-serif", 15 | color = "black" 16 | ] 17 | 18 | // Cluster: User Inputs 19 | subgraph cluster_input { 20 | label = "Inputs" 21 | style = "dashed" 22 | color = "lightgrey" 23 | question [label = "Question / Prompt", fillcolor = "lightyellow"] 24 | strategy_choice [label = "Strategy Choice\n(e.g., SC, LtM, GoT, AutoCoT)", fillcolor = "lightyellow"] 25 | llm_choice [label = "LLM Choice\n(Provider, Model Name)", fillcolor = "lightyellow"] 26 | training_data [label = "Training Data\n(Optional: Questions, Answers)", fillcolor = "lightyellow", shape = note] 27 | } 28 | 29 | // Cluster: Cogitator Library Core Components 30 | subgraph cluster_core { 31 | label = "Cogitator Library" 32 | style = "dashed" 33 | color = "lightgrey" 34 | strategy [label = "Selected CoT Strategy\n(e.g., AutoCoT instance)", fillcolor = "lightblue"] 35 | llm_interface [label = "LLM Interface\n(BaseLLM: OpenAI/Ollama)", fillcolor = "lightblue"] 36 | schemas [label = "Pydantic Schemas\n(Structured Output Validation)", fillcolor = "lightgrey", shape = component] 37 | embedding [label = "Embedding Model\n(Optional Usage)", fillcolor = "lightblue", shape = component] 38 | clustering [label = "Clustering Algorithm\n(Optional Usage)", fillcolor = "lightblue", shape = component] 39 | extraction [label = "Answer Extraction Logic\n(Heuristic / LLM-based)", fillcolor = "lightblue", shape = component] 40 | } 41 | 42 | // Cluster: External Dependencies / Services 43 | subgraph cluster_external { 44 | label = "External Services / Models" 45 | style = "dashed" 46 | color = "lightgrey" 47 | llm_backend [label = "LLM Backend\n(OpenAI API / Ollama Server)", fillcolor = "lightpink"] 48 | embedding_backend [label = "Embedding Backend\n(e.g., Sentence Transformers Lib)", fillcolor = "lightpink", shape = cylinder] // Representing the underlying model/lib 49 | } 50 | 51 | // Cluster: Final Output 52 | subgraph cluster_output { 53 | label = "Output" 54 | style = "dashed" 55 | color = "lightgrey" 56 | final_answer [label = "Final Answer / Result", fillcolor = "lightgreen"] 57 | } 58 | 59 | // --- Edges Defining the Flow --- 60 | 61 | // Inputs to Initialization 62 | question -> strategy [label = "is main input to"] 63 | strategy_choice -> strategy [label = "determines instance of"] 64 | llm_choice -> llm_interface [label = "configures"] 65 | training_data -> strategy [label = "used by some for fit/train\n(e.g., AutoCoT, CDWCoT)", style = dashed] 66 | 67 | // Strategy Orchestration 68 | strategy -> llm_interface [label = "makes calls via"] 69 | strategy -> schemas [label = "uses for JSON modes\n(LtM, ToT, GoT, SC, etc.)", style = dashed] 70 | strategy -> embedding [label = "uses sometimes\n(AutoCoT, CDWCoT, GoT)", style = dashed] 71 | strategy -> clustering [label = "uses sometimes\n(AutoCoT, CDWCoT)", style = dashed] 72 | strategy -> extraction [label = "uses sometimes\n(SC, LtM, GoT)", style = dashed] 73 | 74 | // LLM Interaction 75 | llm_interface -> llm_backend [label = "communicates with"] 76 | llm_backend -> llm_interface [label = "returns generation to"] 77 | llm_interface -> strategy [label = "provides results to"] 78 | 79 | // Embedding Interaction (Optional Path) 80 | embedding -> embedding_backend [label = "wraps / uses"] 81 | embedding_backend -> embedding [label = "provides embeddings"] 82 | 83 | // Extraction Interaction (Optional Path) 84 | extraction -> llm_interface [label = "can call LLM for extraction", style = dotted] 85 | 86 | // Final Output 87 | strategy -> final_answer [label = "produces"] 88 | 89 | // Optional: Ranking hints if needed (often not necessary with TB layout) 90 | // { rank=same; question; strategy_choice; llm_choice; training_data } 91 | // { rank=same; llm_backend; embedding_backend } 92 | } 93 | -------------------------------------------------------------------------------- /docs/assets/images/cogitator_v2.dot: -------------------------------------------------------------------------------- 1 | digraph SimplifiedCogitatorWorkflow { 2 | fontname = "Helvetica,Arial,sans-serif" 3 | layout = dot 4 | rankdir = LR 5 | ranksep = 0.9; 6 | nodesep = 0.7; 7 | splines = true; 8 | compound = true; 9 | 10 | node [ 11 | fontname = "Helvetica,Arial,sans-serif", 12 | shape = box, 13 | style = "filled,rounded", 14 | color = "grey", 15 | fillcolor = "white", 16 | penwidth = 1 17 | ] 18 | edge [ 19 | fontname = "Helvetica,Arial,sans-serif", 20 | color = "black", 21 | fontsize = 8, 22 | labeldistance = 2.0 23 | ] 24 | 25 | subgraph cluster_input { 26 | label = "Inputs" 27 | style = "dashed" 28 | color = "lightgrey" 29 | margin = 18 30 | question [label = "1. Question / Prompt", fillcolor = "oldlace"] 31 | config [label = "2. Configuration\n(Strategy Choice, LLM Choice)", fillcolor = "oldlace"] 32 | } 33 | 34 | subgraph cluster_core { 35 | label = "Cogitator" 36 | style = "dashed" 37 | color = "lightgrey" 38 | margin = 18 39 | strategy [label = < 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
3. Selected CoT Strategy
Orchestrates steps:
- Prompt Formatting
- LLM Calls
- Intermediate Processing
(Decomposition, Expansion,
Evaluation, Extraction,
Embedding, Clustering...)
>, fillcolor ="lightblue", shape = box] 49 | } 50 | 51 | subgraph cluster_external { 52 | label = "LLM Service" 53 | style = "dashed" 54 | color = "lightgrey" 55 | margin = 18 56 | llm [label = "4. Model Provider\n(e.g., OpenAI API / Ollama)", fillcolor ="oldlace"] 57 | } 58 | 59 | subgraph cluster_output { 60 | label = "Output" 61 | style = "dashed" 62 | color = "lightgrey" 63 | margin = 18 64 | final_answer [label = "5. Final Answer", fillcolor = "oldlace"] 65 | } 66 | 67 | question -> strategy [lhead = cluster_core] 68 | config -> strategy [lhead = cluster_core] 69 | config -> llm [lhead= cluster_external] 70 | 71 | strategy -> llm [minlen = 2] 72 | llm -> strategy [minlen = 2] 73 | 74 | strategy -> final_answer [lhead = cluster_output, ltail = cluster_core, minlen = 2] 75 | 76 | } 77 | -------------------------------------------------------------------------------- /docs/assets/images/cogitator_v2.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | SimplifiedCogitatorWorkflow 11 | 12 | 13 | cluster_input 14 | 15 | Inputs 16 | 17 | 18 | cluster_core 19 | 20 | Cogitator 21 | 22 | 23 | cluster_external 24 | 25 | LLM Service 26 | 27 | 28 | cluster_output 29 | 30 | Output 31 | 32 | 33 | 34 | question 35 | 36 | 1. Question / Prompt 37 | 38 | 39 | 40 | strategy 41 | 42 | 3. Selected CoT Strategy 43 | Orchestrates steps: 44 | - Prompt Formatting 45 | - LLM Calls 46 | - Intermediate Processing 47 |  (Decomposition, Expansion, 48 |   Evaluation, Extraction, 49 |   Embedding, Clustering...) 50 | 51 | 52 | 53 | question->strategy 54 | 55 | 56 | 57 | 58 | 59 | config 60 | 61 | 2. Configuration 62 | (Strategy Choice, LLM Choice) 63 | 64 | 65 | 66 | config->strategy 67 | 68 | 69 | 70 | 71 | 72 | llm 73 | 74 | 4. Model Provider 75 | (e.g., OpenAI API / Ollama) 76 | 77 | 78 | 79 | config->llm 80 | 81 | 82 | 83 | 84 | 85 | strategy->llm 86 | 87 | 88 | 89 | 90 | 91 | final_answer 92 | 93 | 5. Final Answer 94 | 95 | 96 | 97 | strategy->final_answer 98 | 99 | 100 | 101 | 102 | 103 | llm->strategy 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /docs/assets/images/make_figures.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # You need to have Graphviz installed to run this script 4 | # On Debian-based OSes, you can install it using: sudo apt-get install graphviz 5 | 6 | # Directory containing .dot files (with default value) 7 | ASSET_DIR=${1:-"."} 8 | 9 | # Make figures from .dot files 10 | for f in "${ASSET_DIR}"/*.dot; do 11 | dot -Tsvg "$f" -o "${f%.dot}.svg" 12 | done 13 | -------------------------------------------------------------------------------- /docs/benchmarking.md: -------------------------------------------------------------------------------- 1 | ## Benchmarking Framework 2 | 3 | Cogitator includes a framework for running benchmarks to compare the performance of different Chain-of-Thought methods 4 | on various datasets. 5 | 6 | ### Overview 7 | 8 | The framework consists of two main scripts: 9 | 10 | 1. **`benches/run.py`**: Handles the generation phase. It loads datasets, configures LLMs and CoT strategies based on 11 | `benches.yml` and CLI arguments, runs the strategies on dataset questions, and saves the raw LLM outputs and timings to 12 | a JSONL file. 13 | 2. **`benches/evaluate.py`**: Handles the evaluation phase. It reads the results file generated by `run.py`, extracts 14 | final answers from the raw outputs (using either heuristics or an LLM), compares them to gold answers, and calculates 15 | accuracy metrics for each method. 16 | 17 | ### Configuration 18 | 19 | Benchmarks are configured using the `benches.yml` file located in the project root, which can be overridden by 20 | command-line arguments passed to the scripts. 21 | 22 | ### How to Run 23 | 24 | Detailed instructions on how to configure and run the benchmarks are available in the benchmarking README 25 | ([`benches/README.md`](https://github.com/habedi/cogitator/blob/main/benches/README.md)). 26 | This includes: 27 | 28 | * Command-line options for `run.py` and `evaluate.py`. 29 | * Detailed explanation of the `benches.yml` configuration structure. 30 | * The benchmark workflow. 31 | * Example usage commands. 32 | * Available datasets and dependencies. 33 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | Thank you for considering contributing to Cogitator! 4 | We welcome contributions from the community, whether it is code, documentation, or feedback. 5 | Please refer to the main contribution guidelines document 6 | ([`CONTRIBUTING.md`)](https://github.com/habedi/cogitator/blob/main/CONTRIBUTING.md]) 7 | in the repository's root directory for details on how to contribute, report bugs, suggest features, and the development workflow. 8 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Cogitator Documentation 2 | 3 | ## Overview 4 | Cogitator is a Python toolkit for experimenting and working with 5 | [chain-of-thought (CoT) prompting](https://arxiv.org/abs/2201.11903) 6 | methods in large language models (LLMs). 7 | CoT prompting improves LLM performance on complex tasks (like question-answering, reasoning, and problem-solving) 8 | by guiding the models to generate intermediate reasoning steps before arriving at the final answer. 9 | Additionally, it can be used to improve the interpretability of LLMs by providing insight into the model's reasoning process. 10 | The toolkit aims to make it easier to use popular CoT strategies and frameworks for research or integrating them into AI 11 | applications. 12 | 13 | ### Features 14 | 15 | * Provides unified sync/async API for CoT strategies 16 | * Supports using OpenAI and Ollama as LLM providers 17 | * Supports structured model output with Pydantic validation 18 | * Includes a customizable benchmarking framework (see [benches](https://github.com/habedi/cogitator/blob/main/benches)) 19 | * Includes implementations of popular CoT strategies and frameworks like 20 | - [Self-Consistency CoT (ICLR 2023)](https://arxiv.org/abs/2203.11171) 21 | - [Automatic CoT (ICLR 2023)](https://arxiv.org/abs/2210.03493) 22 | - [Least-to-Most Prompting (ICLR 2023)](https://arxiv.org/abs/2205.10625) 23 | - [Tree of Thoughts (NeurIPS 2023)](https://arxiv.org/abs/2305.10601) 24 | - [Graph of Thoughts (AAAI 2024)](https://arxiv.org/abs/2308.09687) 25 | - [Clustered Distance-Weighted CoT (AAAI 2025)](https://arxiv.org/abs/2501.12226) 26 | 27 | The diagram below shows a high-level overview of Cogitator's workflow. 28 | 29 | ![Cogitator Architecture](assets/images/cogitator_v2.svg) 30 | 31 | ## Installation 32 | 33 | Cogitator can be installed via pip using the following command: 34 | 35 | ```bash 36 | pip install cogitator 37 | ``` 38 | 39 | To run the unit tests and benchmarks, development dependencies are needed that can be installed with the following command: 40 | 41 | ```bash 42 | git clone https://github.com/habedi/cogitator && cd cogitator 43 | 44 | # Set up Python environment 45 | pip install poetry 46 | poetry install --with dev 47 | 48 | # Run the tests to make sure everything is working (optional) 49 | poetry run pytest 50 | ``` 51 | 52 | ## Examples 53 | 54 | Check the [examples](https://github.com/habedi/cogitator/blob/main/examples) for usage examples on how to use the library with 55 | different LLM providers and CoT strategies. 56 | 57 | ## API Reference 58 | 59 | The Cogitator library's functionality is organized into several modules: 60 | 61 | * **LLM Providers (`cogitator.model`)** 62 | * [`BaseLLM`](api/model/base.md): Base LLM provider class that defines a common interface for all providers. 63 | * [`OpenAILLM`](api/model/openai.md): LLM provider implementation for using OpenAI models (like gpt-4o-mini and gpt-4o). 64 | * [`OllamaLLM`](api/model/ollama.md): LLM provider implementation for using Ollama models (like Llama, Gemma, and Qwen). 65 | 66 | * **CoT Strategies (`cogitator.strategies`)** 67 | * [`AutoCoT`](api/strategies/auto_cot.md): An implementation of the automatic CoT prompting strategy. 68 | * [`CDWCoT`](api/strategies/cdw_cot.md): An implementation of the clustered distance-weighted CoT framework. 69 | * [`GraphOfThoughts`](api/strategies/graph_of_thoughts.md): An implementation of the graph of thoughts CoT framework. 70 | * [`LeastToMost`](api/strategies/least_to_most.md): An implementation of the least-to-most prompting strategy. 71 | * [`SelfConsistency`](api/strategies/sc_cot.md): An implementation of the self-consistency prompting strategy. 72 | * [`TreeOfThoughts`](api/strategies/tree_of_thoughts.md): An implementation of the tree of thoughts CoT framework. 73 | 74 | * **Data Formatting and Validation (`cogitator.schemas`)** 75 | * [`Schemas`](api/schemas.md): A set of Pydantic models that are used for validating structure of outputs from LLMs. 76 | 77 | * **Utilities** 78 | * [`Embedding`](api/embedding.md): A set of tools for embedding prompt text which is used by strategies like AutoCoT and CDWCoT. 79 | * [`Clustering`](api/clustering.md): Includes clustering algorithms for grouping similar embeddings that is used during the training phase of strategies like AutoCoT and CDWCoT. 80 | * [`Functions`](api/utils.md): A set of utility functions for working with the library. 81 | 82 | ## Extra Resources 83 | 84 | * **[Benchmarking](benchmarking.md):** Learn how to configure and run the performance evaluation framework. 85 | * **[Contributing](contributing.md):** Find guidelines for contributing to the Cogitator project. 86 | 87 | 88 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Examples 2 | 3 | | File | Description | 4 | |------------------------------------------------------|--------------------------------------------------------------------------| 5 | | [run_simple_example.py](run_simple_example.py) | A simple end-to-end example of using the Cogitator library | 6 | | [run_least_to_most.py](run_least_to_most.py) | Example of using the Least-to-Most prompting strategy | 7 | | [run_sc_cot.py](run_sc_cot.py) | Example of using the Self-Consistency prompting strategy | 8 | | [run_auto_cot.py](run_auto_cot.py) | Example of using the Automatic CoT prompting strategy | 9 | | [run_tree_of_thoughts.py](run_tree_of_thoughts.py) | Example of using the Tree of Thoughts prompting framework | 10 | | [run_graph_of_thoughts.py](run_graph_of_thoughts.py) | Example of using the Graph of Thoughts prompting framework | 11 | | [run_cdw_cot.py](run_cdw_cot.py) | Example of using the Clustered Distance-Weighted CoT prompting framework | 12 | | [shared.py](shared.py) | Shared utilies and settings for the examples | 13 | 14 | ## Running Examples 15 | 16 | ```bash 17 | # Run the Least-to-Most example (OpenAI) 18 | python examples/run_least_to_most.py --provider openai --model-name gpt-4.1-nano 19 | ``` 20 | 21 | ```bash 22 | # Run the Self-Consistency example (Ollama) 23 | python examples/run_least_to_most.py --provider ollama --model-name gemma3:4b 24 | ``` 25 | 26 | ```bash 27 | # Run all examples (Ollama) 28 | make example-ollama 29 | ``` 30 | 31 | ```bash 32 | # Run all examples (OpenAI) 33 | make example-openai 34 | ``` 35 | 36 | Note that the examples should be run from the root directory of the repository. 37 | Additionally, to use `gemma3:4b` (or any other model like `gemma3:12b`) with Ollama, it must be pulled (or downloaded) first. 38 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/examples/__init__.py -------------------------------------------------------------------------------- /examples/run_auto_cot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import logging 5 | 6 | from cogitator import AutoCoT, BaseLLM 7 | from examples.shared import get_llm, run_main, setup_logging 8 | 9 | setup_logging() 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_auto_cot(llm: BaseLLM) -> AutoCoT: 14 | return AutoCoT( 15 | llm, 16 | n_demos=4, 17 | max_q_tokens=100, 18 | max_steps=8, 19 | max_retries=3, 20 | ) 21 | 22 | 23 | TRAIN_QUESTIONS = [ 24 | "A merchant had 10 apples. He sold 3. How many remain?", 25 | "There are 7 days in a week. How many days in 3 weeks?", 26 | "If you buy 4 pens at $2 each, what's the total cost?", 27 | "A car travels 60 km in 1 hour. How far in 2.5 hours?", 28 | "A rectangle is 3 by 5. What is its area?", 29 | "5 birds are on a wire. 2 fly away. How many left?", 30 | "A baker made 12 buns and packed 3 per box. How many boxes?", 31 | "You read 20 pages per day. How many pages in 5 days?", 32 | ] 33 | 34 | QUESTIONS = [ 35 | "John has 8 oranges and gives 3 away. How many does he have?", 36 | "You run 5 km per day. How far in 7 days?", 37 | ] 38 | 39 | 40 | async def main_async(args: argparse.Namespace): 41 | llm = get_llm(args.provider, args.model_name, args.openai_key) 42 | auto = setup_auto_cot(llm) 43 | semaphore = asyncio.Semaphore(5) 44 | 45 | logger.info("Fitting AutoCoT asynchronously...") 46 | await auto.fit_async(TRAIN_QUESTIONS, semaphore=semaphore) 47 | 48 | logger.info("Running test questions asynchronously...") 49 | tasks = [auto.run_async(q) for q in QUESTIONS] 50 | answers = await asyncio.gather(*tasks) 51 | 52 | for q, a in zip(QUESTIONS, answers): 53 | print(f"Q: {q}\nA: {a}\n") 54 | 55 | 56 | def main_sync(args: argparse.Namespace): 57 | llm = get_llm(args.provider, args.model_name, args.openai_key) 58 | auto = setup_auto_cot(llm) 59 | 60 | logger.info("Fitting AutoCoT synchronously...") 61 | auto.fit(TRAIN_QUESTIONS) 62 | 63 | logger.info("Running test questions synchronously...") 64 | for q in QUESTIONS: 65 | result = auto.run(q) 66 | print(f"Q: {q}\nA: {result}\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | run_main(main_sync, main_async, "Run Auto-CoT example") 71 | -------------------------------------------------------------------------------- /examples/run_cdw_cot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import logging 5 | 6 | from cogitator import BaseLLM, CDWCoT 7 | from examples.shared import get_llm, run_main, setup_logging 8 | 9 | setup_logging() 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_cdw_cot(llm: BaseLLM) -> CDWCoT: 14 | return CDWCoT(llm, pool_size=8, n_clusters=2, lr=0.1, sample_size=3) 15 | 16 | 17 | TRAIN_QUESTIONS = [ 18 | "A pot has 5 liters. You add 2 liters. How many liters now?", 19 | "If x+3=7, what is x?", 20 | "There are 4 pens in a box. How many pens in 3 boxes?", 21 | "You walk 2 km in 30 minutes. Distance in 1 hour?", 22 | "5 apples + 3 apples = ?", 23 | "Solve 2y = 10", 24 | "Area of 2x4 rectangle?", 25 | "Cost of 3 items at $5 each?", 26 | ] 27 | TRAIN_ANSWERS = ["7", "4", "12", "4", "8", "5", "8", "15"] 28 | 29 | QUESTIONS = ["If you have 3 boxes of 5 pens each, how many pens?", 30 | "Solve for y: y - 2 = 4"] 31 | 32 | 33 | async def main_async(args: argparse.Namespace): 34 | llm = get_llm(args.provider, args.model_name, args.openai_key) 35 | cdw = setup_cdw_cot(llm) 36 | semaphore = asyncio.Semaphore(5) 37 | 38 | logger.info("Initializing CDW-CoT pool asynchronously...") 39 | try: 40 | await cdw.init_pool_async(TRAIN_QUESTIONS, TRAIN_ANSWERS, semaphore=semaphore) 41 | logger.info("Training CDW-CoT asynchronously...") 42 | await cdw.train_async(val_split=0.4, epochs=5, patience=3, semaphore=semaphore) 43 | logger.info("Running test questions asynchronously...") 44 | tasks = [cdw.run_async(q, semaphore=semaphore) for q in QUESTIONS] 45 | answers = await asyncio.gather(*tasks) 46 | for q, a in zip(QUESTIONS, answers): 47 | print(f"Q: {q}\nA: {a}\n") 48 | except Exception as e: 49 | logger.error(f"CDW-CoT async example failed: {e}", exc_info=True) 50 | 51 | 52 | def main_sync(args: argparse.Namespace): 53 | llm = get_llm(args.provider, args.model_name, args.openai_key) 54 | cdw = setup_cdw_cot(llm) 55 | 56 | logger.info("Initializing CDW-CoT pool synchronously...") 57 | try: 58 | cdw.init_pool(TRAIN_QUESTIONS, TRAIN_ANSWERS) 59 | logger.info("Training CDW-CoT synchronously...") 60 | cdw.train(val_split=0.4, epochs=5, patience=3) 61 | logger.info("Running test questions synchronously...") 62 | for q in QUESTIONS: 63 | out = cdw.run(q) 64 | print(f"Q: {q}\nA: {out}\n") 65 | except Exception as e: 66 | logger.error(f"CDW-CoT sync example failed: {e}", exc_info=True) 67 | 68 | 69 | if __name__ == "__main__": 70 | run_main(main_sync, main_async, "Run CDW-CoT example") 71 | -------------------------------------------------------------------------------- /examples/run_graph_of_thoughts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import logging 5 | from typing import List, Dict, Tuple 6 | 7 | from cogitator import BaseLLM, GraphOfThoughts, ThoughtExpansion 8 | from examples.shared import get_llm, run_main, setup_logging 9 | 10 | setup_logging() 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def setup_got(llm: BaseLLM) -> GraphOfThoughts: 15 | return GraphOfThoughts( 16 | llm, 17 | final_answer_format="json", # "text" or "json" 18 | # embedder can be added here if needed by GoO 19 | ) 20 | 21 | 22 | # Define a sample Graph of Operations (GoO) 23 | # This defines the steps GoT will take. Adjust as needed for the problem. 24 | # Example: Generate thoughts, score them, keep the best one. 25 | EXAMPLE_GOO: List[Tuple[str, Dict]] = [ 26 | ('Generate', 27 | {'k': 3, 'target_set': 'frontier', 'output_set': 'generated_thoughts', 'prompt_key': 'expand', 28 | 'response_schema': ThoughtExpansion}), 29 | ('Score', {'target_set': 'generated_thoughts', 'prompt_key': 'evaluate'}), 30 | ('KeepBest', {'N': 1, 'target_set': 'generated_thoughts', 'output_set': 'frontier'}) 31 | # Keep the best node in the frontier for the final answer 32 | ] 33 | 34 | QUESTIONS = [ 35 | "A baker made 2 dozen cookies (24) and sold 8. How many left?", 36 | "If 7 times z equals 56, what is z?", 37 | ] 38 | 39 | 40 | async def main_async(args: argparse.Namespace): 41 | llm = get_llm(args.provider, args.model_name, args.openai_key) 42 | got = setup_got(llm) 43 | semaphore = asyncio.Semaphore(5) # Concurrency limit for LLM calls 44 | 45 | logger.info("Running GraphOfThoughts asynchronously...") 46 | tasks = [got.run_async(q, graph_of_operations=EXAMPLE_GOO, semaphore=semaphore) for q in 47 | QUESTIONS] 48 | answers = await asyncio.gather(*tasks) 49 | 50 | for q, a in zip(QUESTIONS, answers): 51 | print(f"Q: {q}\nA: {a}\n") 52 | 53 | 54 | def main_sync(args: argparse.Namespace): 55 | llm = get_llm(args.provider, args.model_name, args.openai_key) 56 | got = setup_got(llm) 57 | 58 | logger.info("Running GraphOfThoughts synchronously...") 59 | for q in QUESTIONS: 60 | try: 61 | a = got.run(q, graph_of_operations=EXAMPLE_GOO) 62 | print(f"Q: {q}\nA: {a}\n") 63 | except NotImplementedError as e: 64 | logger.debug(f"GraphOfThoughts sync run failed correctly: {e}", exc_info=True) 65 | print("GraphOfThoughts run failed correctly." 66 | " The implementation does not support synchronous execution.") 67 | 68 | 69 | if __name__ == "__main__": 70 | run_main(main_sync, main_async, "Run Graph-of-Thoughts example") 71 | -------------------------------------------------------------------------------- /examples/run_least_to_most.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import logging 5 | 6 | from cogitator import LeastToMost 7 | from examples.shared import get_llm, run_main, setup_logging 8 | 9 | setup_logging() 10 | logger = logging.getLogger(__name__) 11 | 12 | QUESTIONS = [ 13 | "A box has 3 red balls and 5 blue balls. How many balls in total?", 14 | "If x plus 7 equals 12, what is x?", 15 | ] 16 | 17 | 18 | async def main_async(args: argparse.Namespace): 19 | llm = get_llm(args.provider, args.model_name, args.openai_key) 20 | ltm = LeastToMost(llm, intermediate_output_format="json") 21 | semaphore = asyncio.Semaphore(5) 22 | 23 | logger.info("Running LeastToMost asynchronously...") 24 | tasks = [ltm.run_async(q, semaphore=semaphore) for q in QUESTIONS] 25 | answers = await asyncio.gather(*tasks) 26 | 27 | for q, a in zip(QUESTIONS, answers): 28 | print(f"Q: {q}\nA: {a}\n") 29 | 30 | 31 | def main_sync(args: argparse.Namespace): 32 | llm = get_llm(args.provider, args.model_name, args.openai_key) 33 | ltm = LeastToMost(llm, intermediate_output_format="json") 34 | 35 | logger.info("Running LeastToMost synchronously...") 36 | for q in QUESTIONS: 37 | a = ltm.run(q) 38 | print(f"Q: {q}\nA: {a}\n") 39 | 40 | 41 | if __name__ == "__main__": 42 | run_main(main_sync, main_async, "Run Least-to-Most example") 43 | -------------------------------------------------------------------------------- /examples/run_sc_cot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import logging 5 | 6 | from cogitator import BaseLLM, SelfConsistency 7 | from examples.shared import get_llm, run_main, setup_logging 8 | 9 | setup_logging() 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_sc(llm: BaseLLM) -> SelfConsistency: 14 | return SelfConsistency( 15 | llm, 16 | n_samples=10, 17 | temperature=0.8, # More randomness leads to more diverse answers 18 | max_tokens=200, 19 | internal_extraction_format="json", 20 | # "heuristic" or "json" (JSON option is more robust but might not be supported by all LLMs) 21 | ) 22 | 23 | 24 | QUESTIONS = [ 25 | "Q: A farmer had 17 sheep. All but 9 died. How many sheep are left?\nA: Let's think step by step.", 26 | "Q: If a train travels 60 miles in 1 hour, how far can it travel in 2.5 hours?\nA: Let's break it down." 27 | ] 28 | 29 | 30 | async def main_async(args: argparse.Namespace): 31 | llm = get_llm(args.provider, args.model_name, args.openai_key) 32 | sc = setup_sc(llm) 33 | semaphore = asyncio.Semaphore(5) 34 | 35 | logger.info("Running SelfConsistency concurrently for multiple questions...") 36 | 37 | async def run_single_question(prompt: str): 38 | logger.info(f"Processing async: {prompt[:50]}...") 39 | answer = await sc.run_async(prompt, semaphore=semaphore) 40 | print(f"\nPrompt: {prompt}") 41 | print(f"Final Answer (async self-consistency): {answer}") 42 | return answer 43 | 44 | tasks = [run_single_question(q) for q in QUESTIONS] 45 | await asyncio.gather(*tasks) 46 | logger.info("Async processing complete.") 47 | 48 | 49 | def main_sync(args: argparse.Namespace): 50 | llm = get_llm(args.provider, args.model_name, args.openai_key) 51 | sc = setup_sc(llm) 52 | 53 | logger.info("Running SelfConsistency sequentially for multiple questions...") 54 | for i, prompt in enumerate(QUESTIONS): 55 | logger.info(f"Processing sync Q{i + 1}: {prompt[:50]}...") 56 | answer = sc.run(prompt) 57 | print(f"\nPrompt: {prompt}") 58 | print(f"Final Answer (sync self-consistency): {answer}") 59 | logger.info("Sync processing complete.") 60 | 61 | 62 | if __name__ == "__main__": 63 | run_main(main_sync, main_async, "Run Self-Consistency example") 64 | -------------------------------------------------------------------------------- /examples/run_simple_example.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from cogitator import SelfConsistency, OllamaLLM 4 | 5 | # Step 1: Configure logging (optional, but helpful) 6 | logging.basicConfig(level=logging.INFO) 7 | logging.getLogger("httpx").setLevel(logging.WARNING) # Suppress HTTPX logs 8 | 9 | # Step 2: Initialize the LLM (using Ollama) 10 | # Needs Ollama running locally with the model pulled (e.g., `ollama pull gemma3:4b`) 11 | try: 12 | llm = OllamaLLM(model="gemma3:4b") 13 | except Exception as e: 14 | print(f"Error initializing Ollama LLM: {e}") 15 | print("Please make sure Ollama is running and the model is pulled.") 16 | exit(1) 17 | 18 | # Step 3: Choose a CoT strategies (Self-Consistency in this case) 19 | # Self-Consistency generates multiple reasoning paths and finds the most common answer 20 | sc_strategy = SelfConsistency( 21 | llm, 22 | n_samples=5, # Number of reasoning paths to generate 23 | temperature=0.7 # Higher temperature can lead to more diverse answers 24 | ) 25 | 26 | # Step 4: Define the prompt (with a basic CoT trigger) 27 | question = "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?" 28 | prompt = f"Q: {question}\nA: Let's think step by step." 29 | 30 | # Step 5: Run the CoT prompting sc_strategy 31 | print(f"\nQuestion: {question}") 32 | print("Running Self-Consistency CoT...") 33 | final_answer = sc_strategy.run(prompt) # Returns the most consistent (repeated) answer 34 | 35 | # Expected output: $0.05 or 0.05 (may vary slightly based on model and temperature) 36 | print(f"\nCogitator's Answer (Self-Consistency): {final_answer}") 37 | -------------------------------------------------------------------------------- /examples/run_tree_of_thoughts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import logging 5 | 6 | from cogitator import BaseLLM, TreeOfThoughts 7 | from examples.shared import get_llm, run_main, setup_logging 8 | 9 | setup_logging() 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_tot(llm: BaseLLM) -> TreeOfThoughts: 14 | return TreeOfThoughts(llm, max_depth=2, num_branches=2, sims=4, c_puct=1.0) 15 | 16 | 17 | QUESTIONS = [ 18 | "A garden has 4 rows of 6 plants each. How many plants total?", 19 | "If y minus 5 equals 10, what is y?", 20 | ] 21 | 22 | 23 | async def main_async(args: argparse.Namespace): 24 | llm = get_llm(args.provider, args.model_name, args.openai_key) 25 | tot = setup_tot(llm) 26 | semaphore = asyncio.Semaphore(5) 27 | 28 | logger.info("Running TreeOfThoughts asynchronously...") 29 | tasks = [tot.run_async(q, semaphore=semaphore) for q in QUESTIONS] 30 | answers = await asyncio.gather(*tasks) 31 | 32 | for q, a in zip(QUESTIONS, answers): 33 | print(f"Q: {q}\nA: {a}\n") 34 | 35 | 36 | def main_sync(args: argparse.Namespace): 37 | llm = get_llm(args.provider, args.model_name, args.openai_key) 38 | tot = setup_tot(llm) 39 | 40 | logger.info("Running TreeOfThoughts synchronously...") 41 | for q in QUESTIONS: 42 | a = tot.run(q) 43 | print(f"Q: {q}\nA: {a}\n") 44 | 45 | 46 | if __name__ == "__main__": 47 | run_main(main_sync, main_async, "Run Tree-of-Thoughts example") 48 | -------------------------------------------------------------------------------- /examples/shared.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import logging 4 | import os 5 | from typing import Any, Callable, Coroutine, Optional 6 | 7 | from cogitator import BaseLLM, OllamaLLM, OpenAILLM 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def setup_logging() -> None: 13 | logging.basicConfig( 14 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 15 | ) 16 | logging.getLogger("httpx").setLevel(logging.WARNING) 17 | logging.getLogger("sentence_transformers.SentenceTransformer").setLevel(logging.WARNING) 18 | 19 | 20 | def get_llm(provider: str, model_name: str, openai_key: Optional[str] = None) -> BaseLLM: 21 | logger.info(f"Initializing LLM for examples: provider={provider}, model={model_name}") 22 | if provider == "openai": 23 | key = openai_key or os.getenv("OPENAI_API_KEY") 24 | if not key: 25 | raise ValueError( 26 | "OpenAI API key must be provided via --openai-key or " 27 | "OPENAI_API_KEY environment variable." 28 | ) 29 | return OpenAILLM(api_key=key, model=model_name) 30 | elif provider == "ollama": 31 | return OllamaLLM(model=model_name) 32 | else: 33 | raise ValueError(f"Unsupported provider: {provider}") 34 | 35 | 36 | def parse_common_args(description: str) -> argparse.Namespace: 37 | parser = argparse.ArgumentParser(description=description) 38 | parser.add_argument( 39 | "--provider", 40 | choices=["openai", "ollama"], 41 | default="ollama", 42 | help="LLM provider to use (default: ollama)", 43 | ) 44 | parser.add_argument( 45 | "--model-name", 46 | default=None, 47 | help="Name of the model (default: 'gemma3:4b' for ollama, 'gpt-4.1-nano' for openai)", 48 | ) 49 | parser.add_argument( 50 | "--openai-key", 51 | default=None, 52 | help="OpenAI API key (reads OPENAI_API_KEY env var if not set)", 53 | ) 54 | parser.add_argument( 55 | "--use-async", action="store_true", help="Run the asynchronous version of the example" 56 | ) 57 | 58 | args = parser.parse_args() 59 | 60 | if not args.model_name: 61 | args.model_name = "gpt-4.1-nano" if args.provider == "openai" else "gemma3:4b" 62 | logger.info( 63 | f"Model name not specified, using default for {args.provider}: {args.model_name}" 64 | ) 65 | 66 | return args 67 | 68 | 69 | def run_main( 70 | main_sync_func: Callable[[argparse.Namespace], None], 71 | main_async_func: Callable[[argparse.Namespace], Coroutine[Any, Any, None]], 72 | description: str, 73 | ) -> None: 74 | args = parse_common_args(description) 75 | 76 | if args.use_async: 77 | asyncio.run(main_async_func(args)) 78 | else: 79 | main_sync_func(args) 80 | -------------------------------------------------------------------------------- /logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 17 | 19 | 38 | 46 | 54 | 62 | 71 | 79 | 87 | 95 | 100 | 101 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Cogitator Documentation 2 | site_description: Documentation for the Cogitator toolkit. 3 | repo_url: https://github.com/habedi/cogitator 4 | repo_name: habedi/cogitator 5 | 6 | theme: 7 | name: material 8 | palette: 9 | - media: "(prefers-color-scheme: light)" 10 | scheme: default 11 | toggle: 12 | icon: material/brightness-7 13 | name: Switch to dark mode 14 | - media: "(prefers-color-scheme: dark)" 15 | scheme: slate 16 | toggle: 17 | icon: material/brightness-4 18 | name: Switch to light mode 19 | features: 20 | - content.code.copy 21 | - navigation.tabs 22 | - navigation.top 23 | - navigation.indexes 24 | - navigation.expand 25 | - content.code.select 26 | - content.code.annotate 27 | 28 | plugins: 29 | - search 30 | - mkdocstrings: 31 | handlers: 32 | python: 33 | options: 34 | show_root_heading: true 35 | show_source: true 36 | nav: 37 | - Home: index.md 38 | - API Reference: 39 | - LLM Providers: 40 | - Base: api/model/base.md 41 | - OpenAI: api/model/openai.md 42 | - Ollama: api/model/ollama.md 43 | - Schemas: api/schemas.md 44 | - Embedding: api/embedding.md 45 | - Clustering: api/clustering.md 46 | - Utilities: api/utils.md 47 | - Strategies: 48 | - AutoCoT: api/strategies/auto_cot.md 49 | - CDWCoT: api/strategies/cdw_cot.md 50 | - GraphOfThoughts: api/strategies/graph_of_thoughts.md 51 | - LeastToMost: api/strategies/least_to_most.md 52 | - SelfConsistency: api/strategies/sc_cot.md 53 | - TreeOfThoughts: api/strategies/tree_of_thoughts.md 54 | - Benchmarking: benchmarking.md 55 | - Contributing: contributing.md 56 | 57 | markdown_extensions: 58 | - pymdownx.highlight: 59 | anchor_linenums: true 60 | - pymdownx.inlinehilite 61 | - pymdownx.snippets 62 | - pymdownx.superfences 63 | - admonition 64 | - toc: 65 | permalink: true 66 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cogitator" 3 | version = "0.1.0b1" 4 | description = "A Python toolkit for chain-of-thought prompting" 5 | authors = ["Hassan Abedi "] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{ include = "cogitator" }] 9 | repository = "https://github.com/habedi/cogitator" 10 | documentation = "https://habedi.github.io/cogitator" 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Intended Audience :: Developers", 14 | "License :: OSI Approved :: MIT License", 15 | "Programming Language :: Python :: 3", 16 | "Programming Language :: Python :: 3.10", 17 | "Programming Language :: Python :: 3.11", 18 | "Programming Language :: Python :: 3.12", 19 | "Programming Language :: Python :: 3.13", 20 | "Topic :: Software Development :: Libraries", 21 | "Topic :: Software Development :: Libraries :: Python Modules", 22 | ] 23 | keywords = ["prompt engineering", "chain-of-thought", "artificial intelligence", "machine learning"] 24 | 25 | [tool.poetry.dependencies] 26 | python = "^3.10" 27 | scikit-learn = "^1.6.1" 28 | openai = "^1.76.2" 29 | ollama = "^0.4.8" 30 | pydantic = "^2.11.4" 31 | sentence-transformers = "^4.1.0" 32 | tiktoken = "^0.9.0" 33 | 34 | [tool.poetry.group.dev.dependencies] 35 | pytest = "^8.0.1" 36 | pytest-cov = "^6.0.0" 37 | pytest-mock = "^3.14.0" 38 | pytest-asyncio = "^0.26.0" 39 | mypy = "^1.11.1" 40 | ruff = "^0.11.7" 41 | griffe = "^1.7.3" 42 | polars = "^1.28.1" 43 | datasets = "^3.5.1" 44 | pyyaml = "^6.0.2" 45 | pre-commit = "^4.2.0" 46 | mkdocs = "^1.6.1" 47 | mkdocs-material = "^9.6.12" 48 | mkdocstrings-python = "^1.16.10" 49 | 50 | [build-system] 51 | requires = ["poetry-core"] 52 | build-backend = "poetry.core.masonry.api" 53 | 54 | [tool.pytest.ini_options] 55 | pythonpath = ["cogitator"] 56 | testpaths = ["tests"] 57 | addopts = [ 58 | "--tb=short", 59 | #"--disable-warnings", 60 | "--cov=cogitator", 61 | "--cov-branch", 62 | "--cov-report=term", 63 | "--cov-report=xml", 64 | "-rs", 65 | ] 66 | asyncio_mode = "auto" 67 | asyncio_default_fixture_loop_scope = "function" 68 | asyncio_default_test_loop_scope = "function" 69 | 70 | [tool.coverage.run] 71 | branch = true 72 | parallel = true 73 | source = ["cogitator"] 74 | omit = [ 75 | "tests/*", 76 | "benches/*", 77 | "examples/*", 78 | ] 79 | 80 | [tool.coverage.report] 81 | show_missing = false 82 | skip_empty = true 83 | precision = 2 84 | 85 | [tool.mypy] 86 | python_version = "3.10" 87 | ignore_missing_imports = true 88 | disallow_untyped_defs = true 89 | disallow_untyped_calls = true 90 | disallow_incomplete_defs = true 91 | check_untyped_defs = true 92 | warn_return_any = true 93 | strict_optional = true 94 | warn_redundant_casts = true 95 | exclude = "^(benches/|examples/|tests/)" 96 | 97 | [tool.ruff] 98 | exclude = [ 99 | ".bzr", 100 | ".direnv", 101 | ".eggs", 102 | ".git", 103 | ".git-rewrite", 104 | ".hg", 105 | ".mypy_cache", 106 | ".nox", 107 | ".pants.d", 108 | ".pytype", 109 | ".ruff_cache", 110 | ".svn", 111 | ".tox", 112 | ".venv", 113 | "__pypackages__", 114 | "_build", 115 | "buck-out", 116 | "build", 117 | "dist", 118 | "node_modules", 119 | "venv", 120 | # Additional directories to exclude 121 | "tests", 122 | "benches", 123 | "examples", 124 | ] 125 | line-length = 100 126 | indent-width = 4 127 | src = ["cogitator", "examples"] 128 | target-version = "py310" 129 | unsafe-fixes = true 130 | 131 | [tool.ruff.lint] 132 | select = ["ANN", "E", "F", "I", "W", "B", "RUF", "SIM", "C90"] 133 | ignore = [ 134 | # Ignore docstring errors 135 | "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", 136 | ] 137 | fixable = ["ALL"] 138 | unfixable = [] 139 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 140 | 141 | [tool.ruff.format] 142 | quote-style = "double" 143 | indent-style = "space" 144 | skip-magic-trailing-comma = false 145 | line-ending = "auto" 146 | 147 | [tool.ruff.lint.pydocstyle] 148 | convention = "google" 149 | 150 | [tool.ruff.lint.per-file-ignores] 151 | "tests/**/*.py" = [] 152 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple, Type 5 | 6 | import numpy as np 7 | import pytest 8 | from pydantic import BaseModel 9 | 10 | from cogitator import BaseClusterer 11 | from cogitator import BaseEmbedder 12 | from cogitator import BaseLLM 13 | from cogitator import ( 14 | EvaluationResult, 15 | ExtractedAnswer, 16 | LTMDecomposition, 17 | ThoughtExpansion, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | DEFAULT_SYNC_RESPONSE = "SYNC_RESPONSE" 23 | DEFAULT_ASYNC_RESPONSE = "ASYNC_RESPONSE" 24 | DEFAULT_JSON_RESPONSE = ExtractedAnswer(final_answer="json_sync_default") 25 | DEFAULT_ASYNC_JSON_RESPONSE = ExtractedAnswer(final_answer="json_async_default") 26 | DEFAULT_FINAL_ANSWER = "FINAL_ANSWER_DEFAULT" 27 | DEFAULT_SUBANSWER = "SUBANSWER_DEFAULT" 28 | DEFAULT_JSON_STEPS = ThoughtExpansion(thoughts=["step1_default_sync"]) 29 | DEFAULT_ASYNC_JSON_STEPS = ThoughtExpansion(thoughts=["step1_default_async"]) 30 | DEFAULT_JSON_SUBQUESTIONS = LTMDecomposition(subquestions=["subq1_default_sync"]) 31 | DEFAULT_ASYNC_JSON_SUBQUESTIONS = LTMDecomposition(subquestions=["subq1_default_async"]) 32 | DEFAULT_JSON_EVAL = EvaluationResult(score=7, justification="Default Eval Sync") 33 | DEFAULT_ASYNC_JSON_EVAL = EvaluationResult(score=8, justification="Default Eval Async") 34 | 35 | 36 | class ConfigurableFakeLLM(BaseLLM): 37 | 38 | def __init__(self, config: Optional[Dict[str, Any]] = None): 39 | self._config = config if config is not None else {} 40 | self.sync_calls: List[Dict[str, Any]] = [] 41 | self.async_calls: List[Dict[str, Any]] = [] 42 | self.responses_map: Dict[str, Any] = self._config.get("responses_map", {}) 43 | self._call_counts = { 44 | "generate": 0, 45 | "generate_async": 0, 46 | "_generate_json_internal": 0, 47 | "_generate_json_internal_async": 0, 48 | "stream": 0, 49 | "async_stream": 0 50 | } 51 | 52 | def _get_next_response(self, key: str, config_lookup_key: str, default: Any) -> Any: 53 | if key not in self._call_counts: 54 | raise KeyError(f"Internal error: key '{key}' not initialized in _call_counts.") 55 | 56 | response_config = self._config.get(config_lookup_key) 57 | 58 | if isinstance(response_config, list): 59 | if not response_config: 60 | self._call_counts[key] += 1 61 | return default 62 | idx = self._call_counts[key] % len(response_config) 63 | self._call_counts[key] += 1 64 | return response_config[idx] 65 | elif response_config is not None: 66 | self._call_counts[key] += 1 67 | return response_config 68 | else: 69 | self._call_counts[key] += 1 70 | return default 71 | 72 | def _get_response_for_prompt(self, prompt: str, method_type: str) -> Any: 73 | 74 | longest_match_key = None 75 | for key_fragment in sorted(self.responses_map.keys(), key=len, reverse=True): 76 | if key_fragment in prompt: 77 | longest_match_key = key_fragment 78 | break 79 | 80 | if longest_match_key is not None: 81 | return self.responses_map[longest_match_key] 82 | 83 | is_json_method = "json" in method_type 84 | if "JSON Output:" in prompt and "thoughts" in prompt: 85 | key = "json_steps" 86 | default = DEFAULT_ASYNC_JSON_STEPS if 'async' in method_type else DEFAULT_JSON_STEPS 87 | elif "JSON list of strings" in prompt: 88 | key = "generate_async" if 'async' in method_type else "generate_sync" 89 | default = json.dumps( 90 | DEFAULT_ASYNC_JSON_STEPS.model_dump()) if 'async' in method_type else json.dumps( 91 | DEFAULT_JSON_STEPS.model_dump()) 92 | elif "JSON Output:" in prompt and "subquestions" in prompt: 93 | key = "json_subquestions" 94 | default = DEFAULT_ASYNC_JSON_SUBQUESTIONS if 'async' in method_type else DEFAULT_JSON_SUBQUESTIONS 95 | elif "JSON Evaluation:" in prompt: 96 | key = "json_eval" 97 | default = DEFAULT_ASYNC_JSON_EVAL if 'async' in method_type else DEFAULT_JSON_EVAL 98 | elif "JSON Answer:" in prompt: 99 | key = "json_answer" 100 | 101 | return self._get_next_response(method_type, key, 102 | DEFAULT_ASYNC_JSON_RESPONSE if 'async' in method_type else DEFAULT_JSON_RESPONSE) 103 | elif "Current Subquestion:" in prompt: 104 | key = "sub_answer" 105 | default = DEFAULT_SUBANSWER + ("_async" if 'async' in method_type else "") 106 | elif "Given reasoning steps" in prompt \ 107 | or prompt.startswith("Answer the question:") \ 108 | or prompt.startswith( 109 | "Based on the following sequential subquestions"): 110 | key = "final_answer" 111 | default = DEFAULT_FINAL_ANSWER + ("_async" if 'async' in method_type else "") 112 | else: 113 | if method_type == "generate": 114 | key, default = "generate_sync", DEFAULT_SYNC_RESPONSE 115 | elif method_type == "generate_async": 116 | key, default = "generate_async", DEFAULT_ASYNC_RESPONSE 117 | elif method_type == "_generate_json_internal": 118 | key, default = "generate_json", DEFAULT_JSON_RESPONSE 119 | elif method_type == "_generate_json_internal_async": 120 | key, default = "generate_json_async", DEFAULT_ASYNC_JSON_RESPONSE 121 | elif method_type == "stream": 122 | key, default = "generate_sync", DEFAULT_SYNC_RESPONSE 123 | elif method_type == "async_stream": 124 | key, default = "generate_async", DEFAULT_ASYNC_RESPONSE 125 | else: 126 | key, default = "unhandled", "UNHANDLED_FAKE_RESPONSE" 127 | 128 | return self._get_next_response(method_type, key, default) 129 | 130 | def generate(self, prompt: str, **kwargs: Any) -> str: 131 | self.sync_calls.append({"type": "generate", "prompt": prompt, "kwargs": kwargs}) 132 | response = self._get_response_for_prompt(prompt, "generate") 133 | 134 | if not isinstance(response, str): 135 | try: 136 | 137 | if isinstance(response, BaseModel): return response.model_dump_json() 138 | if isinstance(response, (dict, list)): return json.dumps(response) 139 | except Exception: 140 | pass 141 | return str(response) 142 | return response 143 | 144 | async def generate_async(self, prompt: str, **kwargs: Any) -> str: 145 | self.async_calls.append({"type": "generate_async", "prompt": prompt, "kwargs": kwargs}) 146 | await asyncio.sleep(0.001) 147 | response = self._get_response_for_prompt(prompt, "generate_async") 148 | 149 | if not isinstance(response, str): 150 | try: 151 | if isinstance(response, BaseModel): return response.model_dump_json() 152 | if isinstance(response, (dict, list)): return json.dumps(response) 153 | except Exception: 154 | pass 155 | return str(response) 156 | return response 157 | 158 | def generate_stream(self, prompt: str, **kwargs: Any) -> Iterator[str]: 159 | self.sync_calls.append({"type": "stream", "prompt": prompt, "kwargs": kwargs}) 160 | response = self._get_response_for_prompt(prompt, "stream") 161 | yield str(response) + "_stream" 162 | 163 | async def generate_stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[str]: 164 | self.async_calls.append({"type": "async_stream", "prompt": prompt, "kwargs": kwargs}) 165 | await asyncio.sleep(0.001) 166 | response = self._get_response_for_prompt(prompt, "async_stream") 167 | yield str(response) + "_async_stream" 168 | 169 | def _generate_json_internal(self, prompt: str, response_model: Type[BaseModel], 170 | **kwargs: Any) -> Tuple[str, Optional[str]]: 171 | self.sync_calls.append({ 172 | "type": "_generate_json_internal", 173 | "prompt": prompt, 174 | "response_model": response_model.__name__, 175 | "kwargs": kwargs 176 | }) 177 | response_obj = self._get_response_for_prompt(prompt, "_generate_json_internal") 178 | json_string = "" 179 | if isinstance(response_obj, BaseModel): 180 | json_string = response_obj.model_dump_json() 181 | elif isinstance(response_obj, dict): 182 | json_string = json.dumps(response_obj) 183 | elif isinstance(response_obj, str): 184 | try: 185 | json.loads(response_obj) 186 | json_string = response_obj 187 | except json.JSONDecodeError: 188 | logger.warning( 189 | f"Mock Configured string response for JSON prompt is not valid JSON: {response_obj}") 190 | json_string = "{}" 191 | else: 192 | try: 193 | json_string = json.dumps(response_obj) 194 | except TypeError: 195 | logger.warning( 196 | f"Mock cannot dump configured response to JSON: {type(response_obj)}") 197 | json_string = "{}" 198 | 199 | mode_used = "mock_json_mode" 200 | return json_string, mode_used 201 | 202 | async def _generate_json_internal_async(self, prompt: str, response_model: Type[BaseModel], 203 | **kwargs: Any) -> Tuple[ 204 | str, Optional[str]]: 205 | self.async_calls.append({ 206 | "type": "_generate_json_internal_async", 207 | "prompt": prompt, 208 | "response_model": response_model.__name__, 209 | "kwargs": kwargs 210 | }) 211 | await asyncio.sleep(0.001) 212 | response_obj = self._get_response_for_prompt(prompt, "_generate_json_internal_async") 213 | json_string = "" 214 | if isinstance(response_obj, BaseModel): 215 | json_string = response_obj.model_dump_json() 216 | elif isinstance(response_obj, dict): 217 | json_string = json.dumps(response_obj) 218 | elif isinstance(response_obj, str): 219 | try: 220 | json.loads(response_obj) 221 | json_string = response_obj 222 | except json.JSONDecodeError: 223 | logger.warning( 224 | f"Mock Configured string response for async JSON prompt is not valid JSON: {response_obj}") 225 | json_string = "{}" 226 | else: 227 | try: 228 | json_string = json.dumps(response_obj) 229 | except TypeError: 230 | logger.warning( 231 | f"Mock cannot dump configured async response to JSON: {type(response_obj)}") 232 | json_string = "{}" 233 | 234 | mode_used = "mock_json_mode_async" 235 | return json_string, mode_used 236 | 237 | 238 | @pytest.fixture 239 | def fake_llm_factory() -> Callable[[Optional[Dict[str, Any]]], ConfigurableFakeLLM]: 240 | def _create_llm(config: Optional[Dict[str, Any]] = None) -> ConfigurableFakeLLM: 241 | return ConfigurableFakeLLM(config) 242 | 243 | return _create_llm 244 | 245 | 246 | class MockEmbedder(BaseEmbedder): 247 | def encode(self, texts: List[str]) -> List[np.ndarray]: 248 | logger.debug(f"Mock encoding texts: {texts}") 249 | return [np.array([float(i), float(i + 1)], dtype=float) for i in range(len(texts))] 250 | 251 | 252 | class MockClusterer(BaseClusterer): 253 | def cluster(self, embeddings: np.ndarray, n_clusters: int, **kwargs) -> Tuple[ 254 | np.ndarray, np.ndarray]: 255 | logger.debug( 256 | f"Mock clustering embeddings (shape {embeddings.shape}) into {n_clusters} clusters") 257 | if embeddings.shape[0] == 0 or n_clusters <= 0: 258 | output_dim = embeddings.shape[1] if len( 259 | embeddings.shape) > 1 and embeddings.shape[1] > 0 else 1 260 | labels = np.array([], dtype=int) 261 | centers = np.array([], dtype=float).reshape(0, output_dim) 262 | else: 263 | output_dim = embeddings.shape[1] 264 | n_clusters = min(n_clusters, embeddings.shape[0]) 265 | labels = (embeddings[:, 0] % n_clusters).astype(int) 266 | centers = np.array( 267 | [embeddings[labels == i].mean(axis=0) if np.any(labels == i) else np.zeros( 268 | output_dim) for i in range(n_clusters)]) 269 | if centers.ndim == 1 and output_dim > 0: 270 | centers = centers.reshape(-1, output_dim) 271 | elif centers.ndim == 0 and output_dim == 0: 272 | centers = centers.reshape(n_clusters, 1) 273 | logger.debug(f"Generated labels: {labels}") 274 | logger.debug(f"Generated centers shape: {centers.shape}") 275 | return labels, centers 276 | 277 | 278 | @pytest.fixture 279 | def patch_embedding_clustering(monkeypatch): 280 | logger.debug("Patching embedding and clustering classes") 281 | monkeypatch.setattr("cogitator.strategies.auto_cot.SentenceTransformerEmbedder", MockEmbedder) 282 | monkeypatch.setattr("cogitator.strategies.auto_cot.KMeansClusterer", MockClusterer) 283 | monkeypatch.setattr("cogitator.strategies.cdw_cot.SentenceTransformerEmbedder", MockEmbedder) 284 | monkeypatch.setattr("cogitator.strategies.cdw_cot.KMeansClusterer", MockClusterer) 285 | monkeypatch.setattr("cogitator.strategies.graph_of_thoughts.SentenceTransformerEmbedder", 286 | MockEmbedder) 287 | -------------------------------------------------------------------------------- /tests/extra_tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import logging 4 | import sys 5 | from pathlib import Path 6 | from unittest.mock import MagicMock 7 | 8 | import pytest 9 | 10 | # --- Test Setup --- 11 | # Add the examples directory to the path for module imports 12 | examples_dir = Path(__file__).parent.parent / "examples" 13 | sys.path.insert(0, str(examples_dir.parent)) 14 | 15 | # List of example script filenames 16 | example_scripts = [f.stem for f in examples_dir.glob("run_*.py")] 17 | 18 | 19 | @pytest.fixture 20 | def mock_get_llm_examples(mocker): 21 | """ Mocks get_llm specifically for the examples """ 22 | mock_llm_instance = MagicMock() 23 | mock_llm_instance.generate.return_value = "Example Mock LLM Response" 24 | mock_llm_instance.generate_async.return_value = "Example Mock Async LLM Response" 25 | # Add mocks for other methods used by examples if necessary 26 | # e.g., fit, fit_async, run, run_async for specific strategies 27 | mock_llm_instance.fit = MagicMock() 28 | mock_llm_instance.fit_async = MagicMock(return_value=None) # Needs to be awaitable 29 | mock_llm_instance.init_pool = MagicMock() 30 | mock_llm_instance.init_pool_async = MagicMock(return_value=None) 31 | mock_llm_instance.train = MagicMock() 32 | mock_llm_instance.train_async = MagicMock(return_value=None) 33 | mock_llm_instance.run = MagicMock(return_value="Example Mock Run") 34 | mock_llm_instance.run_async = MagicMock( 35 | return_value="Example Mock Async Run") # Needs to be awaitable 36 | mock_llm_instance.generate_json = MagicMock(return_value={"final_answer": "Mock JSON"}) 37 | mock_llm_instance.generate_json_async = MagicMock( 38 | return_value={"final_answer": "Mock Async JSON"}) # Needs to be awaitable 39 | 40 | # Mock the get_llm function within the examples.shared module 41 | return mocker.patch("examples.shared.get_llm", return_value=mock_llm_instance) 42 | 43 | 44 | @pytest.fixture 45 | def mock_embedding_clustering_examples(mocker): 46 | """ Mocks embedding and clustering for examples that might use them """ 47 | mocker.patch("cogitator.embedding.SentenceTransformerEmbedder.encode", 48 | return_value=[MagicMock()]) 49 | mocker.patch("cogitator.clustering.KMeansClusterer.cluster", 50 | return_value=(MagicMock(), MagicMock())) 51 | 52 | 53 | # --- Test Runner --- 54 | 55 | @pytest.mark.parametrize("script_name", example_scripts) 56 | def test_example_script_sync(script_name, mock_get_llm_examples, mock_embedding_clustering_examples, 57 | mocker, capsys): 58 | """ Tests if example scripts run synchronously without errors """ 59 | logging.info(f"Testing example (sync): {script_name}") 60 | # Mock sys.argv to prevent interference and set provider 61 | mocker.patch("sys.argv", ["examples/" + script_name + ".py", "--provider", 62 | "ollama"]) # Use ollama as it requires no key 63 | 64 | try: 65 | # Import the module 66 | module = importlib.import_module(f"examples.{script_name}") 67 | # Check if main_sync exists and run it 68 | if hasattr(module, "main_sync"): 69 | module.main_sync( 70 | argparse.Namespace(provider="ollama", model_name="mock_model", openai_key=None, 71 | use_async=False)) 72 | # Capture stdout to check for basic execution (optional) 73 | captured = capsys.readouterr() 74 | assert "Error" not in captured.err # Basic check for errors in output 75 | logging.info(f"Sync run for {script_name} completed.") 76 | else: 77 | pytest.skip(f"No main_sync function found in {script_name}") 78 | except Exception as e: 79 | pytest.fail(f"Example script {script_name} (sync) failed: {e}") 80 | 81 | 82 | @pytest.mark.asyncio 83 | @pytest.mark.parametrize("script_name", example_scripts) 84 | async def test_example_script_async(script_name, mock_get_llm_examples, 85 | mock_embedding_clustering_examples, mocker, capsys): 86 | """ Tests if example scripts run asynchronously without errors """ 87 | logging.info(f"Testing example (async): {script_name}") 88 | # Mock sys.argv to prevent interference and set provider + async flag 89 | mocker.patch("sys.argv", 90 | ["examples/" + script_name + ".py", "--provider", "ollama", "--use-async"]) 91 | 92 | try: 93 | module = importlib.import_module(f"examples.{script_name}") 94 | if hasattr(module, "main_async"): 95 | # Mock asyncio.gather if necessary, especially for complex async flows 96 | # mocker.patch('asyncio.gather', return_value=["Mock Async Result"]) 97 | 98 | # Mock the specific CoT methods used in async examples if get_llm mock isn't enough 99 | # Example: mocker.patch('cogitator.AutoCoT.fit_async', return_value=None) 100 | # Example: mocker.patch('cogitator.AutoCoT.run_async', return_value="Mock Async Output") 101 | # (Covered by mock_get_llm_examples for now, but might need refinement) 102 | 103 | await module.main_async( 104 | argparse.Namespace(provider="ollama", model_name="mock_model", openai_key=None, 105 | use_async=True)) 106 | captured = capsys.readouterr() 107 | assert "Error" not in captured.err 108 | logging.info(f"Async run for {script_name} completed.") 109 | else: 110 | pytest.skip(f"No main_async function found in {script_name}") 111 | except Exception as e: 112 | pytest.fail(f"Example script {script_name} (async) failed: {e}") 113 | -------------------------------------------------------------------------------- /tests/test_auto_cot.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cogitator import AutoCoT 4 | 5 | 6 | def test_fit_builds_expected_number_of_demos(fake_llm_factory, patch_embedding_clustering): 7 | questions = [f"q{i}" for i in range(8)] 8 | llm = fake_llm_factory({ 9 | "generate_sync": "Step 1\nStep 2" 10 | }) 11 | ac = AutoCoT(llm, n_demos=2, max_q_tokens=100, max_steps=5) 12 | ac.fit(questions) 13 | assert ac.demos is not None 14 | assert len(ac.demos) == 2 15 | for demo in ac.demos: 16 | assert demo.startswith("Q: ") 17 | assert "Step 1" in demo 18 | 19 | 20 | @pytest.mark.asyncio 21 | async def test_fit_async_builds_expected_number_of_demos(fake_llm_factory, 22 | patch_embedding_clustering): 23 | questions = [f"q{i}" for i in range(8)] 24 | llm = fake_llm_factory({ 25 | "generate_async": "Async Step 1\nAsync Step 2" 26 | }) 27 | ac = AutoCoT(llm, n_demos=2, max_q_tokens=100, max_steps=5) 28 | await ac.fit_async(questions) 29 | assert ac.demos is not None 30 | assert len(ac.demos) == 2 31 | for demo in ac.demos: 32 | assert demo.startswith("Q: ") 33 | assert "Async Step 1" in demo 34 | 35 | 36 | def test_run_uses_cached_demos_and_constructs_payload(fake_llm_factory, 37 | patch_embedding_clustering): 38 | questions = [f"q{i}" for i in range(8)] 39 | llm = fake_llm_factory({ 40 | "generate_sync": "Sync Final Answer" 41 | }) 42 | ac = AutoCoT(llm, n_demos=2) 43 | ac.fit(questions) 44 | assert ac.demos is not None 45 | out = ac.run("test question") 46 | assert out == "Sync Final Answer" 47 | assert "test question" in llm.sync_calls[-1]["prompt"] 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_run_async_uses_cached_demos(fake_llm_factory, patch_embedding_clustering): 52 | questions = [f"q{i}" for i in range(8)] 53 | llm = fake_llm_factory({ 54 | "generate_async": "Async Final Answer" 55 | }) 56 | ac = AutoCoT(llm, n_demos=2) 57 | await ac.fit_async(questions) 58 | assert ac.demos is not None 59 | out = await ac.run_async("test question async") 60 | assert out == "Async Final Answer" 61 | assert "test question async" in llm.async_calls[-1]["prompt"] 62 | 63 | 64 | def test_fit_raises_with_insufficient_questions(fake_llm_factory): 65 | llm = fake_llm_factory() 66 | ac = AutoCoT(llm, n_demos=3) 67 | with pytest.raises(ValueError): 68 | ac.fit(["only one"]) 69 | 70 | 71 | @pytest.mark.asyncio 72 | async def test_fit_async_raises_with_insufficient_questions(fake_llm_factory): 73 | llm = fake_llm_factory() 74 | ac = AutoCoT(llm, n_demos=3) 75 | with pytest.raises(ValueError): 76 | await ac.fit_async(["only one"]) 77 | -------------------------------------------------------------------------------- /tests/test_cdw_cot.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cogitator import CDWCoT 4 | 5 | 6 | def test_init_pool_builds_candidate_pool(fake_llm_factory, patch_embedding_clustering): 7 | questions = [f"q{i}" for i in range(8)] 8 | answers = [f"a{i}" for i in range(8)] 9 | llm = fake_llm_factory({"generate_sync": "Pool CoT Sync"}) 10 | cdw = CDWCoT(llm, pool_size=4, n_clusters=2, sample_size=2) 11 | cdw.init_pool(questions, answers) 12 | assert len(cdw.PC) <= 4 13 | assert cdw.cluster_centers is not None 14 | assert len(cdw.p_cluster) == cdw.cluster_centers.shape[0] 15 | if cdw.PC: 16 | assert "Pool CoT Sync" in cdw.PC[0] 17 | for p in cdw.p_cluster: 18 | assert pytest.approx(1.0) == p.sum() 19 | 20 | 21 | @pytest.mark.asyncio 22 | async def test_init_pool_async_builds_candidate_pool(fake_llm_factory, patch_embedding_clustering): 23 | questions = [f"q{i}" for i in range(8)] 24 | answers = [f"a{i}" for i in range(8)] 25 | llm = fake_llm_factory({"generate_async": "Pool CoT Async"}) 26 | cdw = CDWCoT(llm, pool_size=4, n_clusters=2, sample_size=2) 27 | await cdw.init_pool_async(questions, answers) 28 | assert len(cdw.PC) <= 4 29 | assert cdw.cluster_centers is not None 30 | assert len(cdw.p_cluster) == cdw.cluster_centers.shape[0] 31 | if cdw.PC: 32 | assert "Pool CoT Async" in cdw.PC[0] 33 | for p in cdw.p_cluster: 34 | assert pytest.approx(1.0) == p.sum() 35 | 36 | 37 | def test_train_and_run_flow_runs_without_error(fake_llm_factory, patch_embedding_clustering): 38 | questions = [f"q{i}{i}" for i in range(10)] 39 | answers = [f"a{i}" for i in range(10)] 40 | llm = fake_llm_factory({"generate_sync": "Train/Run Sync"}) 41 | # Removed temp=1.0 42 | cdw = CDWCoT(llm, pool_size=5, n_clusters=3, sample_size=2, lr=0.1) 43 | cdw.init_pool(questions, answers) 44 | cdw.train(val_split=0.3, epochs=1, patience=1) 45 | if not cdw.PC or cdw.cluster_centers is None: pytest.skip("Pool/Clusters empty") 46 | out = cdw.run("some sync test") # Changed from answer to run 47 | assert out == "Train/Run Sync" 48 | assert "some sync test" in llm.sync_calls[-1]["prompt"] 49 | 50 | 51 | @pytest.mark.asyncio 52 | async def test_train_async_and_run_async_flow(fake_llm_factory, patch_embedding_clustering): 53 | questions = [f"q{i}{i}" for i in range(10)] 54 | answers = [f"a{i}" for i in range(10)] 55 | llm = fake_llm_factory({"generate_async": "Train/Run Async"}) 56 | # Removed temp=1.0 57 | cdw = CDWCoT(llm, pool_size=5, n_clusters=3, sample_size=2, lr=0.1, seed=42) 58 | await cdw.init_pool_async(questions, answers) 59 | await cdw.train_async(val_split=0.3, epochs=1, patience=1) 60 | if not cdw.PC or cdw.cluster_centers is None: pytest.skip("Pool/Clusters empty") 61 | out = await cdw.run_async("some async test") # Changed from answer_async to run_async 62 | assert out == "Train/Run Async" 63 | assert "some async test" in llm.async_calls[-1]["prompt"] 64 | 65 | 66 | def test_init_pool_raises_on_length_mismatch(fake_llm_factory): 67 | llm = fake_llm_factory() 68 | cdw = CDWCoT(llm) 69 | with pytest.raises(ValueError): 70 | cdw.init_pool(["q1"], ["a1", "a2"]) 71 | 72 | 73 | @pytest.mark.asyncio 74 | async def test_init_pool_async_raises_on_length_mismatch(fake_llm_factory): 75 | llm = fake_llm_factory() 76 | cdw = CDWCoT(llm) 77 | with pytest.raises(ValueError): 78 | await cdw.init_pool_async(["q1"], ["a1", "a2"]) 79 | -------------------------------------------------------------------------------- /tests/test_graph_of_thoughts.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | 3 | import pytest 4 | 5 | from cogitator import EvaluationResult, ExtractedAnswer, ThoughtExpansion 6 | from cogitator import GraphOfThoughts 7 | 8 | EXAMPLE_BASIC_GOO: List[Tuple[str, Dict]] = [ 9 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated_step1', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}), 10 | ('Score', {'target_set': 'generated_step1', 'prompt_key': 'evaluate'}), 11 | ('KeepBest', {'N': 1, 'target_set': 'generated_step1', 'output_set': 'best_final_node'}) 12 | ] 13 | 14 | @pytest.mark.asyncio 15 | async def test_run_async_returns_result_and_calls_prompts_text_format( 16 | fake_llm_factory, patch_embedding_clustering): 17 | fake_expansion_config_async = ThoughtExpansion(thoughts=["stepA_async"]) 18 | fake_eval_config_async = EvaluationResult(score=9, justification="Good_async") 19 | llm = fake_llm_factory({ 20 | "json_steps": fake_expansion_config_async, 21 | "json_eval": fake_eval_config_async, 22 | "responses_map": { 23 | "Based on the final reasoning": "RESULT_async_text" 24 | } 25 | }) 26 | 27 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="text") 28 | 29 | test_goo: List[Tuple[str, Dict]] = [ 30 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}), 31 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}), 32 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'}) 33 | ] 34 | 35 | out = await got_instance.run_async("start_async?", graph_of_operations=test_goo) 36 | assert out == "RESULT_async_text" 37 | 38 | gen_op_call = next((c for c in llm.async_calls if 39 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "ThoughtExpansion"), 40 | None) 41 | score_op_call = next((c for c in llm.async_calls if 42 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "EvaluationResult"), 43 | None) 44 | final_answer_call = next((c for c in llm.async_calls if 45 | c["type"] == "generate_async" and c["prompt"].startswith("Based on the final reasoning")), 46 | None) 47 | 48 | assert gen_op_call is not None, "Async GenerateOp LLM call not found" 49 | # Corrected Assertion: Check for the actual end of the modified prompt 50 | assert "JSON Output:" in gen_op_call["prompt"], "GenerateOp did not use expected prompt content" 51 | assert score_op_call is not None, "Async ScoreOp LLM call not found" 52 | assert "JSON Evaluation:" in score_op_call["prompt"], "ScoreOp did not use expected prompt content" 53 | assert final_answer_call is not None, "Async final answer generation call (text) not found" 54 | 55 | @pytest.mark.asyncio 56 | async def test_run_async_returns_result_and_calls_prompts_json_format( 57 | fake_llm_factory, patch_embedding_clustering): 58 | fake_expansion_config_async = ThoughtExpansion(thoughts=["stepA_async_json"]) 59 | fake_eval_config_async = EvaluationResult(score=9, justification="Good_async_json") 60 | fake_final_answer_obj_async = ExtractedAnswer(final_answer="RESULT_async_json") 61 | llm = fake_llm_factory({ 62 | "json_steps": fake_expansion_config_async, 63 | "json_eval": fake_eval_config_async, 64 | "json_answer": fake_final_answer_obj_async 65 | }) 66 | 67 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="json") 68 | 69 | test_goo: List[Tuple[str, Dict]] = [ 70 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}), 71 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}), 72 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'}) 73 | ] 74 | 75 | out = await got_instance.run_async("start_async_json?", graph_of_operations=test_goo) 76 | assert out == "RESULT_async_json" 77 | 78 | gen_op_call = next((c for c in llm.async_calls if 79 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "ThoughtExpansion"), 80 | None) 81 | score_op_call = next((c for c in llm.async_calls if 82 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "EvaluationResult"), 83 | None) 84 | final_json_call = next((c for c in llm.async_calls if 85 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "ExtractedAnswer"), 86 | None) 87 | 88 | assert gen_op_call is not None, "Async GenerateOp LLM call not found" 89 | assert score_op_call is not None, "Async ScoreOp LLM call not found" 90 | assert final_json_call is not None, "Async final answer generation call (JSON) not found" 91 | assert "Based on the final reasoning" in final_json_call["prompt"], "Final prompt content mismatch" 92 | 93 | def test_run_returns_result_and_calls_prompts_text_format( 94 | fake_llm_factory, patch_embedding_clustering): 95 | fake_expansion_config = ThoughtExpansion(thoughts=["stepA_sync"]) 96 | fake_eval_config = EvaluationResult(score=9, justification="Good_sync") 97 | llm = fake_llm_factory({ 98 | "json_steps": fake_expansion_config, 99 | "json_eval": fake_eval_config, 100 | "responses_map": { 101 | "Based on the final reasoning": "RESULT_sync_text" 102 | } 103 | }) 104 | 105 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="text") 106 | 107 | test_goo: List[Tuple[str, Dict]] = [ 108 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}), 109 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}), 110 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'}) 111 | ] 112 | 113 | try: 114 | # This test might still fail or be unreliable due to asyncio.run issues 115 | # Mark as skipped or handle potential errors robustly if run is just a wrapper 116 | # pytest.skip("Skipping sync test for GoT due to asyncio.run wrapper issues.") 117 | out = got_instance.run("start?", graph_of_operations=test_goo) 118 | assert out == "RESULT_sync_text" 119 | assert len(llm.async_calls) > 0, "Expected async calls even in sync test due to wrapper" 120 | 121 | except NotImplementedError: 122 | pytest.skip("Synchronous 'run' not implemented for GraphOfThoughts.") 123 | except RuntimeError as e: 124 | if "event loop" in str(e).lower(): 125 | pytest.skip(f"Skipping sync test due to asyncio event loop issue: {e}") 126 | else: 127 | raise 128 | 129 | 130 | def test_run_returns_result_and_calls_prompts_json_format( 131 | fake_llm_factory, patch_embedding_clustering): 132 | fake_expansion_config = ThoughtExpansion(thoughts=["stepA_sync_json"]) 133 | fake_eval_config = EvaluationResult(score=9, justification="Good_sync_json") 134 | fake_final_answer_obj = ExtractedAnswer(final_answer="RESULT_sync_json") 135 | llm = fake_llm_factory({ 136 | "json_steps": fake_expansion_config, 137 | "json_eval": fake_eval_config, 138 | "json_answer": fake_final_answer_obj 139 | }) 140 | 141 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="json") 142 | 143 | test_goo: List[Tuple[str, Dict]] = [ 144 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}), 145 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}), 146 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'}) 147 | ] 148 | 149 | try: 150 | # This test might still fail or be unreliable due to asyncio.run issues 151 | # Mark as skipped or handle potential errors robustly if run is just a wrapper 152 | # pytest.skip("Skipping sync test for GoT due to asyncio.run wrapper issues.") 153 | out = got_instance.run("start_json?", graph_of_operations=test_goo) 154 | assert out == "RESULT_sync_json" 155 | assert len(llm.async_calls) > 0, "Expected async calls even in sync test due to wrapper" 156 | assert any(c["type"] == "_generate_json_internal_async" and c["response_model"] == "ExtractedAnswer" for c in llm.async_calls), "Final async JSON call missing" 157 | 158 | except NotImplementedError: 159 | pytest.skip("Synchronous 'run' not implemented for GraphOfThoughts.") 160 | except RuntimeError as e: 161 | if "event loop" in str(e).lower(): 162 | pytest.skip(f"Skipping sync test due to asyncio event loop issue: {e}") 163 | else: 164 | raise 165 | -------------------------------------------------------------------------------- /tests/test_least_to_most.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cogitator import LTMDecomposition, ExtractedAnswer 4 | from cogitator import LeastToMost 5 | 6 | 7 | def test_decompose_numbered(fake_llm_factory): 8 | llm = fake_llm_factory({ 9 | "json_subquestions": LTMDecomposition(subquestions=["sub1", "sub2", "sub3"]) 10 | }) 11 | ltm = LeastToMost(llm, intermediate_output_format="json") 12 | subs = ltm.decompose("anything") 13 | assert subs == ["sub1", "sub2", "sub3"] 14 | assert len(llm.sync_calls) == 1 15 | assert llm.sync_calls[0]["type"] == "_generate_json_internal" 16 | assert llm.sync_calls[0]["response_model"] == "LTMDecomposition" 17 | 18 | 19 | def test_decompose_bullets(fake_llm_factory): 20 | llm = fake_llm_factory({ 21 | "json_subquestions": LTMDecomposition(subquestions=["subA", "subB"]) 22 | }) 23 | ltm = LeastToMost(llm, intermediate_output_format="json") 24 | subs = ltm.decompose("anything") 25 | assert subs == ["subA", "subB"] 26 | 27 | 28 | def test_decompose_fallback_simulated(fake_llm_factory): 29 | llm = fake_llm_factory({ 30 | "json_subquestions": LTMDecomposition( 31 | subquestions=["This is a sentence", "Another sentence"]) 32 | }) 33 | ltm = LeastToMost(llm, intermediate_output_format="json") 34 | subs = ltm.decompose("anything") 35 | assert subs == ["This is a sentence", "Another sentence"] 36 | 37 | 38 | def test_decompose_empty_list_raises(fake_llm_factory): 39 | llm = fake_llm_factory({ 40 | "json_subquestions": LTMDecomposition(subquestions=[]) 41 | }) 42 | ltm = LeastToMost(llm, intermediate_output_format="json") 43 | with pytest.raises(ValueError, match="LLM returned empty subquestions list after validation."): 44 | ltm.decompose("anything") 45 | 46 | 47 | def test_decompose_invalid_json_raises(fake_llm_factory): 48 | llm = fake_llm_factory() 49 | llm.generate_json = lambda *args, **kwargs: (_ for _ in ()).throw( 50 | RuntimeError("generate_json failed... Last error: JSONDecodeError...")) 51 | ltm = LeastToMost(llm, intermediate_output_format="json") 52 | with pytest.raises(ValueError, 53 | match=r"Failed to decompose question due to LLM error: RuntimeError"): 54 | ltm.decompose("anything") 55 | 56 | 57 | def test_decompose_invalid_schema_raises(fake_llm_factory): 58 | llm = fake_llm_factory() 59 | llm.generate_json = lambda *args, **kwargs: (_ for _ in ()).throw( 60 | RuntimeError("generate_json failed... Last error: ValidationError...")) 61 | ltm = LeastToMost(llm, intermediate_output_format="json") 62 | with pytest.raises(ValueError, 63 | match=r"Failed to decompose question due to LLM error: RuntimeError"): 64 | ltm.decompose("anything") 65 | 66 | 67 | def test_max_subqs_trims(fake_llm_factory): 68 | many_subs = [f"sub{i + 1}" for i in range(20)] 69 | llm = fake_llm_factory({ 70 | "json_subquestions": LTMDecomposition(subquestions=many_subs) 71 | }) 72 | ltm = LeastToMost(llm, max_subqs=5, intermediate_output_format="json") 73 | subs = ltm.decompose("anything") 74 | assert len(subs) == 5 75 | assert subs == ["sub1", "sub2", "sub3", "sub4", "sub5"] 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_decompose_async_numbered(fake_llm_factory): 80 | llm = fake_llm_factory({ 81 | "json_subquestions": LTMDecomposition(subquestions=["sub1", "sub2", "sub3"]) 82 | }) 83 | ltm = LeastToMost(llm, intermediate_output_format="json") 84 | subs = await ltm.decompose_async("anything async") 85 | assert subs == ["sub1", "sub2", "sub3"] 86 | assert len(llm.async_calls) == 1 87 | assert llm.async_calls[0]["type"] == "_generate_json_internal_async" 88 | assert llm.async_calls[0]["response_model"] == "LTMDecomposition" 89 | 90 | 91 | @pytest.mark.asyncio 92 | async def test_decompose_async_empty_list_raises(fake_llm_factory): 93 | llm = fake_llm_factory({ 94 | "json_subquestions": LTMDecomposition(subquestions=[]) 95 | }) 96 | ltm = LeastToMost(llm, intermediate_output_format="json") 97 | with pytest.raises(ValueError, 98 | match="Async LLM returned empty subquestions list after validation."): 99 | await ltm.decompose_async("anything async") 100 | 101 | 102 | @pytest.mark.asyncio 103 | async def test_decompose_async_invalid_json_raises(fake_llm_factory): 104 | llm = fake_llm_factory() 105 | 106 | async def mock_fail(*args, **kwargs): 107 | raise RuntimeError("generate_json_async failed... Last error: JSONDecodeError...") 108 | 109 | llm.generate_json_async = mock_fail 110 | ltm = LeastToMost(llm, intermediate_output_format="json") 111 | with pytest.raises(ValueError, 112 | match=r"Async decomposition failed due to LLM error: RuntimeError"): 113 | await ltm.decompose_async("anything async") 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_decompose_async_invalid_schema_raises(fake_llm_factory): 118 | llm = fake_llm_factory() 119 | 120 | async def mock_fail(*args, **kwargs): 121 | raise RuntimeError("generate_json_async failed... Last error: ValidationError...") 122 | 123 | llm.generate_json_async = mock_fail 124 | ltm = LeastToMost(llm, intermediate_output_format="json") 125 | with pytest.raises(ValueError, 126 | match=r"Async decomposition failed due to LLM error: RuntimeError"): 127 | await ltm.decompose_async("anything async") 128 | 129 | 130 | # --- Run tests --- 131 | 132 | def test_run_integration_text_mode(fake_llm_factory): 133 | llm = fake_llm_factory({ 134 | "json_subquestions": LTMDecomposition(subquestions=["dummy sub1", "dummy sub2"]), 135 | "sub_answer": "sub_answer_sync_text", 136 | "final_answer": "ANS_sync_text" 137 | }) 138 | ltm = LeastToMost(llm, intermediate_output_format="text") 139 | out = ltm.run("question?") 140 | assert out == "ANS_sync_text" 141 | assert any( 142 | c["type"] == "_generate_json_internal" and c["response_model"] == "LTMDecomposition" for c 143 | in llm.sync_calls) 144 | solve_calls = [c for c in llm.sync_calls if 145 | c["type"] == "generate" and "Current Subquestion:" in c["prompt"]] 146 | final_call = next((c for c in llm.sync_calls if 147 | c["type"] == "generate" and c["prompt"].startswith( 148 | "Based on the following")), None) 149 | assert len(solve_calls) == 2 150 | assert final_call is not None 151 | 152 | 153 | def test_run_integration_json_mode(fake_llm_factory): 154 | subquestions = ["dummy sub1", "dummy sub2"] 155 | question = "question?" 156 | fake_sub_answer_obj = ExtractedAnswer(final_answer="sub_answer_sync_json") 157 | fake_final_answer_obj = ExtractedAnswer(final_answer="ANS_sync_json") 158 | 159 | solve_prompt_key_1 = "Current Subquestion: dummy sub1" 160 | solve_prompt_key_2 = "Current Subquestion: dummy sub2" 161 | # The final prompt starts with "Based on..." and contains the original question 162 | final_prompt_key = f"Original Main Question: {question}" 163 | 164 | llm = fake_llm_factory({ 165 | "json_subquestions": LTMDecomposition(subquestions=subquestions), 166 | "responses_map": { 167 | 168 | solve_prompt_key_1: fake_sub_answer_obj, 169 | solve_prompt_key_2: fake_sub_answer_obj, 170 | final_prompt_key: fake_final_answer_obj 171 | } 172 | }) 173 | ltm = LeastToMost(llm, intermediate_output_format="json") 174 | out = ltm.run(question) 175 | assert out == "ANS_sync_json" 176 | 177 | assert any( 178 | c["type"] == "_generate_json_internal" and c["response_model"] == "LTMDecomposition" for c 179 | in llm.sync_calls) 180 | solve_json_calls = [c for c in llm.sync_calls if 181 | c["type"] == "_generate_json_internal" and "Current Subquestion:" in c[ 182 | "prompt"]] 183 | final_json_call = next((c for c in llm.sync_calls if 184 | c["type"] == "_generate_json_internal" and c["prompt"].startswith( 185 | "Based on the following")), None) 186 | assert len(solve_json_calls) == 2 187 | assert final_json_call is not None 188 | 189 | 190 | @pytest.mark.asyncio 191 | async def test_run_async_integration_text_mode(fake_llm_factory): 192 | llm = fake_llm_factory({ 193 | "json_subquestions": LTMDecomposition( 194 | subquestions=["dummy sub1 async", "dummy sub2 async"]), 195 | "sub_answer": "sub_answer_async_text", 196 | "final_answer": "ANS_async_text" 197 | }) 198 | ltm = LeastToMost(llm, intermediate_output_format="text") 199 | out = await ltm.run_async("question async?") 200 | assert out == "ANS_async_text" 201 | assert any( 202 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "LTMDecomposition" 203 | for c in llm.async_calls) 204 | solve_calls = [c for c in llm.async_calls if 205 | c["type"] == "generate_async" and "Current Subquestion:" in c["prompt"]] 206 | final_call = next((c for c in llm.async_calls if 207 | c["type"] == "generate_async" and c["prompt"].startswith( 208 | "Based on the following")), None) 209 | assert len(solve_calls) == 2 210 | assert final_call is not None 211 | 212 | 213 | @pytest.mark.asyncio 214 | async def test_run_async_integration_json_mode(fake_llm_factory): 215 | subquestions = ["dummy sub1 async", "dummy sub2 async"] 216 | question = "question async?" 217 | fake_sub_answer_obj_async = ExtractedAnswer(final_answer="sub_answer_async_json") 218 | fake_final_answer_obj_async = ExtractedAnswer(final_answer="ANS_async_json") 219 | 220 | solve_prompt_key_1 = "Current Subquestion: dummy sub1 async" 221 | solve_prompt_key_2 = "Current Subquestion: dummy sub2 async" 222 | final_prompt_key = f"Original Main Question: {question}" # Include the specific question 223 | 224 | llm = fake_llm_factory({ 225 | "json_subquestions": LTMDecomposition(subquestions=subquestions), 226 | "responses_map": { 227 | solve_prompt_key_1: fake_sub_answer_obj_async, 228 | solve_prompt_key_2: fake_sub_answer_obj_async, 229 | final_prompt_key: fake_final_answer_obj_async 230 | } 231 | }) 232 | ltm = LeastToMost(llm, intermediate_output_format="json") 233 | out = await ltm.run_async(question) 234 | assert out == "ANS_async_json" 235 | 236 | assert any( 237 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "LTMDecomposition" 238 | for c in llm.async_calls) 239 | solve_json_calls = [c for c in llm.async_calls if 240 | c["type"] == "_generate_json_internal_async" and "Current Subquestion:" in 241 | c["prompt"]] 242 | final_json_call = next((c for c in llm.async_calls if 243 | c["type"] == "_generate_json_internal_async" and c["prompt"].startswith( 244 | "Based on the following")), None) 245 | assert len(solve_json_calls) == 2 246 | assert final_json_call is not None 247 | 248 | 249 | # --- Solve tests --- 250 | @pytest.mark.asyncio 251 | async def test_solve_async_calls_generate_async_text(fake_llm_factory): 252 | expected_sub_answer = "async_sub_answer_test_text" 253 | llm = fake_llm_factory({"sub_answer": expected_sub_answer}) 254 | ltm = LeastToMost(llm, intermediate_output_format="text") 255 | solved = await ltm.solve_async("main q", ["sub1", "sub2"]) 256 | assert solved == [("sub1", expected_sub_answer), ("sub2", expected_sub_answer)] 257 | assert len(llm.async_calls) == 2 258 | assert all(c["type"] == "generate_async" for c in llm.async_calls) 259 | 260 | 261 | @pytest.mark.asyncio 262 | async def test_solve_async_calls_generate_json_async(fake_llm_factory): 263 | expected_sub_answer_obj = ExtractedAnswer(final_answer="async_sub_answer_test_json") 264 | llm = fake_llm_factory({"json_answer": expected_sub_answer_obj}) 265 | ltm = LeastToMost(llm, intermediate_output_format="json") 266 | solved = await ltm.solve_async("main q", ["sub1", "sub2"]) 267 | assert solved == [("sub1", "async_sub_answer_test_json"), 268 | ("sub2", "async_sub_answer_test_json")] 269 | assert len(llm.async_calls) == 2 270 | assert all(c["type"] == "_generate_json_internal_async" for c in llm.async_calls) 271 | assert all(c["response_model"] == "ExtractedAnswer" for c in llm.async_calls) 272 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, AsyncMock 2 | 3 | import pytest 4 | from pydantic import BaseModel, Field 5 | 6 | from cogitator import BaseLLM 7 | from cogitator import OllamaLLM 8 | from cogitator import OpenAILLM 9 | 10 | 11 | class DummySchema(BaseModel): 12 | name: str = Field(...) 13 | value: int = Field(..., gt=0) 14 | 15 | 16 | class SimpleSchema(BaseModel): 17 | answer: str 18 | 19 | 20 | @pytest.fixture 21 | def mock_openai_clients(mocker): 22 | mock_sync_client = MagicMock() 23 | mock_async_client = AsyncMock() 24 | mock_sync_client.chat.completions.create = MagicMock() 25 | mock_async_client.chat.completions.create = AsyncMock() 26 | mocker.patch("cogitator.model.openai.SyncOpenAI", return_value=mock_sync_client) 27 | mocker.patch("cogitator.model.openai.AsyncOpenAI", return_value=mock_async_client) 28 | return mock_sync_client, mock_async_client 29 | 30 | 31 | @pytest.fixture 32 | def mock_ollama_clients(mocker): 33 | mock_sync_client = MagicMock() 34 | mock_async_client = AsyncMock() 35 | mock_sync_client.chat = MagicMock() 36 | mock_async_client.chat = AsyncMock() 37 | mock_sync_client.list = MagicMock(return_value={'models': []}) 38 | mocker.patch("cogitator.model.ollama.Client", new_callable=MagicMock, return_value=mock_sync_client) 39 | mocker.patch("cogitator.model.ollama.AsyncClient", return_value=mock_async_client) 40 | return mock_sync_client, mock_async_client 41 | 42 | 43 | def test_base_llm_abstract_methods(): 44 | class ConcreteLLM(BaseLLM): 45 | def generate(self, prompt: str, **kwargs) -> str: return "" 46 | 47 | async def generate_async(self, prompt: str, **kwargs) -> str: return "" 48 | 49 | def generate_stream(self, prompt: str, **kwargs): yield "" 50 | 51 | async def generate_stream_async(self, prompt: str, **kwargs): yield "" 52 | 53 | def _generate_json_internal(self, prompt: str, response_model, **kwargs): return "{}", None 54 | 55 | async def _generate_json_internal_async(self, prompt: str, response_model, 56 | **kwargs): return "{}", None 57 | 58 | instance = ConcreteLLM() 59 | assert hasattr(instance, "generate") 60 | 61 | 62 | def test_extract_json_block(mock_ollama_clients): 63 | llm = OllamaLLM(model="dummy") 64 | 65 | text_fence = "```json\n{\"key\": \"value\"}\n```" 66 | text_fence_no_lang = "```\n{\"key\": \"value2\"}\n```" 67 | text_braces = "Some text {\"key\": \"value3\"} more text" 68 | text_brackets = "Some text [{\"key\": \"value4\"}] more text" 69 | text_both = "Text {\"a\":1} then [\"b\"] end" 70 | text_nested = "Text {\"a\": {\"b\": 1}} end" 71 | text_no_json = "Just plain text" 72 | text_empty = "" 73 | 74 | assert llm._extract_json_block(text_fence) == "{\"key\": \"value\"}" 75 | assert llm._extract_json_block(text_fence_no_lang) == "{\"key\": \"value2\"}" 76 | assert llm._extract_json_block(text_braces) == "{\"key\": \"value3\"}" 77 | assert llm._extract_json_block(text_brackets) == "[{\"key\": \"value4\"}]" 78 | assert llm._extract_json_block(text_both) == "{\"a\":1}" 79 | assert llm._extract_json_block(text_nested) == "{\"a\": {\"b\": 1}}" 80 | assert llm._extract_json_block(text_no_json) == "Just plain text" 81 | assert llm._extract_json_block(text_empty) == "" 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "model_name, expected_mode, expect_format_present, expect_additional_props", [ 86 | ("gpt-4o", "json_schema", True, False), 87 | ("gpt-4o-mini", "json_schema", True, False), 88 | ("gpt-4-turbo", "json_object", True, None), 89 | ("gpt-3.5-turbo-1106", "json_object", True, None), 90 | ("gpt-3.5-turbo-0613", "json_schema", True, False), 91 | ("unknown-model", "json_schema", True, False), 92 | ]) 93 | def test_openai_prepare_api_params_json_modes(mock_openai_clients, model_name, expected_mode, 94 | expect_format_present, expect_additional_props): 95 | llm = OpenAILLM(api_key="dummy_key", model=model_name) 96 | params, mode = llm._prepare_api_params(is_json_mode=True, response_schema=DummySchema) 97 | 98 | if expect_format_present: 99 | assert "response_format" in params 100 | rf = params["response_format"] 101 | assert rf["type"] == expected_mode 102 | if expected_mode == "json_schema": 103 | assert "json_schema" in rf 104 | assert rf["json_schema"]["name"] == "DummySchema" 105 | schema = rf["json_schema"]["schema"] 106 | assert schema.get("additionalProperties") is expect_additional_props 107 | elif expected_mode == "json_object": 108 | assert "json_schema" not in rf 109 | assert expect_additional_props is None 110 | else: 111 | assert "response_format" not in params 112 | assert expect_additional_props is None 113 | if expect_format_present: 114 | assert mode == expected_mode 115 | else: 116 | assert mode is None 117 | 118 | 119 | def test_openai_prepare_api_params_no_schema_json_mode(mock_openai_clients): 120 | llm_json = OpenAILLM(api_key="d", model="gpt-4-turbo") 121 | params, mode = llm_json._prepare_api_params(is_json_mode=True, response_schema=None) 122 | assert params["response_format"]["type"] == "json_object" 123 | assert mode == "json_object" 124 | 125 | llm_no_json = OpenAILLM(api_key="d", model="gpt-3.5-turbo-0613") 126 | params, mode = llm_no_json._prepare_api_params(is_json_mode=True, response_schema=None) 127 | assert "response_format" not in params 128 | assert mode is None 129 | 130 | 131 | def test_openai_prepare_api_params_schema_generation_fails(mock_openai_clients, mocker): 132 | llm = OpenAILLM(api_key="d", model="gpt-4o") 133 | mocker.patch.object(DummySchema, "model_json_schema", side_effect=TypeError("Schema fail")) 134 | params, mode = llm._prepare_api_params(is_json_mode=True, response_schema=DummySchema) 135 | assert params["response_format"]["type"] == "json_object" 136 | assert mode == "json_object" 137 | 138 | llm_no_fallback = OpenAILLM(api_key="d", model="gpt-3.5-turbo-0613") 139 | mocker.patch.object(DummySchema, "model_json_schema", 140 | side_effect=TypeError("Schema fail again")) 141 | params_no_fallback, mode_no_fallback = llm_no_fallback._prepare_api_params(is_json_mode=True, 142 | response_schema=DummySchema) 143 | assert "response_format" not in params_no_fallback 144 | assert mode_no_fallback is None 145 | 146 | 147 | def test_ollama_init_success(mock_ollama_clients): 148 | mock_sync, mock_async = mock_ollama_clients 149 | llm = OllamaLLM(model="ollama-test", ollama_host="http://testhost:11434") 150 | assert llm.model == "ollama-test" 151 | assert llm.host == "http://testhost:11434" 152 | assert llm._client == mock_sync 153 | assert llm._async_client == mock_async 154 | 155 | # Verify constructors were called by the patcher via the fixture 156 | # Check the call args on the mock returned by the patcher 157 | from cogitator.model.ollama import Client, AsyncClient 158 | Client.assert_called_once_with(host="http://testhost:11434") 159 | AsyncClient.assert_called_once_with(host="http://testhost:11434") 160 | 161 | 162 | def test_ollama_strip_content(mock_ollama_clients): 163 | llm = OllamaLLM(model="dummy") 164 | 165 | response_dict = {"message": {"content": " Strip Me! "}} 166 | response_obj = MagicMock(message=MagicMock(content=" Strip Me Too! ")) 167 | response_bad_obj = MagicMock(message=None) 168 | response_bad_dict = {"message": None} 169 | response_no_content = {"message": {"role": "assistant"}} 170 | 171 | assert llm._strip_content(response_dict) == "Strip Me!" 172 | assert llm._strip_content(response_obj) == "Strip Me Too!" 173 | assert llm._strip_content(response_bad_obj) == "" 174 | assert llm._strip_content(response_bad_dict) == "" 175 | assert llm._strip_content(response_no_content) == "" 176 | assert llm._strip_content(None) == "" 177 | assert llm._strip_content("string") == "" 178 | 179 | 180 | def test_ollama_prepare_options(mock_ollama_clients): 181 | llm = OllamaLLM(model="d", temperature=0.5, max_tokens=100, stop=["\n"], seed=1) 182 | 183 | opts = llm._prepare_options(temperature=0.8, seed=None, stop=["stop"], extra_param=True) 184 | assert opts == {"temperature": 0.8, "num_predict": 100, "stop": ["stop"], "extra_param": True} 185 | 186 | opts_defaults = llm._prepare_options() 187 | assert opts_defaults == {"temperature": 0.5, "num_predict": 100, "seed": 1, "stop": ["\n"]} 188 | -------------------------------------------------------------------------------- /tests/test_sc_cot.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from cogitator import ExtractedAnswer 6 | from cogitator import SelfConsistency 7 | 8 | 9 | def test_extract_answer_heuristic_with_equal_sign(): 10 | sc = SelfConsistency(llm=None) 11 | assert sc._extract_answer_heuristic("…\n4+1 = 5") == "5" 12 | assert sc._extract_answer_heuristic("Result=42.") == "42" 13 | 14 | 15 | def test_extract_answer_heuristic_with_prefixes(): 16 | sc = SelfConsistency(llm=None) 17 | assert sc._extract_answer_heuristic("Therefore, the answer is 7.") == "7" 18 | assert sc._extract_answer_heuristic("Answer: 99") == "99" 19 | assert sc._extract_answer_heuristic("Ans 13") == "13" 20 | assert sc._extract_answer_heuristic("### 123") == "123" 21 | assert sc._extract_answer_heuristic("Final Answer: foo bar") == "foo bar" 22 | 23 | 24 | def test_extract_answer_heuristic_last_line_fallback(): 25 | sc = SelfConsistency(llm=None) 26 | cot = "Step1\nSome thought\nFinal numerical answer 100" 27 | assert sc._extract_answer_heuristic(cot) == "Final numerical answer 100" 28 | 29 | 30 | def test_extract_answer_heuristic_numeric_last_line(): 31 | sc = SelfConsistency(llm=None) 32 | cot = "Step1\nSome thought\n12345" 33 | assert sc._extract_answer_heuristic(cot) == "12345" 34 | cot_dollar = "Step1\nSome thought\n$56.78" 35 | assert sc._extract_answer_heuristic(cot_dollar) == "56.78" 36 | 37 | 38 | def test_run_majority_vote_heuristic(fake_llm_factory): 39 | responses = ["...\nAns 10", "...\nFinal Answer: 10", "...\nResult=5"] 40 | llm = fake_llm_factory({"generate_sync": responses}) 41 | sc = SelfConsistency(llm=llm, n_samples=3) # Defaults to heuristic 42 | out = sc.run("dummy prompt") 43 | assert out == "10" 44 | assert len(llm.sync_calls) == 3 45 | 46 | 47 | def test_extract_answer_json_path(fake_llm_factory): 48 | expected_answer = "AnswerFromJSON" 49 | llm = fake_llm_factory({ 50 | "json_answer": ExtractedAnswer(final_answer=expected_answer) 51 | }) 52 | sc = SelfConsistency(llm=llm, internal_extraction_format="json") 53 | extracted = sc.extract_answer("Some reasoning text...") 54 | assert extracted == expected_answer 55 | assert len(llm.sync_calls) == 1 56 | assert llm.sync_calls[0]["type"] == "_generate_json_internal" 57 | assert llm.sync_calls[0]["response_model"] == "ExtractedAnswer" 58 | 59 | 60 | def test_run_majority_vote_json(fake_llm_factory): 61 | cot_responses = ["CoT leading to 10", "Another CoT leading to 10", "CoT leading to 5"] 62 | extraction_responses = [ 63 | ExtractedAnswer(final_answer="10"), 64 | ExtractedAnswer(final_answer="10"), 65 | ExtractedAnswer(final_answer="5") 66 | ] 67 | # Configure mock: generate CoTs, then respond to extraction prompts 68 | llm = fake_llm_factory({ 69 | "generate_sync": cot_responses, 70 | "json_answer": extraction_responses # Mock should pick this for JSON extraction calls 71 | }) 72 | sc = SelfConsistency(llm=llm, n_samples=3, internal_extraction_format="json") 73 | out = sc.run("dummy prompt for json run") 74 | assert out == "10" # Verify majority vote result 75 | assert len(llm.sync_calls) == 6 # 3 generate + 3 json_internal 76 | assert sum(1 for c in llm.sync_calls if c["type"] == "generate") == 3 77 | assert sum(1 for c in llm.sync_calls if c["type"] == "_generate_json_internal") == 3 78 | 79 | 80 | def test_run_stream_not_implemented(fake_llm_factory): 81 | llm = fake_llm_factory() 82 | sc = SelfConsistency(llm=llm) 83 | with pytest.raises(NotImplementedError): 84 | next(sc.run_stream("anything")) 85 | 86 | 87 | @pytest.mark.asyncio 88 | async def test_extract_answer_async_heuristic(): 89 | sc = SelfConsistency(llm=None) 90 | assert await sc.extract_answer_async("...\nFinal Answer: ABC_async") == "ABC_async" 91 | assert await sc.extract_answer_async("...\nResult=55_async") == "55_async" 92 | 93 | 94 | @pytest.mark.asyncio 95 | async def test_run_async_majority_vote_heuristic(fake_llm_factory): 96 | responses = ["...\nAns Async10", "...\nFinal Answer: Async10", "...\nResult=Async5"] 97 | llm = fake_llm_factory({"generate_async": responses}) 98 | sc = SelfConsistency(llm=llm, n_samples=3) 99 | out = await sc.run_async("dummy async prompt") 100 | assert out == "Async10" 101 | assert len(llm.async_calls) == 3 102 | 103 | 104 | @pytest.mark.asyncio 105 | async def test_extract_answer_json_async_path(fake_llm_factory): 106 | expected_answer = "AsyncAnswerFromJSON" 107 | llm = fake_llm_factory({ 108 | "json_answer": ExtractedAnswer(final_answer=expected_answer) 109 | }) 110 | sc = SelfConsistency(llm=llm, internal_extraction_format="json") 111 | extracted = await sc.extract_answer_async("Some async reasoning text...") 112 | assert extracted == expected_answer 113 | assert len(llm.async_calls) == 1 114 | assert llm.async_calls[0]["type"] == "_generate_json_internal_async" 115 | assert llm.async_calls[0]["response_model"] == "ExtractedAnswer" 116 | 117 | 118 | @pytest.mark.asyncio 119 | async def test_run_async_majority_vote_json(fake_llm_factory): 120 | cot_responses_async = ["Async CoT 10", "Async CoT 10 again", "Async CoT 5"] 121 | extraction_responses_async = [ 122 | ExtractedAnswer(final_answer="10"), 123 | ExtractedAnswer(final_answer="10"), 124 | ExtractedAnswer(final_answer="5") 125 | ] 126 | llm = fake_llm_factory({ 127 | "generate_async": cot_responses_async, 128 | "json_answer": extraction_responses_async 129 | }) 130 | sc = SelfConsistency(llm=llm, n_samples=3, internal_extraction_format="json") 131 | out = await sc.run_async("dummy async json run prompt") 132 | assert out == "10" # Verify majority vote result 133 | assert len(llm.async_calls) == 6 # 3 generate_async + 3 json_internal_async 134 | assert sum(1 for c in llm.async_calls if c["type"] == "generate_async") == 3 135 | assert sum(1 for c in llm.async_calls if c["type"] == "_generate_json_internal_async") == 3 136 | 137 | 138 | @pytest.mark.asyncio 139 | async def test_run_async_with_semaphore(fake_llm_factory): 140 | responses = ["…\nAnswer: S1", "…\nAnswer: S2", "…\nAnswer: S1"] 141 | llm = fake_llm_factory({"generate_async": responses}) 142 | sc = SelfConsistency(llm=llm, n_samples=3) 143 | semaphore = asyncio.Semaphore(2) 144 | out = await sc.run_async("dummy async prompt semaphore", semaphore=semaphore) 145 | assert out == "S1" 146 | assert len(llm.async_calls) == 3 147 | 148 | 149 | @pytest.mark.asyncio 150 | async def test_run_stream_async_not_implemented(fake_llm_factory): 151 | llm = fake_llm_factory() 152 | sc = SelfConsistency(llm=llm) 153 | with pytest.raises(NotImplementedError): 154 | # Corrected: await the coroutine 155 | await sc.run_stream_async("anything") 156 | -------------------------------------------------------------------------------- /tests/test_tree_of_thoughts.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cogitator import ThoughtExpansion, EvaluationResult 4 | from cogitator import TreeOfThoughts 5 | 6 | 7 | def test_run_returns_final_and_calls_prompts(fake_llm_factory): 8 | fake_expansion = ThoughtExpansion(thoughts=["step1_sync"]) 9 | fake_eval = EvaluationResult(score=8, justification="Okay_sync") 10 | llm = fake_llm_factory({ 11 | "json_steps": fake_expansion, 12 | "json_eval": fake_eval, 13 | "final_answer": "FINAL_sync" 14 | }) 15 | tot = TreeOfThoughts(llm, max_depth=1, num_branches=1, sims=1, c_puct=1.0) 16 | out = tot.run("test?") 17 | 18 | assert out == "FINAL_sync" 19 | 20 | expand_call = next((c for c in llm.sync_calls if 21 | c["type"] == "_generate_json_internal" and "JSON Output:" in c[ 22 | "prompt"] and "thoughts" in c["prompt"]), 23 | None) 24 | eval_call = next((c for c in llm.sync_calls if 25 | c["type"] == "_generate_json_internal" and "JSON Evaluation:" in c["prompt"]), 26 | None) 27 | final_call = next((c for c in llm.sync_calls if 28 | c["type"] == "generate" and ( 29 | "Given reasoning steps" in c["prompt"] or c["prompt"].startswith( 30 | "Answer the question:"))), 31 | None) 32 | 33 | assert expand_call is not None, "Expansion call not found" 34 | assert expand_call["response_model"] == "ThoughtExpansion" 35 | assert eval_call is not None, "Evaluation call not found" 36 | assert eval_call["response_model"] == "EvaluationResult" 37 | assert final_call is not None, "Final answer generation call not found" 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_run_async_returns_final_and_calls_prompts(fake_llm_factory): 42 | fake_expansion_async = ThoughtExpansion(thoughts=["step1_async"]) 43 | fake_eval_async = EvaluationResult(score=8, justification="Okay_async") 44 | llm = fake_llm_factory({ 45 | "json_steps": fake_expansion_async, 46 | "json_eval": fake_eval_async, 47 | "final_answer": "FINAL_async" 48 | }) 49 | tot = TreeOfThoughts(llm, max_depth=1, num_branches=1, sims=1, c_puct=1.0) 50 | out = await tot.run_async("test_async?") 51 | 52 | assert out == "FINAL_async" 53 | 54 | expand_call = next((c for c in llm.async_calls if 55 | c["type"] == "_generate_json_internal_async" and "JSON Output:" in c[ 56 | "prompt"] and "thoughts" in c["prompt"]), 57 | None) 58 | eval_call = next((c for c in llm.async_calls if 59 | c["type"] == "_generate_json_internal_async" and "JSON Evaluation:" in c[ 60 | "prompt"]), 61 | None) 62 | final_call = next((c for c in llm.async_calls if 63 | c["type"] == "generate_async" and ( 64 | "Given reasoning steps" in c["prompt"] or c["prompt"].startswith( 65 | "Answer the question:"))), 66 | None) 67 | 68 | assert expand_call is not None, "Async expansion call not found" 69 | assert expand_call["response_model"] == "ThoughtExpansion" 70 | assert eval_call is not None, "Async evaluation call not found" 71 | assert eval_call["response_model"] == "EvaluationResult" 72 | assert final_call is not None, "Async final answer generation call not found" 73 | --------------------------------------------------------------------------------