├── .editorconfig
├── .gitattributes
├── .github
└── workflows
│ ├── docs.yml
│ ├── lints.yml
│ ├── publish.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── benches.yml
├── benches
├── README.md
├── __init__.py
├── evaluate.py
├── extractors.py
├── run.py
└── shared.py
├── codecov.yml
├── cogitator
├── __init__.py
├── clustering.py
├── embedding.py
├── model
│ ├── __init__.py
│ ├── base.py
│ ├── ollama.py
│ ├── openai.py
│ └── py.typed
├── py.typed
├── schemas.py
├── strategies
│ ├── __init__.py
│ ├── auto_cot.py
│ ├── cdw_cot.py
│ ├── graph_of_thoughts.py
│ ├── least_to_most.py
│ ├── py.typed
│ ├── sc_cot.py
│ └── tree_of_thoughts.py
└── utils.py
├── docs
├── api
│ ├── clustering.md
│ ├── embedding.md
│ ├── model
│ │ ├── base.md
│ │ ├── ollama.md
│ │ └── openai.md
│ ├── schemas.md
│ ├── strategies
│ │ ├── auto_cot.md
│ │ ├── cdw_cot.md
│ │ ├── graph_of_thoughts.md
│ │ ├── least_to_most.md
│ │ ├── sc_cot.md
│ │ └── tree_of_thoughts.md
│ └── utils.md
├── assets
│ └── images
│ │ ├── cogitator_v1.dot
│ │ ├── cogitator_v1.svg
│ │ ├── cogitator_v2.dot
│ │ ├── cogitator_v2.svg
│ │ └── make_figures.sh
├── benchmarking.md
├── contributing.md
└── index.md
├── examples
├── README.md
├── __init__.py
├── run_auto_cot.py
├── run_cdw_cot.py
├── run_graph_of_thoughts.py
├── run_least_to_most.py
├── run_sc_cot.py
├── run_simple_example.py
├── run_tree_of_thoughts.py
└── shared.py
├── logo.svg
├── mkdocs.yml
├── poetry.toml
├── pyproject.toml
└── tests
├── conftest.py
├── extra_tests
├── test_benches.py
└── test_examples.py
├── test_auto_cot.py
├── test_cdw_cot.py
├── test_graph_of_thoughts.py
├── test_least_to_most.py
├── test_model.py
├── test_sc_cot.py
└── test_tree_of_thoughts.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig is awesome: https://EditorConfig.org
2 |
3 | # Top-most EditorConfig file
4 | root = true
5 |
6 | # Global settings (applicable to all files unless overridden)
7 | [*]
8 | charset = utf-8 # Default character encoding
9 | end_of_line = lf # Use LF for line endings (Unix-style)
10 | indent_style = space # Use spaces for indentation
11 | indent_size = 4 # Default indentation size
12 | insert_final_newline = true # Make sure files end with a newline
13 | trim_trailing_whitespace = true # Remove trailing whitespace
14 |
15 | [*.py]
16 | max_line_length = 100
17 |
18 | # Markdown files
19 | [*.md]
20 | max_line_length = 130
21 | trim_trailing_whitespace = false
22 |
23 | # Bash scripts
24 | [*.sh]
25 | indent_size = 2
26 |
27 | # YAML files
28 | [*.{yml,yaml}]
29 | indent_size = 2
30 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Common document and text file formats
2 | *.docx filter=lfs diff=lfs merge=lfs -text
3 | *.doc filter=lfs diff=lfs merge=lfs -text
4 | *.pdf filter=lfs diff=lfs merge=lfs -text
5 | *.xls filter=lfs diff=lfs merge=lfs -text
6 | *.xlsx filter=lfs diff=lfs merge=lfs -text
7 | *.ppt filter=lfs diff=lfs merge=lfs -text
8 | *.pptx filter=lfs diff=lfs merge=lfs -text
9 |
10 | # Common image formats
11 | *.jpg filter=lfs diff=lfs merge=lfs -text
12 | *.jpeg filter=lfs diff=lfs merge=lfs -text
13 | *.png filter=lfs diff=lfs merge=lfs -text
14 | *.gif filter=lfs diff=lfs merge=lfs -text
15 | *.bmp filter=lfs diff=lfs merge=lfs -text
16 | *.tiff filter=lfs diff=lfs merge=lfs -text
17 | *.tif filter=lfs diff=lfs merge=lfs -text
18 |
19 | # Common compressed file formats
20 | *.zip filter=lfs diff=lfs merge=lfs -text
21 | *.gz filter=lfs diff=lfs merge=lfs -text
22 | *.tar filter=lfs diff=lfs merge=lfs -text
23 | *.tgz filter=lfs diff=lfs merge=lfs -text
24 | *.bz2 filter=lfs diff=lfs merge=lfs -text
25 | *.7z filter=lfs diff=lfs merge=lfs -text
26 | *.rar filter=lfs diff=lfs merge=lfs -text
27 |
28 | # Common file formats in machine learning projects
29 | *.bin filter=lfs diff=lfs merge=lfs -text
30 | *.model filter=lfs diff=lfs merge=lfs -text
31 | *.h5 filter=lfs diff=lfs merge=lfs -text
32 | *.tfrecord filter=lfs diff=lfs merge=lfs -text
33 | *.hdf5 filter=lfs diff=lfs merge=lfs -text
34 | *.keras filter=lfs diff=lfs merge=lfs -text
35 | *.pth filter=lfs diff=lfs merge=lfs -text
36 | *.pt filter=lfs diff=lfs merge=lfs -text
37 | *.joblib filter=lfs diff=lfs merge=lfs -text
38 | *.pkl filter=lfs diff=lfs merge=lfs -text
39 | *.pickle filter=lfs diff=lfs merge=lfs -text
40 | *.npy filter=lfs diff=lfs merge=lfs -text
41 |
42 | # Common audio and video formats
43 | *.mp3 filter=lfs diff=lfs merge=lfs -text
44 | *.mp4 filter=lfs diff=lfs merge=lfs -text
45 | *.wav filter=lfs diff=lfs merge=lfs -text
46 | *.avi filter=lfs diff=lfs merge=lfs -text
47 | *.mov filter=lfs diff=lfs merge=lfs -text
48 | *.flac filter=lfs diff=lfs merge=lfs -text
49 | *.mkv filter=lfs diff=lfs merge=lfs -text
50 | *.webm filter=lfs diff=lfs merge=lfs -text
51 | *.ogg filter=lfs diff=lfs merge=lfs -text
52 | *.ogv filter=lfs diff=lfs merge=lfs -text
53 |
54 | # Common data transfer formats
55 | #*.csv filter=lfs diff=lfs merge=lfs -text
56 | #*.tsv filter=lfs diff=lfs merge=lfs -text
57 | #*.json filter=lfs diff=lfs merge=lfs -text
58 | #*.xml filter=lfs diff=lfs merge=lfs -text
59 | *.parquet filter=lfs diff=lfs merge=lfs -text
60 | *.feather filter=lfs diff=lfs merge=lfs -text
61 | *.msgpack filter=lfs diff=lfs merge=lfs -text
62 | *.avro filter=lfs diff=lfs merge=lfs -text
63 | *.arrow filter=lfs diff=lfs merge=lfs -text
64 | *.orc filter=lfs diff=lfs merge=lfs -text
65 |
66 | # Exclude files from language stats (GitHub Linguist)
67 | *.ipynb linguist-vendored
68 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Build and Deploy Docs
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | permissions:
7 | contents: read
8 | pages: write
9 | id-token: write
10 |
11 | # Only allow one deployment at a time running
12 | concurrency:
13 | group: "pages"
14 | cancel-in-progress: false
15 |
16 | jobs:
17 | deploy:
18 | environment:
19 | name: github-pages
20 | url: ${{ steps.deployment.outputs.page_url }}
21 | runs-on: ubuntu-latest
22 | steps:
23 | - name: Checkout Repository
24 | uses: actions/checkout@v4
25 |
26 | - name: Set up Python
27 | uses: actions/setup-python@v5
28 | with:
29 | python-version: '3.11'
30 |
31 | - name: Install Poetry
32 | uses: snok/install-poetry@v1
33 | with:
34 | virtualenvs-create: true
35 | virtualenvs-in-project: true
36 |
37 | - name: Load cached venv
38 | id: cached-poetry-dependencies
39 | uses: actions/cache@v4
40 | with:
41 | path: .venv
42 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
43 |
44 | - name: Install Dependencies
45 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
46 | run: make install
47 |
48 | - name: Build Documentation
49 | run: make docs
50 |
51 | - name: Setup Pages
52 | uses: actions/configure-pages@v5
53 |
54 | - name: Upload Documentation as Artifact
55 | uses: actions/upload-pages-artifact@v3
56 | with:
57 | path: './site'
58 |
59 | - name: Deploy to GitHub Pages
60 | id: deployment
61 | uses: actions/deploy-pages@v4
62 |
--------------------------------------------------------------------------------
/.github/workflows/lints.yml:
--------------------------------------------------------------------------------
1 | name: Run Linters
2 |
3 | on:
4 | workflow_dispatch:
5 | workflow_call:
6 | push:
7 | tags:
8 | - 'v*'
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | build:
15 | runs-on: ubuntu-latest
16 |
17 | strategy:
18 | matrix:
19 | python-version: [ "3.10", "3.11", "3.12", "3.13" ]
20 |
21 | steps:
22 | - name: Checkout Repository
23 | uses: actions/checkout@v4
24 |
25 | - name: Set Up Python ${{ matrix.python-version }}
26 | uses: actions/setup-python@v4
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 |
30 | - name: Install Dependencies
31 | run: |
32 | make setup
33 | make install
34 |
35 | - name: Run Tests with Coverage
36 | run: |
37 | make lint
38 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | tags:
7 | - 'v*'
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 |
14 | call_tests:
15 | uses: ./.github/workflows/tests.yml
16 |
17 | publish_to_pypi:
18 | runs-on: ubuntu-latest
19 | needs: call_tests
20 |
21 | steps:
22 | - name: Checkout Repository
23 | uses: actions/checkout@v4
24 |
25 | - name: Set Up Python
26 | uses: actions/setup-python@v4
27 | with:
28 | python-version: "3.10"
29 |
30 | - name: Install Dependencies
31 | run: |
32 | make setup
33 | make install
34 |
35 | - name: Build and Publish to PyPI
36 | run: |
37 | PYPI_TOKEN=${{ secrets.PYPI_API_TOKEN }} make publish
38 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on:
4 | workflow_dispatch:
5 | workflow_call:
6 | pull_request:
7 | branches:
8 | - main
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | build:
15 | runs-on: ubuntu-latest
16 |
17 | strategy:
18 | matrix:
19 | python-version: [ "3.10", "3.11", "3.12", "3.13" ]
20 |
21 | steps:
22 | - name: Checkout Repository
23 | uses: actions/checkout@v4
24 |
25 | - name: Set Up Python ${{ matrix.python-version }}
26 | uses: actions/setup-python@v4
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 |
30 | - name: Install Dependencies
31 | run: |
32 | make setup
33 | make install
34 |
35 | - name: Run Tests with Coverage
36 | run: |
37 | make test
38 |
39 | - name: Upload coverage reports to Codecov
40 | uses: codecov/codecov-action@v5
41 | with:
42 | token: ${{ secrets.CODECOV_TOKEN }}
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python specific
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Virtual environments
7 | .env/
8 | env/
9 | .venv/
10 | venv/
11 |
12 | # Packaging and distribution files
13 | .Python
14 | build/
15 | dist/
16 | *.egg-info/
17 | *.egg
18 | MANIFEST
19 |
20 | # Dependency directories
21 | develop-eggs/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | .installed.cfg
32 |
33 | # Test and coverage reports
34 | htmlcov/
35 | .tox/
36 | .coverage
37 | .coverage.*
38 | .cache
39 | nosetests.xml
40 | coverage.xml
41 | *.cover
42 | .hypothesis/
43 | .pytest_cache/
44 | .benchmarks/
45 |
46 | # IDE specific files and directories
47 | .idea/
48 | *.iml
49 | .vscode/
50 |
51 | # Jupyter Notebook files
52 | .ipynb_checkpoints
53 |
54 | # Temporary files created by editors and the system and folders to ignore
55 | *.swp
56 | *~
57 | *.bak
58 | *.tmp
59 | temp/
60 | output/
61 | tmp/
62 | tmp2/
63 | out/
64 | out2/
65 |
66 | # Database files (SQLite, DuckDB, etc.)
67 | *.duckdb
68 | *.db
69 | *.wal
70 | *.sqlite
71 |
72 | # Dependency lock files (uncomment to ignore)
73 | poetry.lock
74 |
75 | # Miscellaneous files and directories to ignore
76 | # Add any additional file patterns a directory names that should be ignored down here
77 | *_output.txt
78 | .env
79 | benches/data/
80 | coverage_*_report
81 | *.jsonl
82 | site/
83 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - id: check-toml
9 | - id: check-added-large-files
10 | - id: check-merge-conflict
11 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | We adhere to the [Python Software Foundation Code of Conduct](https://policies.python.org/python.org/code-of-conduct).
4 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guidelines
2 |
3 | Thank you for considering contributing to this project!
4 | Contributions are always welcome and appreciated.
5 |
6 | ## How to Contribute
7 |
8 | Please check the [issue tracker](https://github.com/habedi/cogitator/issues) to see if there is an issue you
9 | would like to work on or if it has already been resolved.
10 |
11 | ### Reporting Bugs
12 |
13 | 1. Open an issue on the [issue tracker](https://github.com/habedi/cogitator/issues).
14 | 2. Include information such as steps to reproduce, expected/actual behavior, and relevant logs or screenshots.
15 |
16 | ### Suggesting Features
17 |
18 | 1. Open an issue on the [issue tracker](https://github.com/habedi/cogitator/issues).
19 | 2. Provide details about the feature, its purpose, and potential implementation ideas.
20 |
21 | ## Submitting Pull Requests
22 |
23 | - Ensure all tests pass before submitting a pull request.
24 | - Write a clear description of the changes you made and the reasons behind them.
25 |
26 | > [!IMPORTANT]
27 | > It's assumed that by submitting a pull request, you agree to license your contributions under the project's license.
28 |
29 | ## Development Workflow
30 |
31 | ### Prerequisites
32 |
33 | Install GNU Make on your system if it's not already installed.
34 |
35 | ```shell
36 | # For Debian-based systems like Debian, Ubuntu, etc.
37 | sudo apt-get install make
38 | ```
39 |
40 | - Use the `make setup` command to install the development dependencies.
41 |
42 | ### Code Style
43 |
44 | - Use the `make format` command to format the code.
45 |
46 | ### Running Tests
47 |
48 | - Use the `make test` command to run the tests.
49 |
50 | ### Running Linter Checks
51 |
52 | - Use the `make lint` command to run the linter checks.
53 |
54 | ### See Available Commands
55 |
56 | - Run `make help` to see all available commands for managing different tasks.
57 |
58 | ## Code of Conduct
59 |
60 | We adhere to the [Python Software Foundation Code of Conduct](https://policies.python.org/python.org/code-of-conduct).
61 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Hassan Abedi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Variables
2 | PYTHON = python
3 | PIP = pip
4 | POETRY = poetry
5 | SHELL = /bin/bash
6 | BENCH_DIR = benches
7 | EXAMPLE_DIR = examples
8 | OLLAMA_MODEL ?= gemma3:12b
9 | OPENAI_MODE ?= gpt-4o-mini
10 |
11 | # Default target
12 | .DEFAULT_GOAL := help
13 |
14 | # Help target
15 | .PHONY: help
16 | help: ## Show help messages for all available targets
17 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; \
18 | {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
19 |
20 | # Setup and installation
21 | .PHONY: setup
22 | setup: ## Install system dependencies
23 | sudo apt-get update
24 | sudo apt-get install -y python3-pip
25 | $(PIP) install poetry
26 |
27 | .PHONY: install
28 | install: ## Install Python dependencies
29 | $(POETRY) install --with dev
30 |
31 | # Testing and linting
32 | .PHONY: test
33 | test: ## Run the tests
34 | CUDA_VISIBLE_DEVICES="" $(POETRY) run pytest
35 |
36 | .PHONY: lint
37 | lint: ## Run the linter checks
38 | $(POETRY) run ruff check --fix
39 |
40 | .PHONY: format
41 | format: ## Format the Python files
42 | $(POETRY) run ruff format .
43 |
44 | .PHONY: typecheck
45 | typecheck: ## Typecheck the Python files
46 | $(POETRY) run mypy .
47 |
48 | # Cleaning
49 | .PHONY: clean
50 | clean: ## Remove build artifacts, caches, and temporary files
51 | find . -type f -name '*.pyc' -delete
52 | find . -type d -name '__pycache__' -exec rm -r {} +
53 | rm -rf site dist
54 | rm -rf .mypy_cache .pytest_cache .ruff_cache .coverage htmlcov coverage.xml junit
55 |
56 | # Build and publish
57 | .PHONY: build
58 | build: ## Build the wheel and source distribution
59 | $(POETRY) build
60 |
61 | .PHONY: publish
62 | publish: ## Publish the library to PyPI (requires PYPI_TOKEN to be set)
63 | $(POETRY) config pypi-token.pypi $(PYPI_TOKEN)
64 | $(POETRY) publish --build
65 |
66 | .PHONY: example-openai
67 | example-openai: ## Run the examples using OpenAI (needs OPENAI_API_KEY to be set)
68 | @for script in $(EXAMPLE_DIR)/run_*.py; do \
69 | echo "Running $$script --provider openai --openai-key ******** --model-name $(OPENAI_MODE) --use-async"; \
70 | $(POETRY) run python $$script --provider openai --openai-key $(OPENAI_API_KEY) --model-name $(OPENAI_MODE) --use-async; \
71 | done
72 |
73 | example-ollama: ## Run the examples using Ollama
74 | @echo "Running examples with Ollama provider (Model: $(OLLAMA_MODEL))"
75 | @for script in $(EXAMPLE_DIR)/run_*.py; do \
76 | echo "Running $$script --provider ollama --model-name $(OLLAMA_MODEL)"; \
77 | $(POETRY) run python $$script --provider ollama --model-name $(OLLAMA_MODEL); \
78 | done
79 |
80 | .PHONY: docs
81 | docs: ## Generate the project documentation
82 | $(POETRY) run mkdocs build
83 |
84 | # All-in-one target
85 | .PHONY: all
86 | all: install check build ## Install Python dependencies, run lint, typecheck, tests, and build the library
87 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
Cogitator
8 |
9 | [](https://github.com/habedi/cogitator/actions/workflows/tests.yml)
10 | [](https://codecov.io/gh/habedi/cogitator)
11 | [](https://www.codefactor.io/repository/github/habedi/cogitator)
12 | [](https://github.com/habedi/cogitator)
13 | [](https://pypi.org/project/cogitator)
14 | [](https://github.com/habedi/cogitator)
15 |
16 | [](https://github.com/habedi/cogitator/blob/main/LICENSE)
17 | [](https://habedi.github.io/cogitator)
18 | [](https://doi.org/10.5281/zenodo.15331821)
19 |
20 | A Python Toolkit for Chain-of-Thought Prompting
21 |
22 |
23 |
24 | ---
25 |
26 | Cogitator is a Python toolkit for experimenting and working with
27 | [chain-of-thought (CoT) prompting](https://arxiv.org/abs/2201.11903)
28 | methods in large language models (LLMs).
29 | CoT prompting improves LLM performance on complex tasks (like question-answering, reasoning, and problem-solving)
30 | by guiding the models to generate intermediate reasoning steps before arriving at the final answer.
31 | Additionally, it can be used to improve the interpretability of LLMs by providing insight into the model's reasoning process.
32 | The toolkit aims to make it easier to use popular CoT strategies and frameworks for research or integrating them into AI
33 | applications.
34 |
35 | ### Features
36 |
37 | * Provides unified sync/async API for CoT strategies
38 | * Supports using OpenAI and Ollama as LLM providers
39 | * Supports structured model output with Pydantic validation
40 | * Includes a customizable benchmarking framework (see [benches](benches))
41 | * Includes implementations of popular CoT strategies and frameworks like
42 | - [Self-Consistency CoT (ICLR 2023)](https://arxiv.org/abs/2203.11171)
43 | - [Automatic CoT (ICLR 2023)](https://arxiv.org/abs/2210.03493)
44 | - [Least-to-Most Prompting (ICLR 2023)](https://arxiv.org/abs/2205.10625)
45 | - [Tree of Thoughts (NeurIPS 2023)](https://arxiv.org/abs/2305.10601)
46 | - [Graph of Thoughts (AAAI 2024)](https://arxiv.org/abs/2308.09687)
47 | - [Clustered Distance-Weighted CoT (AAAI 2025)](https://arxiv.org/abs/2501.12226)
48 |
49 | ---
50 |
51 | ### Getting Started
52 |
53 | You can install Cogitator with
54 |
55 | ```bash
56 | pip install cogitator
57 | ```
58 |
59 | Or, if you want to install from the latest version with examples and benchmarks included
60 |
61 | ```bash
62 | git clone https://github.com/habedi/cogitator && cd cogitator
63 |
64 | # Set up Python environment
65 | pip install poetry
66 | poetry install --with dev
67 |
68 | # Run the tests to make sure everything is working (optional)
69 | poetry run pytest
70 | ```
71 |
72 | #### Examples
73 |
74 | Below is a simple example of using the Self-Consistency CoT with Ollama.
75 |
76 | ```python
77 | import logging
78 | from cogitator import SelfConsistency, OllamaLLM
79 |
80 | # Step 1: Configure logging (optional, but helpful)
81 | logging.basicConfig(level=logging.INFO)
82 | logging.getLogger("httpx").setLevel(logging.WARNING) # Suppress HTTPX logs
83 |
84 | # Step 2: Initialize the LLM (using Ollama)
85 | # Needs Ollama running locally with the model pulled (e.g., `ollama pull gemma3:4b`)
86 | try:
87 | llm = OllamaLLM(model="gemma3:4b")
88 | except Exception as e:
89 | print(f"Error initializing Ollama LLM: {e}")
90 | print("Please make sure Ollama is running and the model is pulled.")
91 | exit(1)
92 |
93 | # Step 3: Choose a CoT strategies (Self-Consistency in this case)
94 | # Self-Consistency generates multiple reasoning paths and finds the most common answer
95 | sc_strategy = SelfConsistency(
96 | llm,
97 | n_samples=5, # Number of reasoning paths to generate
98 | temperature=0.7 # Higher temperature can lead to more diverse answers
99 | )
100 |
101 | # Step 4: Define the prompt (with a basic CoT trigger)
102 | question = "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?"
103 | prompt = f"Q: {question}\nA: Let's think step by step."
104 |
105 | # Step 5: Run the CoT prompting sc_strategy
106 | print(f"\nQuestion: {question}")
107 | print("Running Self-Consistency CoT...")
108 | final_answer = sc_strategy.run(prompt) # Returns the most consistent (repeated) answer
109 |
110 | # Expected output: $0.05 or 0.05 (may vary slightly based on model and temperature)
111 | print(f"\nCogitator's Answer (Self-Consistency): {final_answer}")
112 | ```
113 |
114 | Check out the [examples](examples) directory for more examples.
115 |
116 | ---
117 |
118 | ### Documentation
119 |
120 | Cogitator documentation is available [here](https://habedi.github.io/cogitator).
121 |
122 | ---
123 |
124 | ### Benchmarking Framework
125 |
126 | This project includes a customizable and extensible benchmarking framework to evaluate the performance of different
127 | CoT strategies on various datasets like [GSM8K](https://arxiv.org/abs/2110.14168) and
128 | [StrategyQA](https://arxiv.org/abs/2101.02235).
129 |
130 | Check out the [benches](benches) directory for more details about the framework and how it could be used.
131 |
132 | ---
133 |
134 | ### Contributing
135 |
136 | See [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to make a contribution.
137 |
138 | ### Citations
139 |
140 | If you find this project useful, please give it a star!
141 | If you have any questions or feedback, please use the discussion section of the repository or open an issue.
142 | If you use this project in your research, please consider citing using the following information:
143 |
144 | ```bibtex
145 | @software{abedi_cogitator_2025,
146 | author = {Abedi Firouzjaei, Hassan},
147 | title = {{Cogitator: A Python Toolkit for Chain-of-Thought Prompting}},
148 | year = {2025--},
149 | publisher = {Zenodo},
150 | doi = {10.5281/zenodo.15331821},
151 | url = {https://github.com/habedi/cogitator}
152 | }
153 | ```
154 |
155 | ### Logo
156 |
157 | The logo is named "Cognition" and was originally created by [vectordoodle](https://www.svgrepo.com/author/vectordoodle).
158 |
159 | ### License
160 |
161 | Cogitator is licensed under the [MIT License](LICENSE).
162 |
--------------------------------------------------------------------------------
/benches.yml:
--------------------------------------------------------------------------------
1 | common:
2 | debug: false
3 | openai_key_env_var: "OPENAI_API_KEY"
4 |
5 | generation:
6 | dataset: "gsm8k"
7 | cutoff: 50
8 | provider: "ollama"
9 | model_name: "gemma3:4b"
10 | ollama_host: null
11 | use_async: true
12 | concurrency: 3
13 | use_json_strategies: false
14 | output_file: "benchmark_results.jsonl"
15 | llm_params:
16 | max_tokens: 2048
17 | seed: 33
18 | temperature: 0.7
19 |
20 | evaluation:
21 | input_file: "benchmark_results.jsonl"
22 | extractor:
23 | type: "heuristic"
24 | provider: "ollama"
25 | model_name: "gemma3:12b"
26 | ollama_host: null
27 | llm_params:
28 | max_tokens: 64
29 | seed: 42
30 | temperature: 0.1
31 | show_details: false
32 | concurrency: 3
33 |
34 | strategies:
35 | ZeroShotCoT:
36 | enabled: true
37 |
38 | AutoCoT:
39 | enabled: true
40 | n_demos: 5
41 | max_q_tokens: 100
42 | max_steps: 8
43 | max_retries: 3
44 | prompt_template: "Let's think step-by-step."
45 |
46 | CDWCoT:
47 | enabled: true
48 | pool_size: 10
49 | n_clusters: 5
50 | lr: 0.1
51 | sample_size: 10
52 | max_grad_norm: 1.0
53 | init_pool_retries: 1
54 | train_params:
55 | epochs: 5
56 | patience: 2
57 | val_split: 0.2
58 |
59 | SelfConsistency:
60 | enabled: true
61 | n_samples: 10
62 | temperature: 0.8
63 | stop: null
64 | internal_extraction_format: "heuristic"
65 |
66 | LeastToMost:
67 | enabled: true
68 | max_subqs: 10
69 | intermediate_output_format: "text"
70 |
71 | TreeOfThoughts:
72 | enabled: true
73 | max_depth: 3
74 | num_branches: 3
75 | sims: 10
76 | c_puct: 1.0
77 |
78 | GraphOfThoughts:
79 | enabled: true
80 | graph_of_operations:
81 | - [ "Generate", { k: 3, target_set: "frontier", output_set: "generated_1", prompt_key: "expand" } ]
82 | - [ "Score", { target_set: "generated_1", prompt_key: "evaluate" } ]
83 | - [ "KeepBest", { N: 1, target_set: "generated_1", output_set: "frontier" } ]
84 | - [ "Generate", { k: 2, target_set: "frontier", output_set: "generated_2", prompt_key: "expand" } ]
85 | - [ "Score", { target_set: "generated_2", prompt_key: "evaluate" } ]
86 | - [ "KeepBest", { N: 1, target_set: "generated_2", output_set: "frontier" } ]
87 | final_answer_format: "text"
88 |
--------------------------------------------------------------------------------
/benches/README.md:
--------------------------------------------------------------------------------
1 | ## Benchmarks
2 |
3 | Benchmarks are primarily run using the `benches/run.py` script, which handles the generation phase.
4 | A separate script, `benches/evaluate.py`, is used afterward to calculate accuracy from the generated results.
5 |
6 | Run the generation script from the project root directory:
7 |
8 | ```bash
9 | poetry run python benches/run.py [OPTIONS]
10 | ```
11 |
12 | Available Options for `run.py`:
13 |
14 | * `--dataset `: Dataset to use (default: `gsm8k`). Available options: `gsm8k`, `multiarith`, `aqua`, `csqa`,
15 | `strategyqa`, `coin`, and `letter`.
16 | * `--cutoff `: Number of samples to load from the dataset (-1 for all; default: `50`). These samples are used
17 | for both setup (if needed) and generation and testing.
18 | * `--provider `: LLM provider (`ollama` or `openai`; default: `ollama`).
19 | * `--model-name `: Model name for the provider (default: `gemma2:9b` for ollama, `gpt-4o-mini` for openai). Verify model
20 | availability.
21 | * `--openai-key `: OpenAI API key (needed for `--provider openai`, can use `OPENAI_API_KEY` environment variable if
22 | it is set).
23 | * `--use-async`: Use asynchronous execution for LLM calls (default: sync). Highly recommended for speed.
24 | * `--concurrency `: Max concurrent LLM requests when using `--use-async` (default: `3`).
25 | * `--use-json-strategies`: Use JSON mode within strategies where applicable (LtM, GoT, SC) (default: disabled).
26 | * `--output-file `: File to save raw generation results in JSONL format (default: `benchmark_results.jsonl`).
27 | * `--debug`: Enable debug logging for more verbose output.
28 |
29 | Check out OpenAI's [API documentation](https://platform.openai.com/docs/api-reference) for more details on the models
30 | and their capabilities. Use `ollama list` to see the available models for the `ollama` provider.
31 |
32 | ### Benchmark Workflow (`run.py`)
33 |
34 | The `run.py` script executes the following steps:
35 |
36 | 1. **Configuration:** Loads settings from `benches.yml` and merges them with command-line arguments (CLI > YAML > Defaults).
37 | 2. **Dataset Loading:** Loads the specified dataset subset based on the final configuration.
38 | 3. **Model & CoT Strategies:** Sets up the language model and instances of enabled CoT strategies, configured according to
39 | `benches.yml`.
40 | 4. **Setup Phase (One-Time Cost per Run):** Before generating answers for the test questions, strategies that need fitting
41 | or training (like AutoCoT and CDWCoT) perform this step *once* using the loaded dataset samples and configured
42 | parameters.
43 | 5. **Generation Phase (Per Question):** The script iterates through each loaded question:
44 | * For each question, it executes all *enabled* CoT strategies using their configured parameters.
45 | * If run synchronously, strategies execute one after another. If async, calls run concurrently.
46 | * The raw text output from the LLM and the execution time are recorded.
47 | 6. **Output:** Results are saved line-by-line in JSONL format to the specified output file.
48 |
49 | See `poetry run python benches/run.py --help` to see all available options.
50 |
51 | ### Evaluation (`evaluate.py`)
52 |
53 | After `run.py` generates the result file, use `evaluate.py` to calculate metrics:
54 |
55 | ```bash
56 | poetry run python benches/evaluate.py --input-file [EVAL_OPTIONS]
57 | ```
58 |
59 | This script reads the JSONL file, loads its configuration from `benches.yml` and merges with CLI options, extracts the
60 | final answer from the raw model output (using the configured extractor type: heuristic or LLM), compares it to the
61 | correct answer, and calculates the accuracy for each CoT strategy present in the result file. It then shows a summary
62 | table. See `poetry run python benches/evaluate.py --help` for evaluation-specific options.
63 |
64 | ### Configuration (`benches.yml`)
65 |
66 | Benchmark runs are configured using `benches.yml` in the project root, combined with command-line arguments.
67 | **Configuration Precedence:**
68 |
69 | 1. **Command-Line Arguments:** Highest priority (e.g., `--dataset`, `--provider`).
70 | 2. **`benches.yml`:** Values from this file are used if not specified via CLI.
71 | 3. **Code Defaults:** Lowest priority, used if not set in CLI or YAML.
72 |
73 | **YAML Structure:**
74 |
75 | * **`common`**: Shared settings.
76 | * `debug`: `true` or `false` for verbose logging.
77 | * `openai_key_env_var`: Name of the environment variable holding the OpenAI key.
78 | * **`generation`**: Settings for the generation script (`run.py`).
79 | * `dataset`: Name of the dataset (e.g., `gsm8k`).
80 | * `cutoff`: Max number of samples to use (-1 for all).
81 | * `provider`: `ollama` or `openai`.
82 | * `model_name`: Specific model for the provider.
83 | * `ollama_host`: (**Optional**) Specify the host address for the Ollama server (e.g., `http://192.168.1.100:11434`). If `null`
84 | or omitted, uses `OLLAMA_HOST` env var or defaults to `http://localhost:11434`.
85 | * `use_async`: `true` to run LLM calls concurrently.
86 | * `concurrency`: Max parallel requests for async runs.
87 | * `use_json_strategies`: `true` or `false`. Default for strategies supporting JSON output (can be overridden per strategy).
88 | * `output_file`: Path to save raw results (JSONL).
89 | * `llm_params`: Global LLM settings (`max_tokens`, `seed`, `temperature`) applied unless overridden per strategy.
90 | * **`evaluation`**: Settings for the evaluation script (`evaluate.py`).
91 | * `input_file`: Path to the results JSONL file (defaults to `generation.output_file`).
92 | * `extractor`: Configures how final answers are extracted.
93 | * `type`: `heuristic` or `llm`.
94 | * `provider`, `model_name`: Settings for the LLM extractor if `type` is `llm`.
95 | * `ollama_host`: Specify the Ollama host for the extractor LLM, if using `type: llm` and
96 | `provider: ollama`. Defaults apply if null/omitted.
97 | * `llm_params`: Settings for the LLM extractor if `type` is `llm`.
98 | * `show_details`: `true` to print per-question evaluation details.
99 | * `concurrency`: Max parallel requests for the LLM extractor.
100 | * **`strategies`**: Configure individual CoT strategies.
101 | * Each key is the strategy's class name (e.g., `AutoCoT`).
102 | * Including a section enables the strategy by default. Add `enabled: false` to disable.
103 | * Set strategy-specific parameters (e.g., `n_demos`, `pool_size`, `n_samples`, `max_depth`).
104 | * Strategy-specific LLM parameters (like `temperature` for `SelfConsistency`) or format choices (`internal_extraction_format`,
105 | `intermediate_output_format`, `final_answer_format`) set here override global settings from the `generation` section for
106 | that specific strategy.
107 |
108 | See the example `benches.yml` in the repository for detailed options.
109 |
110 | > [!NOTE]
111 | > For parameters like OpenAI keys, it is recommended to specify the *environment variable name* in `benches.yml` (e.g.,
112 | `openai_key_env_var: "MY_API_KEY"`) rather than pasting the key directly into the file.
113 | > The scripts will then read the key from the specified environment variable.
114 | > You can still override this by passing `--openai-key` on the command line.
115 |
116 | ### Example Usages (`run.py`)
117 |
118 | * Run using configuration defined in `benches.yml`:
119 | ```bash
120 | # Assumes benches.yml is configured as desired
121 | poetry run python benches/run.py
122 | ```
123 |
124 | * Run using `benches.yml` but override the dataset via command line:
125 | ```bash
126 | poetry run python benches/run.py --dataset aqua
127 | ```
128 |
129 | ### Dependencies
130 |
131 | To run the benchmarks, you might want to install the development dependencies along with Cogitator itself.
132 |
133 | ```bash
134 | poetry install --with dev
135 | ```
136 |
137 | Additionally, any model used in the benchmarks must be available.
138 | Make sure the Ollama server is **running** and pull desired models using `ollama pull `.
139 | Make sure the OpenAI key is set correctly if using the OpenAI models.
140 |
141 | ### More Examples
142 |
143 | ```bash
144 | # Run using benches.yml (assuming it's configured for Ollama, gemma2:9b, aqua, async, etc.)
145 | poetry run python benches/run.py --output-file my_ollama_results.jsonl
146 |
147 | # Evaluate the results using heuristic (as configured in benches.yml or default)
148 | poetry run python benches/evaluate.py --input-file my_ollama_results.jsonl --show-details
149 |
150 | # Evaluate the results using LLM extractor (override benches.yml extractor setting)
151 | poetry run python benches/evaluate.py --extractor-type llm --provider ollama --model-name llama3 --input-file my_ollama_results.jsonl
152 |
153 | # Run specifically with OpenAI, overriding YAML if necessary
154 | poetry run python benches/run.py --provider openai --model-name gpt-4o-mini --dataset csqa --cutoff 10 --use-async --output-file my_openai_results.jsonl
155 |
156 | # Evaluate the OpenAI results using an OpenAI extractor model
157 | poetry run python benches/evaluate.py --input-file my_openai_results.jsonl --extractor-type llm --provider openai --model-name gpt-4o-mini
158 | ```
159 |
160 | ## Performance Metric
161 |
162 | Accuracy is the primary metric reported by the `evaluate.py` script.
163 | It is defined as the percentage of correctly answered questions out of the total number of successfully extracted answers for a
164 | given CoT strategy.
165 |
166 | > [!NOTE]
167 | > This definition means accuracy reflects performance *only* on runs where the final answer could be successfully extracted.
168 | > Runs resulting in extraction errors (e.g., the extractor fails to find an answer in the raw output) are excluded from the
169 | > accuracy calculation, which is important when comparing strategies with different extraction success rates.
170 |
171 | ## Datasets
172 |
173 | The following datasets can be used for values in the `--dataset` argument of the `run.py` script.
174 |
175 | | Dataset Name | Source Link | Category Tags | Description |
176 | |:-------------|:-------------------------------------------------------------------------------------|:--------------------------|:--------------------------------------------------|
177 | | `gsm8k` | [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) | `math` | Grade school math word problems |
178 | | `multiarith` | [ChilleD/MultiArith](https://huggingface.co/datasets/ChilleD/MultiArith) | `math` | Multi-step arithmetic problems |
179 | | `aqua` | [deepmind/aqua_rat](https://huggingface.co/datasets/deepmind/aqua_rat) | `math` | Algebraic word problems with rationales |
180 | | `csqa` | [tau/commonsense_qa](https://huggingface.co/datasets/tau/commonsense_qa) | `commonsense` | Multiple-choice commonsense questions |
181 | | `strategyqa` | [ChilleD/StrategyQA](https://huggingface.co/datasets/ChilleD/StrategyQA) | `commonsense`, `symbolic` | Yes and no questions requiring implicit reasoning |
182 | | `coin` | [skrishna/coin_flip](https://huggingface.co/datasets/skrishna/coin_flip) | `symbolic` | Symbolic tasks involving state tracking |
183 | | `letter` | [ChilleD/LastLetterConcat](https://huggingface.co/datasets/ChilleD/LastLetterConcat) | `symbolic`, `text` | Extract and concatenate last letters from words |
184 |
--------------------------------------------------------------------------------
/benches/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/benches/__init__.py
--------------------------------------------------------------------------------
/benches/extractors.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from typing import Optional
4 |
5 | logger_extractors = logging.getLogger("benchmark_extractors")
6 |
7 |
8 | def _extract_label_heuristic(text: str) -> Optional[str]:
9 | text = str(text).strip()
10 | lines = text.strip().splitlines()
11 |
12 | mcq_patterns = [
13 | r'(?:final answer|answer|choice|option) is\s*:?\s*\(([A-Ea-e])\)',
14 | r'(?:final answer|answer|choice|option) is\s*:?\s*([A-Ea-e])\b',
15 | r'\(([A-Ea-e])\)\s*is the correct answer',
16 | r'\b([A-Ea-e])\)\s*$',
17 | r'^([A-Ea-e])\)$',
18 | r'correct label \(character before \'\)\'\)\n([A-Ea-e])',
19 | r'\b(?:Answer|Choice|Option):\s*([A-Ea-e])\b',
20 | ]
21 |
22 | if lines:
23 | last_line = lines[-1].strip()
24 | for pattern in mcq_patterns:
25 | match = re.search(pattern, last_line, re.IGNORECASE)
26 | if match:
27 | ans = match.group(1).upper()
28 | logger_extractors.debug(f"Extracted MCQ label '{ans}' from last line")
29 | return ans
30 | if re.fullmatch(r'[A-Ea-e]', last_line):
31 | ans = last_line.upper()
32 | logger_extractors.debug(f"Extracted MCQ label '{ans}' as last line content")
33 | return ans
34 |
35 | for pattern in mcq_patterns:
36 | match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
37 | if match:
38 | ans = match.group(1).upper()
39 | logger_extractors.debug(f"Extracted MCQ label '{ans}' from full text")
40 | return ans
41 |
42 | return None
43 |
44 |
45 | def _extract_numerical_heuristic(text: str) -> Optional[str]:
46 | text = str(text).strip()
47 | lines = text.strip().splitlines()
48 |
49 | gsm_pattern = r'####\s*([+-]?\d+(?:,\d+)*\.?\d*)'
50 | boxed_pattern = r'\\boxed\{([+-]?\d+(?:,\d+)*\.?\d*)\}'
51 |
52 | match = re.search(gsm_pattern, text)
53 | if match:
54 | extracted_num = match.group(1).replace(",", "")
55 | logger_extractors.debug(f"Extracted numerical answer '{extracted_num}' using GSM pattern")
56 | return extracted_num
57 |
58 | match = re.search(boxed_pattern, text)
59 | if match:
60 | extracted_num = match.group(1).replace(",", "")
61 | logger_extractors.debug(f"Extracted numerical answer '{extracted_num}' using boxed pattern")
62 | return extracted_num
63 |
64 | num_pattern_loose = r'[+-]?\d+(?:,\d+)*(?:\.\d+)?'
65 | numerical_patterns = [
66 | r'(?:final answer is|the final answer is|final answer:|answer:|the answer is)\s*:?\s*(' + num_pattern_loose + r')\b',
67 | r'(?:is|equals|result is|=)\s*(' + num_pattern_loose + r')\s*\.?\s*$',
68 | ]
69 |
70 | if lines:
71 | last_line = lines[-1].strip()
72 | for pattern in numerical_patterns:
73 | match = re.search(pattern, last_line, re.IGNORECASE)
74 | if match:
75 | extracted_num = match.group(1).replace(",", "")
76 | logger_extractors.debug(
77 | f"Extracted numerical answer '{extracted_num}' from last line pattern")
78 | return extracted_num
79 | if re.fullmatch(num_pattern_loose, last_line.replace("$", "")):
80 | extracted_num = last_line.replace(",", "").replace("$", "")
81 | logger_extractors.debug(
82 | f"Extracted numerical answer '{extracted_num}' as last line content")
83 | return extracted_num
84 |
85 | for pattern in numerical_patterns:
86 | match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
87 | if match:
88 | extracted_num = match.group(1).replace(",", "")
89 | logger_extractors.debug(
90 | f"Extracted numerical answer '{extracted_num}' from full text pattern")
91 | return extracted_num
92 |
93 | numbers = re.findall(num_pattern_loose, text)
94 | if numbers:
95 | last_num_str = numbers[-1].replace(",", "")
96 | logger_extractors.debug(f"Extracted numerical answer '{last_num_str}' as last number found")
97 | return last_num_str
98 |
99 | return None
100 |
101 |
102 | def _extract_boolean_heuristic(text: str) -> Optional[str]:
103 | text = str(text).strip().lower()
104 | lines = text.strip().splitlines()
105 |
106 | if lines:
107 | last_line = lines[-1].strip().lower().rstrip(".")
108 | if last_line == "yes": return "yes"
109 | if last_line == "no": return "no"
110 | if last_line == "true": return "yes"
111 | if last_line == "false": return "no"
112 | match = re.search(r'\b(?:answer|result) is\s+(yes|no|true|false)\b', last_line)
113 | if match:
114 | ans = match.group(1)
115 | logger_extractors.debug(f"Extracted boolean '{ans}' from last line pattern")
116 | return "yes" if ans == "true" else "no" if ans == "false" else ans
117 |
118 | match = re.search(r'\b(?:final answer|answer|result) is\s+(yes|no|true|false)\b', text)
119 | if match:
120 | ans = match.group(1)
121 | logger_extractors.debug(f"Extracted boolean '{ans}' from full text pattern")
122 | return "yes" if ans == "true" else "no" if ans == "false" else ans
123 |
124 | if "yes" in text.split()[-5:]: return "yes"
125 | if "no" in text.split()[-5:]: return "no"
126 |
127 | return None
128 |
129 |
130 | def _extract_letter_concat_heuristic(text: str) -> Optional[str]:
131 | text = str(text).strip()
132 | lines = text.strip().splitlines()
133 |
134 | pattern = r'(?:final answer|answer|result) is\s*:?\s*([a-zA-Z]+)\b'
135 |
136 | if lines:
137 | last_line = lines[-1].strip().rstrip(".")
138 | match = re.search(pattern, last_line, re.IGNORECASE)
139 | if match:
140 | ans = match.group(1)
141 | logger_extractors.debug(f"Extracted letter concat '{ans}' from last line pattern")
142 | return ans
143 | if re.fullmatch(r'[a-zA-Z]+', last_line):
144 | logger_extractors.debug(f"Extracted letter concat '{last_line}' as last line content")
145 | return last_line
146 |
147 | match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
148 | if match:
149 | ans = match.group(1)
150 | logger_extractors.debug(f"Extracted letter concat '{ans}' from full text pattern")
151 | return ans
152 |
153 | if lines and re.fullmatch(r'[a-zA-Z]{2,}', lines[-1].strip()):
154 | logger_extractors.debug(
155 | f"Extracted letter concat '{lines[-1].strip()}' as fallback last line")
156 | return lines[-1].strip()
157 |
158 | return None
159 |
160 |
161 | def extract_answer_heuristic_custom(raw_output: str, dataset_name: str) -> str:
162 | if not raw_output or str(raw_output).strip() == "[ERROR]" or str(raw_output).strip().startswith(
163 | "[ERROR:"):
164 | return "[ERROR]"
165 |
166 | text = str(raw_output).strip()
167 | extracted: Optional[str] = None
168 |
169 | if dataset_name in ["aqua", "csqa"]:
170 | extracted = _extract_label_heuristic(text)
171 | elif dataset_name in ["gsm8k", "multiarith"]:
172 | extracted = _extract_numerical_heuristic(text)
173 | elif dataset_name in ["strategyqa", "coin"]:
174 | extracted = _extract_boolean_heuristic(text)
175 | elif dataset_name == "letter":
176 | extracted = _extract_letter_concat_heuristic(text)
177 | else:
178 | logger_extractors.warning(
179 | f"No specific heuristic defined for dataset '{dataset_name}'. Using generic fallback.")
180 | extracted = _extract_numerical_heuristic(text)
181 | if extracted is None:
182 | extracted = _extract_label_heuristic(text)
183 |
184 | if extracted is None:
185 | logger_extractors.warning(
186 | f"Could not extract answer for dataset '{dataset_name}' using custom heuristic from: '{text[:150]}...'")
187 | return "[EXTRACTION_HEURISTIC_FAILURE]"
188 |
189 | return extracted.strip()
190 |
191 |
192 | MCQ_EXTRACTION_PROMPT = """
193 | Original Question:
194 | {question}
195 |
196 | LLM Reasoning and Output:
197 | \"\"\"
198 | {raw_output}
199 | \"\"\"
200 |
201 | Analyze the LLM Reasoning and Output based on the Original Question.
202 | The original question is multiple choice with options typically labeled A, B, C, D, E.
203 | Extract ONLY the single capital letter corresponding to the final choice identified in the reasoning or output.
204 | If the reasoning calculates a value, ensure you extract the letter label associated with that value in the options.
205 | If no specific choice label is clearly identified as the final answer, return null.
206 | Return the result as a JSON object with a single key "final_answer" containing the final answer label (A, B, C, D, or E) as a string, or null if not found.
207 |
208 | JSON Output:
209 | """
210 |
211 | NUMERICAL_EXTRACTION_PROMPT = """
212 | Original Question:
213 | {question}
214 |
215 | LLM Reasoning and Output:
216 | \"\"\"
217 | {raw_output}
218 | \"\"\"
219 |
220 | Analyze the LLM Reasoning and Output based on the Original Question.
221 | Extract ONLY the final numerical answer stated in the text. Ignore intermediate calculations.
222 | Look for patterns like "Final Answer: [number]", "#### [number]", or the last number mentioned if it seems conclusive.
223 | Return the result as a JSON object with a single key "final_answer" containing the final numerical answer as a string (e.g., "15", "36.5", "77.5"), or null if no definitive numerical answer is found.
224 |
225 | JSON Output:
226 | """
227 |
228 | BOOLEAN_EXTRACTION_PROMPT = """
229 | Original Question:
230 | {question}
231 |
232 | LLM Reasoning and Output:
233 | \"\"\"
234 | {raw_output}
235 | \"\"\"
236 |
237 | Analyze the LLM Reasoning and Output based on the Original Question.
238 | Extract ONLY the final boolean answer (yes or no) stated in the text.
239 | Look for explicit statements like "Answer: yes", "Final Answer: no", or the concluding yes/no statement.
240 | Return the result as a JSON object with a single key "final_answer" containing either the string "yes" or "no", or null if no definitive boolean answer is found.
241 |
242 | JSON Output:
243 | """
244 |
245 | TEXT_EXTRACTION_PROMPT = """
246 | Original Question:
247 | {question}
248 |
249 | LLM Reasoning and Output:
250 | \"\"\"
251 | {raw_output}
252 | \"\"\"
253 |
254 | Analyze the LLM Reasoning and Output based on the Original Question.
255 | Extract ONLY the final short text answer stated in the text (e.g., a concatenated string of letters for the 'letter' dataset).
256 | Look for patterns like "Answer: [text]" or the concluding text segment if it seems to be the final answer.
257 | Return the result as a JSON object with a single key "final_answer" containing the final text answer as a string, or null if no definitive text answer is found.
258 |
259 | JSON Output:
260 | """
261 |
262 |
263 | def get_llm_extraction_prompt(question: str, raw_output: str, dataset_name: str) -> str:
264 | template: str
265 | if dataset_name in ["aqua", "csqa"]:
266 | template = MCQ_EXTRACTION_PROMPT
267 | elif dataset_name in ["gsm8k", "multiarith"]:
268 | template = NUMERICAL_EXTRACTION_PROMPT
269 | elif dataset_name in ["strategyqa", "coin"]:
270 | template = BOOLEAN_EXTRACTION_PROMPT
271 | elif dataset_name == "letter":
272 | template = TEXT_EXTRACTION_PROMPT
273 | else:
274 | logger_extractors.warning(
275 | f"No specific LLM prompt template defined for dataset '{dataset_name}'. Using generic numerical fallback.")
276 | template = NUMERICAL_EXTRACTION_PROMPT
277 |
278 | return template.format(question=question, raw_output=raw_output)
279 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | ignore:
2 | - "examples/*"
3 | - "benches/*"
4 |
--------------------------------------------------------------------------------
/cogitator/__init__.py:
--------------------------------------------------------------------------------
1 | """Cogitator: A Python Toolkit for Chain-of-Thought Prompting.
2 |
3 | This package provides implementations of various chain-of-thought (CoT) prompting
4 | strategies and frameworks, along with supporting utilities like LLM provider interfaces,
5 | embedding models, clustering algorithms, and data validation schemas.
6 | It aims to make it easier to try and integrate CoT methods into AI applications.
7 | """
8 |
9 | import importlib
10 | import logging
11 |
12 | from .clustering import BaseClusterer, KMeansClusterer
13 | from .embedding import BaseEmbedder, SentenceTransformerEmbedder
14 | from .model import BaseLLM, OllamaLLM, OpenAILLM
15 | from .schemas import (
16 | EvaluationResult,
17 | ExtractedAnswer,
18 | LTMDecomposition,
19 | ThoughtExpansion,
20 | )
21 | from .strategies import (
22 | AutoCoT,
23 | CDWCoT,
24 | GraphOfThoughts,
25 | LeastToMost,
26 | SelfConsistency,
27 | TreeOfThoughts,
28 | )
29 | from .utils import accuracy, approx_token_length, count_steps, exact_match
30 |
31 | _logger = logging.getLogger(__name__)
32 | try:
33 | __version__ = importlib.metadata.version("cogitator")
34 | except importlib.metadata.PackageNotFoundError:
35 | __version__ = "0.0.0-unknown"
36 | _logger.warning(
37 | "Could not determine package version using importlib.metadata. "
38 | "Is the library installed correctly?"
39 | )
40 |
41 | __all__ = [
42 | "AutoCoT",
43 | "BaseClusterer",
44 | "BaseEmbedder",
45 | "BaseLLM",
46 | "CDWCoT",
47 | "EvaluationResult",
48 | "ExtractedAnswer",
49 | "GraphOfThoughts",
50 | "KMeansClusterer",
51 | "LTMDecomposition",
52 | "LeastToMost",
53 | "OllamaLLM",
54 | "OpenAILLM",
55 | "SelfConsistency",
56 | "SentenceTransformerEmbedder",
57 | "ThoughtExpansion",
58 | "TreeOfThoughts",
59 | "accuracy",
60 | "approx_token_length",
61 | "count_steps",
62 | "exact_match",
63 | ]
64 |
--------------------------------------------------------------------------------
/cogitator/clustering.py:
--------------------------------------------------------------------------------
1 | """Provides abstractions and implementations for clustering algorithms."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any, Tuple
5 |
6 | import numpy as np
7 | from sklearn.cluster import KMeans
8 |
9 |
10 | class BaseClusterer(ABC):
11 | """Abstract base class for clustering algorithms."""
12 |
13 | @abstractmethod
14 | def cluster(
15 | self, embeddings: np.ndarray, n_clusters: int, **kwargs: Any
16 | ) -> Tuple[np.ndarray, np.ndarray]:
17 | """Clusters the given embeddings into a specified number of clusters.
18 |
19 | Args:
20 | embeddings: A NumPy array where each row is an embedding vector.
21 | n_clusters: The desired number of clusters.
22 | **kwargs: Additional keyword arguments specific to the clustering implementation.
23 |
24 | Returns:
25 | A tuple containing:
26 | - A NumPy array of cluster labels assigned to each embedding.
27 | - A NumPy array of cluster centers.
28 | """
29 | ...
30 |
31 |
32 | class KMeansClusterer(BaseClusterer):
33 | """A clustering implementation using the K-Means algorithm from scikit-learn."""
34 |
35 | def cluster(
36 | self, embeddings: np.ndarray, n_clusters: int, **kwargs: Any
37 | ) -> Tuple[np.ndarray, np.ndarray]:
38 | """Clusters embeddings using K-Means.
39 |
40 | Args:
41 | embeddings: The embeddings to cluster (shape: [n_samples, n_features]).
42 | n_clusters: The number of clusters to form.
43 | **kwargs: Additional arguments for `sklearn.cluster.KMeans`.
44 | Supported args include `random_seed` (or `seed`) and `n_init`.
45 |
46 | Returns:
47 | A tuple containing:
48 | - labels (np.ndarray): Integer labels array (shape: [n_samples,]).
49 | - centers (np.ndarray): Coordinates of cluster centers (shape: [n_clusters, n_features]).
50 |
51 | Raises:
52 | ValueError: If `n_clusters` is invalid or embeddings are incompatible.
53 | """
54 | random_seed = kwargs.get("random_seed") or kwargs.get("seed")
55 | n_init = kwargs.get("n_init", "auto")
56 | kmeans = KMeans(
57 | n_clusters=n_clusters,
58 | random_state=random_seed,
59 | n_init=n_init,
60 | init="k-means++",
61 | )
62 | labels = kmeans.fit_predict(embeddings)
63 | return labels, kmeans.cluster_centers_
64 |
--------------------------------------------------------------------------------
/cogitator/embedding.py:
--------------------------------------------------------------------------------
1 | """Provides abstractions and implementations for text embedding models."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import List, Optional
5 |
6 | import numpy as np
7 | from sentence_transformers import SentenceTransformer
8 |
9 |
10 | class BaseEmbedder(ABC):
11 | """Abstract base class for text embedding models."""
12 |
13 | @abstractmethod
14 | def encode(self, texts: List[str]) -> List[np.ndarray]:
15 | """Encodes a list of texts into embedding vectors.
16 |
17 | Args:
18 | texts: A list of strings to encode.
19 |
20 | Returns:
21 | A list of NumPy arrays, where each array is the embedding vector for
22 | the corresponding text.
23 | """
24 | ...
25 |
26 |
27 | class SentenceTransformerEmbedder(BaseEmbedder):
28 | """An embedder implementation using the sentence-transformers library.
29 |
30 | This class uses a singleton pattern to avoid reloading the model multiple times.
31 | """
32 |
33 | _instance: Optional["SentenceTransformerEmbedder"] = None
34 | _model: Optional[SentenceTransformer] = None
35 |
36 | def __new__(cls, model_name: str = "all-MiniLM-L6-v2") -> "SentenceTransformerEmbedder":
37 | """Creates or returns the singleton instance of the embedder.
38 |
39 | Args:
40 | model_name: The name of the sentence-transformer model to load.
41 | This argument is only used during the first instantiation.
42 |
43 | Returns:
44 | The singleton instance of SentenceTransformerEmbedder.
45 | """
46 | if cls._instance is None:
47 | cls._instance = super(SentenceTransformerEmbedder, cls).__new__(cls)
48 | cls._model = SentenceTransformer(model_name)
49 | return cls._instance
50 |
51 | def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
52 | """Initializes the SentenceTransformerEmbedder instance.
53 |
54 | Note: Due to the singleton pattern implemented in `__new__`, the
55 | `model_name` argument here is effectively ignored after the first
56 | instantiation. The model loaded is determined by the `model_name`
57 | passed during the first call to `__new__` or `__init__`.
58 |
59 | Args:
60 | model_name: The name of the sentence-transformer model. Defaults to
61 | "all-MiniLM-L6-v2".
62 | """
63 | pass
64 |
65 | def encode(self, texts: List[str]) -> List[np.ndarray]:
66 | """Encodes a list of texts using the loaded sentence-transformer model.
67 |
68 | Args:
69 | texts: The list of strings to encode.
70 |
71 | Returns:
72 | A list of NumPy ndarray embeddings.
73 |
74 | Raises:
75 | RuntimeError: If the embedding model has not been initialized correctly.
76 | """
77 | if self._model is None:
78 | raise RuntimeError("Embedder model not initialized.")
79 | embeddings: List[np.ndarray] = self._model.encode(
80 | texts, convert_to_numpy=True, show_progress_bar=False
81 | )
82 | return embeddings
83 |
--------------------------------------------------------------------------------
/cogitator/model/__init__.py:
--------------------------------------------------------------------------------
1 | """Provides interfaces and implementations for LLM providers.
2 |
3 | This subpackage defines the abstract base class `BaseLLM` and implementations for interacting with
4 | different LLM services like OpenAI and Ollama.
5 | """
6 |
7 | from .base import BaseLLM
8 | from .ollama import OllamaLLM
9 | from .openai import OpenAILLM
10 |
11 | __all__ = [
12 | "BaseLLM",
13 | "OllamaLLM",
14 | "OpenAILLM",
15 | ]
16 |
--------------------------------------------------------------------------------
/cogitator/model/base.py:
--------------------------------------------------------------------------------
1 | """Defines the abstract base class for LLM providers."""
2 |
3 | import asyncio
4 | import json
5 | import logging
6 | import re
7 | import time
8 | from abc import ABC, abstractmethod
9 | from typing import Any, AsyncIterator, Iterator, Optional, Tuple, Type
10 |
11 | from pydantic import BaseModel, ValidationError
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class BaseLLM(ABC):
17 | """Abstract base class defining the interface for LLM providers."""
18 |
19 | def __init__(self) -> None:
20 | """Initializes token count storage."""
21 | self._last_prompt_tokens: Optional[int] = None
22 | self._last_completion_tokens: Optional[int] = None
23 |
24 | def get_last_prompt_tokens(self) -> Optional[int]:
25 | """Returns the token count for the last prompt, if available."""
26 | return self._last_prompt_tokens
27 |
28 | def get_last_completion_tokens(self) -> Optional[int]:
29 | """Returns the token count for the last completion, if available."""
30 | return self._last_completion_tokens
31 |
32 | def _reset_token_counts(self) -> None:
33 | """Resets the stored token counts."""
34 | self._last_prompt_tokens = None
35 | self._last_completion_tokens = None
36 |
37 | @abstractmethod
38 | def generate(self, prompt: str, **kwargs: Any) -> str:
39 | """Generates a single text completion for the given prompt.
40 |
41 | Args:
42 | prompt: The input text prompt.
43 | **kwargs: Additional provider-specific parameters (e.g., temperature,
44 | max_tokens, stop sequences, seed).
45 |
46 | Returns:
47 | The generated text completion as a string.
48 |
49 | Raises:
50 | RuntimeError: If the generation fails after retries or due to API errors.
51 | """
52 | ...
53 |
54 | @abstractmethod
55 | async def generate_async(self, prompt: str, **kwargs: Any) -> str:
56 | """Asynchronously generates a single text completion for the given prompt.
57 |
58 | Args:
59 | prompt: The input text prompt.
60 | **kwargs: Additional provider-specific parameters.
61 |
62 | Returns:
63 | The generated text completion as a string.
64 |
65 | Raises:
66 | RuntimeError: If the asynchronous generation fails.
67 | """
68 | ...
69 |
70 | @abstractmethod
71 | def generate_stream(self, prompt: str, **kwargs: Any) -> Iterator[str]:
72 | """Generates a stream of text chunks for the given prompt.
73 |
74 | Args:
75 | prompt: The input text prompt.
76 | **kwargs: Additional provider-specific parameters.
77 |
78 | Yields:
79 | Strings representing chunks of the generated text.
80 |
81 | Raises:
82 | RuntimeError: If starting the stream generation fails.
83 | """
84 | ...
85 |
86 | @abstractmethod
87 | async def generate_stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[str]:
88 | """Asynchronously generates a stream of text chunks for the given prompt.
89 |
90 | Args:
91 | prompt: The input text prompt.
92 | **kwargs: Additional provider-specific parameters.
93 |
94 | Yields:
95 | Strings representing chunks of the generated text asynchronously.
96 |
97 | Raises:
98 | RuntimeError: If starting the asynchronous stream generation fails.
99 | """
100 | ...
101 |
102 | @abstractmethod
103 | def _generate_json_internal(
104 | self, prompt: str, response_model: Type[BaseModel], **kwargs: Any
105 | ) -> Tuple[str, Optional[str]]:
106 | """Internal method to generate raw JSON output string from the LLM.
107 |
108 | This method handles the actual API call for JSON generation, potentially
109 | using provider-specific features like JSON mode or schema enforcement.
110 | It should also handle updating the internal token counts.
111 |
112 | Args:
113 | prompt: The input prompt, potentially instructing JSON format.
114 | response_model: The Pydantic model class for the expected response structure.
115 | **kwargs: Additional provider-specific parameters.
116 |
117 | Returns:
118 | A tuple containing:
119 | - The raw string response from the LLM, expected to be JSON.
120 | - An optional string indicating the JSON generation mode used (e.g.,
121 | 'json_schema', 'json_object', 'heuristic'), or None if extraction
122 | is needed.
123 |
124 | Raises:
125 | RuntimeError: If the underlying LLM call fails.
126 | """
127 | ...
128 |
129 | @abstractmethod
130 | async def _generate_json_internal_async(
131 | self, prompt: str, response_model: Type[BaseModel], **kwargs: Any
132 | ) -> Tuple[str, Optional[str]]:
133 | """Asynchronous internal method to generate raw JSON output string from the LLM.
134 |
135 | It should also handle updating the internal token counts.
136 |
137 | Args:
138 | prompt: The input prompt, potentially instructing JSON format.
139 | response_model: The Pydantic model class for the expected response structure.
140 | **kwargs: Additional provider-specific parameters.
141 |
142 | Returns:
143 | A tuple containing:
144 | - The raw string response from the LLM, expected to be JSON.
145 | - An optional string indicating the JSON generation mode used.
146 |
147 | Raises:
148 | RuntimeError: If the underlying asynchronous LLM call fails.
149 | """
150 | ...
151 |
152 | def _extract_json_block(self, text: str) -> str:
153 | """Extracts the first JSON object or array from a string.
154 |
155 | Handles JSON enclosed in markdown code fences (```json ... ``` or ``` ... ```)
156 | or finds the first substring starting with '{' and ending with '}' or
157 | starting with '[' and ending with ']'.
158 |
159 | Args:
160 | text: The string possibly containing a JSON block.
161 |
162 | Returns:
163 | The extracted JSON string, or the original text if no block is found.
164 | """
165 | fence_match = re.search(
166 | r"```(?:json)?\s*(\{.*\}|\[.*\])\s*```", text, re.DOTALL | re.IGNORECASE
167 | )
168 | if fence_match:
169 | return fence_match.group(1)
170 |
171 | # Find the first standalone JSON object or array
172 | first_obj_start = text.find("{")
173 | first_arr_start = text.find("[")
174 |
175 | if first_obj_start == -1 and first_arr_start == -1:
176 | return text # No JSON start found
177 |
178 | start_index = -1
179 | if first_obj_start != -1 and first_arr_start != -1:
180 | start_index = min(first_obj_start, first_arr_start)
181 | elif first_obj_start != -1:
182 | start_index = first_obj_start
183 | else: # first_arr_start != -1
184 | start_index = first_arr_start
185 |
186 | # Attempt to find the matching end brace/bracket
187 | # This is a simplified approach and might fail for complex nested structures
188 | # if they appear outside the main intended JSON block.
189 | json_str = text[start_index:]
190 | try:
191 | # Try parsing to find the end implicitly
192 | parsed_obj, end_index = json.JSONDecoder().raw_decode(json_str)
193 | return json_str[:end_index]
194 | except json.JSONDecodeError:
195 | # Fallback: Search for the last brace/bracket if raw_decode fails
196 | # This is less reliable.
197 | last_brace = text.rfind("}")
198 | last_bracket = text.rfind("]")
199 | end_index = max(last_brace, last_bracket)
200 | if end_index > start_index:
201 | potential_json = text[start_index : end_index + 1]
202 | # Final check if this substring is valid JSON
203 | try:
204 | json.loads(potential_json)
205 | return potential_json
206 | except json.JSONDecodeError:
207 | pass # Fall through if this substring isn't valid
208 |
209 | # If parsing/fallback fails, return the original text
210 | return text
211 |
212 | def generate_json(
213 | self, prompt: str, response_model: Type[BaseModel], retries: int = 2, **kwargs: Any
214 | ) -> BaseModel:
215 | """Generates a response and parses it into a Pydantic model instance.
216 |
217 | Uses `_generate_json_internal` and attempts to parse the result.
218 | Retries on validation or decoding errors. Also updates internal token counts.
219 |
220 | Args:
221 | prompt: The input prompt, often instructing the LLM to respond in JSON.
222 | response_model: The Pydantic model class to validate the response against.
223 | retries: The number of times to retry on parsing/validation failure.
224 | **kwargs: Additional provider-specific parameters for generation.
225 |
226 | Returns:
227 | An instance of the `response_model` populated with data from the LLM response.
228 |
229 | Raises:
230 | RuntimeError: If parsing fails after all retries.
231 | ValidationError: If the final response does not match the `response_model`.
232 | json.JSONDecodeError: If the final response is not valid JSON.
233 | """
234 | last_error = None
235 | temp = kwargs.pop("temperature", 0.1)
236 | json_kwargs = {**kwargs, "temperature": temp}
237 | self._reset_token_counts() # Reset before attempts
238 |
239 | for attempt in range(retries + 1):
240 | raw = ""
241 | block = ""
242 | mode_used = None
243 | try:
244 | # _generate_json_internal is responsible for updating token counts
245 | raw, mode_used = self._generate_json_internal(prompt, response_model, **json_kwargs)
246 |
247 | if mode_used in ["json_schema", "json_object", "ollama_schema_format"]:
248 | # Assume the provider handled JSON enforcement
249 | block = raw
250 | else:
251 | # Fallback to extracting JSON block heuristically
252 | block = self._extract_json_block(raw)
253 |
254 | validated_model = response_model.model_validate_json(block.strip())
255 | # Token counts should have been set by _generate_json_internal
256 | return validated_model
257 | except (json.JSONDecodeError, ValidationError) as ve:
258 | last_error = ve
259 | logger.warning(
260 | "JSON validation/decode error %d/%d (mode: %s): %s\nBlock: %.200s\nRaw: %.200s",
261 | attempt + 1,
262 | retries + 1,
263 | mode_used,
264 | ve,
265 | block,
266 | raw,
267 | )
268 | self._reset_token_counts() # Reset counts on error
269 | except Exception as e:
270 | last_error = e
271 | logger.error(
272 | "Error generating JSON %d/%d (mode: %s): %s",
273 | attempt + 1,
274 | retries + 1,
275 | mode_used,
276 | e,
277 | exc_info=True,
278 | )
279 | self._reset_token_counts() # Reset counts on error
280 |
281 | if attempt < retries:
282 | sleep_time = 2**attempt
283 | logger.info(f"Retrying JSON generation in {sleep_time} seconds...")
284 | time.sleep(sleep_time)
285 | self._reset_token_counts() # Reset before retry
286 |
287 | # If loop finishes without success
288 | raise RuntimeError(
289 | f"generate_json failed after {retries + 1} attempts. Last error: {type(last_error).__name__}: {last_error}"
290 | )
291 |
292 | async def generate_json_async(
293 | self, prompt: str, response_model: Type[BaseModel], retries: int = 2, **kwargs: Any
294 | ) -> BaseModel:
295 | """Asynchronously generates a response and parses it into a Pydantic model instance.
296 |
297 | Uses `_generate_json_internal_async` and attempts to parse the result.
298 | Retries on validation or decoding errors. Also updates internal token counts.
299 |
300 | Args:
301 | prompt: The input prompt, often instructing the LLM to respond in JSON.
302 | response_model: The Pydantic model class to validate the response against.
303 | retries: The number of times to retry on parsing/validation failure.
304 | **kwargs: Additional provider-specific parameters for generation.
305 |
306 | Returns:
307 | An instance of the `response_model` populated with data from the LLM response.
308 |
309 | Raises:
310 | RuntimeError: If parsing fails after all retries.
311 | ValidationError: If the final response does not match the `response_model`.
312 | json.JSONDecodeError: If the final response is not valid JSON.
313 | """
314 | last_error = None
315 | temp = kwargs.pop("temperature", 0.1)
316 | json_kwargs = {**kwargs, "temperature": temp}
317 | self._reset_token_counts() # Reset before attempts
318 |
319 | for attempt in range(retries + 1):
320 | raw = ""
321 | block = ""
322 | mode_used = None
323 | try:
324 | # _generate_json_internal_async is responsible for updating token counts
325 | raw, mode_used = await self._generate_json_internal_async(
326 | prompt, response_model, **json_kwargs
327 | )
328 |
329 | if mode_used in ["json_schema", "json_object", "ollama_schema_format"]:
330 | block = raw
331 | else:
332 | block = self._extract_json_block(raw)
333 |
334 | validated_model = response_model.model_validate_json(block.strip())
335 | # Token counts should have been set by _generate_json_internal_async
336 | return validated_model
337 | except (json.JSONDecodeError, ValidationError) as ve:
338 | last_error = ve
339 | logger.warning(
340 | "Async JSON validation/decode error %d/%d (mode: %s): %s\nBlock: %.200s\nRaw: %.200s",
341 | attempt + 1,
342 | retries + 1,
343 | mode_used,
344 | ve,
345 | block,
346 | raw,
347 | )
348 | self._reset_token_counts() # Reset counts on error
349 | except Exception as e:
350 | last_error = e
351 | logger.error(
352 | "Error generating JSON async %d/%d (mode: %s): %s",
353 | attempt + 1,
354 | retries + 1,
355 | mode_used,
356 | e,
357 | exc_info=True,
358 | )
359 | self._reset_token_counts() # Reset counts on error
360 |
361 | if attempt < retries:
362 | sleep_time = 2**attempt
363 | logger.info(f"Retrying async JSON generation in {sleep_time} seconds...")
364 | await asyncio.sleep(sleep_time)
365 | self._reset_token_counts() # Reset before retry
366 |
367 | # If loop finishes without success
368 | raise RuntimeError(
369 | f"generate_json_async failed after {retries + 1} attempts. Last error: {type(last_error).__name__}: {last_error}"
370 | )
371 |
--------------------------------------------------------------------------------
/cogitator/model/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/cogitator/model/py.typed
--------------------------------------------------------------------------------
/cogitator/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/cogitator/py.typed
--------------------------------------------------------------------------------
/cogitator/schemas.py:
--------------------------------------------------------------------------------
1 | """Defines Pydantic models for structured data exchange within Cogitator."""
2 |
3 | from typing import List, Optional, Union
4 |
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class LTMDecomposition(BaseModel):
9 | """Schema for the output of the Least-to-Most decomposition step."""
10 |
11 | subquestions: List[str] = Field(..., description="List of sequential subquestions")
12 |
13 |
14 | class ThoughtExpansion(BaseModel):
15 | """Schema for the output of a thought expansion step (e.g., in ToT)."""
16 |
17 | thoughts: List[str] = Field(..., description="List of distinct reasoning steps or thoughts")
18 |
19 |
20 | class EvaluationResult(BaseModel):
21 | """Schema for the output of an evaluation step (e.g., in ToT, GoT)."""
22 |
23 | score: int = Field(..., description="Quality score from 1 to 10")
24 | justification: str = Field(..., description="Brief justification for the score")
25 |
26 |
27 | class ExtractedAnswer(BaseModel):
28 | """Schema for the final extracted answer from a reasoning chain."""
29 |
30 | final_answer: Optional[Union[str, int, float]] = Field(
31 | ..., description="The final extracted answer"
32 | )
33 |
--------------------------------------------------------------------------------
/cogitator/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | """Initializes the strategies subpackage.
2 |
3 | Exports the main CoT strategy classes implemented within this subpackage, making them
4 | available for direct import from `cogitator.strategies`. This includes various CoT and related reasoning frameworks.
5 | """
6 |
7 | from .auto_cot import AutoCoT
8 | from .cdw_cot import CDWCoT
9 | from .graph_of_thoughts import GraphOfThoughts
10 | from .least_to_most import LeastToMost
11 | from .sc_cot import SelfConsistency
12 | from .tree_of_thoughts import TreeOfThoughts
13 |
14 | __all__ = [
15 | "AutoCoT",
16 | "CDWCoT",
17 | "GraphOfThoughts",
18 | "LeastToMost",
19 | "SelfConsistency",
20 | "TreeOfThoughts",
21 | ]
22 |
--------------------------------------------------------------------------------
/cogitator/strategies/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/cogitator/strategies/py.typed
--------------------------------------------------------------------------------
/cogitator/strategies/sc_cot.py:
--------------------------------------------------------------------------------
1 | """Implements the Self-Consistency Chain-of-Thought (SC-CoT) strategy."""
2 |
3 | import asyncio
4 | import logging
5 | import re
6 | from collections import Counter
7 | from typing import Any, AsyncIterator, Iterator, List, Literal, Optional
8 |
9 | from ..model import BaseLLM
10 | from ..schemas import ExtractedAnswer
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class SelfConsistency:
16 | """Implements the Self-Consistency Chain-of-Thought (SC-CoT) strategy.
17 |
18 | Self-Consistency improves CoT prompting by generating multiple diverse
19 | reasoning paths (using sampling with temperature > 0) and then selecting
20 | the most consistent answer among the paths via majority voting.
21 |
22 | Reference:
23 | Wang et al. (v4; 2023) "Self-Consistency Improves Chain of Thought Reasoning in Language Models".
24 | https://arxiv.org/abs/2203.11171
25 | """
26 |
27 | def __init__(
28 | self,
29 | llm: BaseLLM,
30 | n_samples: int = 10,
31 | temperature: float = 0.8,
32 | max_tokens: int = 256,
33 | stop: Optional[List[str]] = None,
34 | internal_extraction_format: Literal["heuristic", "json"] = "heuristic",
35 | answer_extraction_prompt: Optional[str] = None,
36 | seed: Optional[int] = None,
37 | **gen_kwargs: Any,
38 | ) -> None:
39 | """Initializes the SelfConsistency strategy handler.
40 |
41 | Args:
42 | llm: The language model instance.
43 | n_samples: The number of reasoning paths (samples) to generate.
44 | temperature: Sampling temperature for generating diverse paths. Should be > 0.
45 | max_tokens: Maximum tokens for each generated reasoning path.
46 | stop: Optional stop sequences for LLM generation.
47 | internal_extraction_format: Method for extracting the final answer from
48 | each CoT path ('heuristic' or 'json').
49 | answer_extraction_prompt: Prompt template used only if `internal_extraction_format`
50 | is 'json'. Must include {cot}. Expects JSON output
51 | matching ExtractedAnswer schema.
52 | seed: Base random seed for LLM sampling (each sample may use seed + i).
53 | **gen_kwargs: Additional default keyword arguments for LLM generation calls.
54 | """
55 | self.llm = llm
56 | self.n_samples = n_samples
57 | self.temperature = temperature
58 | self.max_tokens = max_tokens
59 | self.stop = stop
60 | self.internal_extraction_format = internal_extraction_format
61 | self.seed = seed
62 | self.gen_kwargs = gen_kwargs
63 |
64 | if self.internal_extraction_format == "json":
65 | self.answer_extraction_prompt = (
66 | answer_extraction_prompt
67 | or "Analyze the following reasoning chain and extract the final numerical or short answer. "
68 | "Return the result as a JSON object with a single key 'final_answer' containing the answer as a string.\n\n"
69 | "Reasoning Chain:\n{cot}\n\nJSON Answer:"
70 | )
71 | else:
72 | self.answer_extraction_prompt = None
73 |
74 | def _extract_answer_heuristic(self, cot: str) -> str:
75 | """Extracts the final answer from a CoT string using heuristics.
76 |
77 | Searches for common patterns like "answer is X", lines starting with "Answer:",
78 | numeric lines, etc., working from the end of the CoT string upwards.
79 |
80 | Args:
81 | cot: The Chain-of-Thought reasoning string.
82 |
83 | Returns:
84 | The extracted answer string, or the last line as a fallback.
85 | """
86 | lines = cot.strip().splitlines()
87 | for line in reversed(lines):
88 | text = line.strip().rstrip(".")
89 | if "=" in text:
90 | parts = text.split("=", 1)
91 | if len(parts) > 1:
92 | answer = parts[1].strip().lstrip("$").strip()
93 | logger.debug(f"Heuristically extracted answer (equals): '{answer}'")
94 | return answer
95 | m0 = re.search(r"(?i)\bthe answer is\s+(\S+)", text)
96 | if m0:
97 | answer = m0.group(1).lstrip("$").strip()
98 | logger.debug(f"Heuristically extracted answer (the answer is): '{answer}'")
99 | return answer
100 | m1 = re.match(r"(?i)^(?:Answer|Final Answer|Ans)\b[: ]\s*(.+)$", text)
101 | if m1:
102 | answer = m1.group(1).strip()
103 | logger.debug(f"Heuristically extracted answer (Prefix): '{answer}'")
104 | return answer
105 | m2 = re.match(r"^#+\s*([+-]?\d+(?:\.\d+)?)$", text)
106 | if m2:
107 | answer = m2.group(1)
108 | logger.debug(f"Heuristically extracted answer (Header): '{answer}'")
109 | return answer
110 | m3 = re.match(r"^\*{1,2}A[: ]\s*(.+?)\*{0,2}$", text, re.IGNORECASE)
111 | if m3:
112 | answer = m3.group(1).strip()
113 | logger.debug(f"Heuristically extracted answer (Markdown A:): '{answer}'")
114 | return answer
115 | m4 = re.search(r":\s*([+-]?\d+(?:\.\d+)?|[A-Za-z]+)\s*$", text)
116 | if m4:
117 | answer = m4.group(1).strip()
118 | logger.debug(f"Heuristically extracted answer (Colon End): '{answer}'")
119 | return answer
120 | if re.fullmatch(r"\$?[+-]?\d+(?:\.\d+)?", text):
121 | answer = text.lstrip("$")
122 | logger.debug(f"Heuristically extracted answer (Numeric Line): '{answer}'")
123 | return answer
124 | fallback_answer = lines[-1].strip() if lines else ""
125 | logger.debug(f"Heuristically extracted answer (Fallback): '{fallback_answer}'")
126 | return fallback_answer
127 |
128 | def _extract_answer_json(self, cot: str, **kwargs: Any) -> str:
129 | """Extracts the final answer using an LLM call with JSON parsing.
130 |
131 | Uses the `answer_extraction_prompt` and expects a JSON response matching
132 | the `ExtractedAnswer` schema. Falls back to heuristic extraction on failure.
133 |
134 | Args:
135 | cot: The Chain-of-Thought reasoning string.
136 | **kwargs: Additional arguments passed to the LLM `generate_json` call.
137 |
138 | Returns:
139 | The extracted answer string.
140 | """
141 | if not self.answer_extraction_prompt:
142 | logger.warning("JSON extraction requested but prompt is not configured.")
143 | return self._extract_answer_heuristic(cot)
144 |
145 | prompt = self.answer_extraction_prompt.format(cot=cot)
146 | logger.debug("Attempting JSON extraction with prompt:\n%s", prompt)
147 | try:
148 | local_kwargs = kwargs.copy()
149 | result = self.llm.generate_json(
150 | prompt,
151 | response_model=ExtractedAnswer,
152 | max_tokens=local_kwargs.pop("max_tokens", self.max_tokens),
153 | seed=local_kwargs.pop("seed", self.seed),
154 | **local_kwargs,
155 | )
156 | answer = str(result.final_answer).strip()
157 | logger.debug(f"JSON extracted answer: '{answer}'")
158 | return answer
159 | except Exception as e:
160 | logger.error("JSON extraction failed: %s", e, exc_info=True)
161 | logger.warning("JSON extraction failed, falling back to heuristic.")
162 | return self._extract_answer_heuristic(cot)
163 |
164 | async def _extract_answer_json_async(self, cot: str, **kwargs: Any) -> str:
165 | """Asynchronously extracts the final answer using an LLM call with JSON parsing.
166 |
167 | Similar to `_extract_answer_json` but uses async LLM calls.
168 |
169 | Args:
170 | cot: The Chain-of-Thought reasoning string.
171 | **kwargs: Additional arguments passed to the async LLM `generate_json_async` call.
172 |
173 | Returns:
174 | The extracted answer string.
175 | """
176 | if not self.answer_extraction_prompt:
177 | logger.warning("Async JSON extraction requested but prompt is not configured.")
178 | return self._extract_answer_heuristic(cot)
179 |
180 | prompt = self.answer_extraction_prompt.format(cot=cot)
181 | logger.debug("Attempting async JSON extraction with prompt:\n%s", prompt)
182 | try:
183 | local_kwargs = kwargs.copy()
184 | result = await self.llm.generate_json_async(
185 | prompt,
186 | response_model=ExtractedAnswer,
187 | max_tokens=local_kwargs.pop("max_tokens", self.max_tokens),
188 | seed=local_kwargs.pop("seed", self.seed),
189 | **local_kwargs,
190 | )
191 | answer = str(result.final_answer).strip()
192 | logger.debug(f"Async JSON extracted answer: '{answer}'")
193 | return answer
194 | except Exception as e:
195 | logger.error("Async JSON extraction failed: %s", e, exc_info=True)
196 | logger.warning("Async JSON extraction failed, falling back to heuristic.")
197 | return self._extract_answer_heuristic(cot)
198 |
199 | def extract_answer(self, cot: str, **kwargs: Any) -> str:
200 | """Extracts the final answer from a CoT string based on the configured method.
201 |
202 | Delegates to either `_extract_answer_heuristic` or `_extract_answer_json`.
203 |
204 | Args:
205 | cot: The Chain-of-Thought reasoning string.
206 | **kwargs: Arguments passed to the underlying extraction method (if JSON).
207 |
208 | Returns:
209 | The extracted answer string.
210 | """
211 | if self.internal_extraction_format == "json":
212 | return self._extract_answer_json(cot, **kwargs)
213 | return self._extract_answer_heuristic(cot)
214 |
215 | async def extract_answer_async(self, cot: str, **kwargs: Any) -> str:
216 | """Asynchronously extracts the final answer based on the configured method.
217 |
218 | Delegates to `_extract_answer_heuristic` or `_extract_answer_json_async`.
219 |
220 | Args:
221 | cot: The Chain-of-Thought reasoning string.
222 | **kwargs: Arguments passed to the underlying async extraction method (if JSON).
223 |
224 | Returns:
225 | The extracted answer string.
226 | """
227 | if self.internal_extraction_format == "json":
228 | return await self._extract_answer_json_async(cot, **kwargs)
229 | return self._extract_answer_heuristic(cot)
230 |
231 | def run(self, prompt: str, **kwargs: Any) -> str:
232 | """Executes the Self-Consistency strategy.
233 |
234 | Generates `n_samples` reasoning paths using the LLM with the specified
235 | temperature. Extracts the final answer from each path and returns the
236 | most frequent answer (majority vote).
237 |
238 | Args:
239 | prompt: The input prompt for the LLM.
240 | **kwargs: Additional arguments passed to the LLM generation and
241 | answer extraction calls.
242 |
243 | Returns:
244 | The most consistent answer string among the generated paths, or an
245 | empty string if no valid answers are generated.
246 | """
247 | answers: List[str] = []
248 | combined_kwargs = {**self.gen_kwargs, **kwargs}
249 |
250 | for i in range(self.n_samples):
251 | try:
252 | iter_seed = (self.seed + i) if self.seed is not None else None
253 | current_gen_kwargs = combined_kwargs.copy()
254 | cot = self.llm.generate(
255 | prompt,
256 | temperature=current_gen_kwargs.pop("temperature", self.temperature),
257 | max_tokens=current_gen_kwargs.pop("max_tokens", self.max_tokens),
258 | stop=current_gen_kwargs.pop("stop", self.stop),
259 | seed=iter_seed,
260 | **current_gen_kwargs,
261 | )
262 | logger.debug(f"Raw CoT sample {i}: {cot}")
263 | ans = self.extract_answer(cot, **kwargs)
264 | if ans:
265 | answers.append(ans)
266 | else:
267 | logger.debug(f"Sample {i} produced empty answer after extraction.")
268 | except Exception as e:
269 | logger.error(f"Error during SC sample {i}: {e}", exc_info=True)
270 |
271 | if not answers:
272 | logger.warning("SelfConsistency generated no valid answers.")
273 | return ""
274 |
275 | try:
276 | count = Counter(answers)
277 | top_answer, _ = count.most_common(1)[0]
278 | logger.debug(f"SelfConsistency vote counts: {count}")
279 | return top_answer
280 | except IndexError:
281 | logger.error("Could not determine most common answer despite having answers.")
282 | return ""
283 |
284 | async def run_async(
285 | self, prompt: str, semaphore: Optional[asyncio.Semaphore] = None, **kwargs: Any
286 | ) -> str:
287 | """Asynchronously executes the Self-Consistency strategy.
288 |
289 | Generates `n_samples` reasoning paths concurrently using async LLM calls.
290 | Extracts answers asynchronously and returns the majority vote answer.
291 |
292 | Args:
293 | prompt: The input prompt for the LLM.
294 | semaphore: Optional asyncio.Semaphore to limit concurrent LLM calls.
295 | **kwargs: Additional arguments passed to the async LLM generation and
296 | answer extraction calls.
297 |
298 | Returns:
299 | The most consistent answer string, or an empty string if none are generated.
300 | """
301 | combined_kwargs = {**self.gen_kwargs, **kwargs}
302 |
303 | async def sample(i: int) -> Optional[str]:
304 | sample_kwargs = combined_kwargs.copy()
305 | iter_seed = (self.seed + i) if self.seed is not None else None
306 | gen_args = {
307 | "temperature": sample_kwargs.pop("temperature", self.temperature),
308 | "max_tokens": sample_kwargs.pop("max_tokens", self.max_tokens),
309 | "stop": sample_kwargs.pop("stop", self.stop),
310 | "seed": iter_seed,
311 | **sample_kwargs,
312 | }
313 | extraction_kwargs = kwargs.copy()
314 |
315 | task_semaphore = semaphore
316 | if task_semaphore:
317 | await task_semaphore.acquire()
318 | try:
319 | cot = await self.llm.generate_async(prompt, **gen_args)
320 | logger.debug(f"Raw async CoT sample {i}: {cot}")
321 | ans = await self.extract_answer_async(cot, **extraction_kwargs)
322 | if not ans:
323 | logger.debug(f"Async sample {i} produced empty answer after extraction.")
324 | return ans
325 | except Exception as e:
326 | logger.error(f"Error during async SC sample {i}: {e}", exc_info=True)
327 | return None
328 | finally:
329 | if task_semaphore:
330 | task_semaphore.release()
331 |
332 | results = await asyncio.gather(*(sample(i) for i in range(self.n_samples)))
333 | answers = [a for a in results if a is not None and a != ""]
334 | if not answers:
335 | logger.warning("SelfConsistency (async) generated no valid answers.")
336 | return ""
337 |
338 | try:
339 | count = Counter(answers)
340 | top_answer, _ = count.most_common(1)[0]
341 | logger.debug(f"SelfConsistency async vote counts: {count}")
342 | return top_answer
343 | except IndexError:
344 | logger.error("Could not determine most common async answer despite having answers.")
345 | return ""
346 |
347 | def run_stream(self, prompt: str) -> Iterator[str]:
348 | """Streaming is not supported for Self-Consistency."""
349 | raise NotImplementedError("Streaming not supported for SelfConsistency.")
350 |
351 | async def run_stream_async(self, prompt: str) -> AsyncIterator[str]:
352 | """Streaming is not supported for Self-Consistency."""
353 | raise NotImplementedError("Streaming not supported for SelfConsistency.")
354 |
--------------------------------------------------------------------------------
/cogitator/utils.py:
--------------------------------------------------------------------------------
1 | """Provides utility functions used across the Cogitator library."""
2 |
3 | import re
4 |
5 |
6 | def count_steps(cot: str) -> int:
7 | """Counts the number of reasoning steps in a Chain-of-Thought string.
8 |
9 | Identifies steps based on lines starting with digits followed by a period/paren
10 | or lines starting with list markers like -, *, •.
11 |
12 | Args:
13 | cot: The string containing the Chain-of-Thought reasoning.
14 |
15 | Returns:
16 | The integer count of identified reasoning steps.
17 | """
18 | return sum(1 for line in cot.splitlines() if re.match(r"^(\d+[\.\)]|[-*•])\s+", line.strip()))
19 |
20 |
21 | def approx_token_length(text: str) -> int:
22 | """Approximates the number of tokens in a string.
23 |
24 | Counts sequences of word characters and any non-whitespace, non-word characters
25 | as separate tokens. This provides a rough estimate, not a precise token count
26 | based on a specific tokenizer model.
27 |
28 | Args:
29 | text: The input string.
30 |
31 | Returns:
32 | An approximate integer count of tokens.
33 | """
34 | return len(re.findall(r"\w+|[^\w\s]", text))
35 |
36 |
37 | def exact_match(pred: str, gold: str) -> bool:
38 | """Performs case-insensitive exact matching between two strings.
39 |
40 | Strips leading/trailing whitespace and converts both strings to lowercase
41 | before comparison.
42 |
43 | Args:
44 | pred: The predicted string.
45 | gold: The ground truth (gold standard) string.
46 |
47 | Returns:
48 | True if the normalized strings are identical, False otherwise.
49 | """
50 | return pred.strip().lower() == gold.strip().lower()
51 |
52 |
53 | def accuracy(preds: list[str], golds: list[str]) -> float:
54 | """Calculates the exact match accuracy between lists of predictions and golds.
55 |
56 | Uses the `exact_match` function for comparison. Handles potential differences
57 | in list lengths by iterating up to the length of the shorter list if `strict=False`
58 | (default in zip). If `golds` is empty, returns 0.0.
59 |
60 | Args:
61 | preds: A list of predicted strings.
62 | golds: A list of ground truth strings.
63 |
64 | Returns:
65 | The accuracy score as a float between 0.0 and 1.0.
66 | """
67 | if not golds:
68 | return 0.0
69 | matches = sum(exact_match(p, g) for p, g in zip(preds, golds, strict=False))
70 | return matches / len(golds)
71 |
--------------------------------------------------------------------------------
/docs/api/clustering.md:
--------------------------------------------------------------------------------
1 | # Clustering Module
2 |
3 | This module defines abstractions and implementations for clustering algorithms used by a few of the strategies.
4 |
5 | ::: cogitator.clustering.BaseClusterer
6 | options:
7 | show_root_heading: false
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
12 | ::: cogitator.clustering.KMeansClusterer
13 | options:
14 | show_root_heading: false
15 | show_source: true
16 | members_order: source
17 | heading_level: 2
18 |
--------------------------------------------------------------------------------
/docs/api/embedding.md:
--------------------------------------------------------------------------------
1 | # Text Embedding Module
2 |
3 | This module provides abstractions and implementations for text embedding models.
4 |
5 | ::: cogitator.embedding.BaseEmbedder
6 | options:
7 | show_root_heading: false
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
12 | ::: cogitator.embedding.SentenceTransformerEmbedder
13 | options:
14 | show_root_heading: false
15 | show_source: true
16 | members_order: source
17 | heading_level: 2
18 |
--------------------------------------------------------------------------------
/docs/api/model/base.md:
--------------------------------------------------------------------------------
1 | # LLM Provider Interface
2 |
3 | This module defines the abstract base class for all LLM providers.
4 |
5 | ::: cogitator.model.base.BaseLLM
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/model/ollama.md:
--------------------------------------------------------------------------------
1 | # Ollama LLM Implementation
2 |
3 | Implementation of the BaseLLM interface for Ollama models.
4 |
5 | ::: cogitator.model.ollama.OllamaLLM
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/model/openai.md:
--------------------------------------------------------------------------------
1 | # OpenAI LLM Implementation
2 |
3 | Implementation of the BaseLLM interface for OpenAI models.
4 |
5 | ::: cogitator.model.openai.OpenAILLM
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/schemas.md:
--------------------------------------------------------------------------------
1 | # Schemas for Structured Data
2 |
3 | This module defines Pydantic models used for structuring and validating intermediate and final outputs from LLMs.
4 |
5 | ::: cogitator.schemas.LTMDecomposition
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
12 | ::: cogitator.schemas.ThoughtExpansion
13 | options:
14 | show_root_heading: true
15 | show_source: true
16 | members_order: source
17 | heading_level: 2
18 |
19 | ::: cogitator.schemas.EvaluationResult
20 | options:
21 | show_root_heading: true
22 | show_source: true
23 | members_order: source
24 | heading_level: 2
25 |
26 | ::: cogitator.schemas.ExtractedAnswer
27 | options:
28 | show_root_heading: true
29 | show_source: true
30 | members_order: source
31 | heading_level: 2
32 |
--------------------------------------------------------------------------------
/docs/api/strategies/auto_cot.md:
--------------------------------------------------------------------------------
1 | # Automatic Chain-of-Thought
2 |
3 | An implementation of the automatic chain-of-thought (CoT) prompting strategy from [this paper](https://arxiv.org/abs/2210.03493).
4 |
5 | ::: cogitator.strategies.auto_cot.AutoCoT
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/strategies/cdw_cot.md:
--------------------------------------------------------------------------------
1 | # Clustered Distance-Weighted Chain-of-Thought
2 |
3 | An implementation of the clustered distance-weighted CoT framework from [this paper](https://arxiv.org/abs/2501.12226).
4 |
5 | ::: cogitator.strategies.cdw_cot.CDWCoT
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/strategies/graph_of_thoughts.md:
--------------------------------------------------------------------------------
1 | # Graph of Thoughts Framework
2 |
3 | An implementation of the Graph of Thoughts (GoT) reasoning framework from [this paper](https://arxiv.org/abs/2308.09687).
4 |
5 | The implementation represents the reasoning process as a graph where nodes are thoughts and edges represent transformations.
6 | The flow of reasoning is controlled by a **Graph of Operations (GoO)**.
7 |
8 | ## Defining the Graph of Operations (GoO)
9 |
10 | To use `GraphOfThoughts`, you must provide a `graph_of_operations` argument to the `run_async` method.
11 | This argument is a list of tuples, where each tuple defines an operation step:
12 |
13 | `graph_of_operations: List[Tuple[str, Dict]]`
14 |
15 | * The first element of the tuple is the **name** of the operation (e.g., `'Generate'`, `'Score'`, `'KeepBest'`).
16 | * The second element is a **dictionary** containing the parameters specific to that operation.
17 |
18 | **Example GoO:**
19 |
20 | ```python
21 | from cogitator import ThoughtExpansion
22 |
23 | EXAMPLE_GOO = [
24 | # Step 1: Generate 3 new thoughts from the initial question (in 'frontier' set)
25 | # Store results in the 'generated_thoughts' set. Use the 'expand' prompt. Expect ThoughtExpansion schema.
26 | ('Generate', {'k': 3, 'target_set': 'frontier', 'output_set': 'generated_thoughts', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}),
27 |
28 | # Step 2: Score the thoughts generated in the previous step. Use 'evaluate' prompt.
29 | ('Score', {'target_set': 'generated_thoughts', 'prompt_key': 'evaluate'}),
30 |
31 | # Step 3: Keep only the single best-scoring thought from the previous step.
32 | # Put the result back into the 'frontier' set for potential further steps or final answer generation.
33 | ('KeepBest', {'N': 1, 'target_set': 'generated_thoughts', 'output_set': 'frontier'})
34 | ]
35 | ```
36 |
37 | ## Main Class (`GraphOfThoughts`)
38 |
39 | ::: cogitator.strategies.graph_of_thoughts.GraphOfThoughts
40 | options:
41 | show_root_heading: true
42 | show_source: true
43 | members_order: source
44 | heading_level: 3
45 |
46 | ## Available Operations
47 |
48 | Here are the standard operations available.
49 | You can create custom operations by subclassing `GoTOperation`.
50 |
51 | ### Base Operation Class
52 |
53 | ::: cogitator.strategies.graph_of_thoughts.GoTOperation
54 | options:
55 | show_root_heading: true
56 | show_source: false
57 | members_order: source
58 | heading_level: 3
59 |
60 | ### Generate Operation
61 |
62 | ::: cogitator.strategies.graph_of_thoughts.GenerateOp
63 | options:
64 | show_root_heading: true
65 | show_source: false
66 | members_order: source
67 | heading_level: 3
68 |
69 | ### Score Operation
70 |
71 | ::: cogitator.strategies.graph_of_thoughts.ScoreOp
72 | options:
73 | show_root_heading: true
74 | show_source: false
75 | members_order: source
76 | heading_level: 3
77 |
78 | ### KeepBest Operation
79 |
80 | ::: cogitator.strategies.graph_of_thoughts.KeepBestOp
81 | options:
82 | show_root_heading: true
83 | show_source: false
84 | members_order: source
85 | heading_level: 3
86 |
87 | ### Aggregate Operation
88 |
89 | ::: cogitator.strategies.graph_of_thoughts.AggregateOp
90 | options:
91 | show_root_heading: true
92 | show_source: false
93 | members_order: source
94 | heading_level: 3
95 |
96 | ## Internal State (Advanced)
97 |
98 | These classes manage the internal graph structure.
99 |
100 | ### GoTNode
101 |
102 | ::: cogitator.strategies.graph_of_thoughts.GoTNode
103 | options:
104 | show_root_heading: true
105 | show_source: false
106 | members: ["__init__"]
107 | heading_level: 3
108 |
109 | ### GraphReasoningState
110 |
111 | ::: cogitator.strategies.graph_of_thoughts.GraphReasoningState
112 | options:
113 | show_root_heading: true
114 | show_source: false
115 | heading_level: 3
116 |
--------------------------------------------------------------------------------
/docs/api/strategies/least_to_most.md:
--------------------------------------------------------------------------------
1 | # Least-to-Most Prompting
2 |
3 | An implementation of the least-to-most prompting strategy from [this paper](https://arxiv.org/abs/2205.10625).
4 |
5 | ::: cogitator.strategies.least_to_most.LeastToMost
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/strategies/sc_cot.md:
--------------------------------------------------------------------------------
1 | # Self-Consistency Prompting
2 |
3 | An implementation of the self-consistency prompting strategy from [this paper](https://arxiv.org/abs/2003.04933).
4 |
5 | ::: cogitator.strategies.sc_cot.SelfConsistency
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/strategies/tree_of_thoughts.md:
--------------------------------------------------------------------------------
1 | # Tree of Thoughts
2 |
3 | An implementation of the tree of thoughts CoT framework from [this paper](https://arxiv.org/abs/2305.10601).
4 |
5 | ::: cogitator.strategies.tree_of_thoughts.TreeOfThoughts
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
--------------------------------------------------------------------------------
/docs/api/utils.md:
--------------------------------------------------------------------------------
1 | # Utility Functions
2 |
3 | This module provides various utility functions used throughout the Cogitator library, such as metrics calculation and text processing helpers.
4 |
5 | ::: cogitator.utils.count_steps
6 | options:
7 | show_root_heading: true
8 | show_source: true
9 | members_order: source
10 | heading_level: 2
11 |
12 | ::: cogitator.utils.approx_token_length
13 | options:
14 | show_root_heading: true
15 | show_source: true
16 | members_order: source
17 | heading_level: 2
18 |
19 | ::: cogitator.utils.exact_match
20 | options:
21 | show_root_heading: true
22 | show_source: true
23 | members_order: source
24 | heading_level: 2
25 |
26 | ::: cogitator.utils.accuracy
27 | options:
28 | show_root_heading: true
29 | show_source: true
30 | members_order: source
31 | heading_level: 2
32 |
--------------------------------------------------------------------------------
/docs/assets/images/cogitator_v1.dot:
--------------------------------------------------------------------------------
1 | digraph CogitatorWorkflow {
2 | fontname = "Helvetica,Arial,sans-serif"
3 | layout = dot
4 | rankdir = TB // Top-to-Bottom workflow layout
5 | node [
6 | fontname = "Helvetica,Arial,sans-serif",
7 | shape = box,
8 | style = "filled,rounded",
9 | color = "grey",
10 | fillcolor = "white",
11 | penwidth = 2
12 | ]
13 | edge [
14 | fontname = "Helvetica,Arial,sans-serif",
15 | color = "black"
16 | ]
17 |
18 | // Cluster: User Inputs
19 | subgraph cluster_input {
20 | label = "Inputs"
21 | style = "dashed"
22 | color = "lightgrey"
23 | question [label = "Question / Prompt", fillcolor = "lightyellow"]
24 | strategy_choice [label = "Strategy Choice\n(e.g., SC, LtM, GoT, AutoCoT)", fillcolor = "lightyellow"]
25 | llm_choice [label = "LLM Choice\n(Provider, Model Name)", fillcolor = "lightyellow"]
26 | training_data [label = "Training Data\n(Optional: Questions, Answers)", fillcolor = "lightyellow", shape = note]
27 | }
28 |
29 | // Cluster: Cogitator Library Core Components
30 | subgraph cluster_core {
31 | label = "Cogitator Library"
32 | style = "dashed"
33 | color = "lightgrey"
34 | strategy [label = "Selected CoT Strategy\n(e.g., AutoCoT instance)", fillcolor = "lightblue"]
35 | llm_interface [label = "LLM Interface\n(BaseLLM: OpenAI/Ollama)", fillcolor = "lightblue"]
36 | schemas [label = "Pydantic Schemas\n(Structured Output Validation)", fillcolor = "lightgrey", shape = component]
37 | embedding [label = "Embedding Model\n(Optional Usage)", fillcolor = "lightblue", shape = component]
38 | clustering [label = "Clustering Algorithm\n(Optional Usage)", fillcolor = "lightblue", shape = component]
39 | extraction [label = "Answer Extraction Logic\n(Heuristic / LLM-based)", fillcolor = "lightblue", shape = component]
40 | }
41 |
42 | // Cluster: External Dependencies / Services
43 | subgraph cluster_external {
44 | label = "External Services / Models"
45 | style = "dashed"
46 | color = "lightgrey"
47 | llm_backend [label = "LLM Backend\n(OpenAI API / Ollama Server)", fillcolor = "lightpink"]
48 | embedding_backend [label = "Embedding Backend\n(e.g., Sentence Transformers Lib)", fillcolor = "lightpink", shape = cylinder] // Representing the underlying model/lib
49 | }
50 |
51 | // Cluster: Final Output
52 | subgraph cluster_output {
53 | label = "Output"
54 | style = "dashed"
55 | color = "lightgrey"
56 | final_answer [label = "Final Answer / Result", fillcolor = "lightgreen"]
57 | }
58 |
59 | // --- Edges Defining the Flow ---
60 |
61 | // Inputs to Initialization
62 | question -> strategy [label = "is main input to"]
63 | strategy_choice -> strategy [label = "determines instance of"]
64 | llm_choice -> llm_interface [label = "configures"]
65 | training_data -> strategy [label = "used by some for fit/train\n(e.g., AutoCoT, CDWCoT)", style = dashed]
66 |
67 | // Strategy Orchestration
68 | strategy -> llm_interface [label = "makes calls via"]
69 | strategy -> schemas [label = "uses for JSON modes\n(LtM, ToT, GoT, SC, etc.)", style = dashed]
70 | strategy -> embedding [label = "uses sometimes\n(AutoCoT, CDWCoT, GoT)", style = dashed]
71 | strategy -> clustering [label = "uses sometimes\n(AutoCoT, CDWCoT)", style = dashed]
72 | strategy -> extraction [label = "uses sometimes\n(SC, LtM, GoT)", style = dashed]
73 |
74 | // LLM Interaction
75 | llm_interface -> llm_backend [label = "communicates with"]
76 | llm_backend -> llm_interface [label = "returns generation to"]
77 | llm_interface -> strategy [label = "provides results to"]
78 |
79 | // Embedding Interaction (Optional Path)
80 | embedding -> embedding_backend [label = "wraps / uses"]
81 | embedding_backend -> embedding [label = "provides embeddings"]
82 |
83 | // Extraction Interaction (Optional Path)
84 | extraction -> llm_interface [label = "can call LLM for extraction", style = dotted]
85 |
86 | // Final Output
87 | strategy -> final_answer [label = "produces"]
88 |
89 | // Optional: Ranking hints if needed (often not necessary with TB layout)
90 | // { rank=same; question; strategy_choice; llm_choice; training_data }
91 | // { rank=same; llm_backend; embedding_backend }
92 | }
93 |
--------------------------------------------------------------------------------
/docs/assets/images/cogitator_v2.dot:
--------------------------------------------------------------------------------
1 | digraph SimplifiedCogitatorWorkflow {
2 | fontname = "Helvetica,Arial,sans-serif"
3 | layout = dot
4 | rankdir = LR
5 | ranksep = 0.9;
6 | nodesep = 0.7;
7 | splines = true;
8 | compound = true;
9 |
10 | node [
11 | fontname = "Helvetica,Arial,sans-serif",
12 | shape = box,
13 | style = "filled,rounded",
14 | color = "grey",
15 | fillcolor = "white",
16 | penwidth = 1
17 | ]
18 | edge [
19 | fontname = "Helvetica,Arial,sans-serif",
20 | color = "black",
21 | fontsize = 8,
22 | labeldistance = 2.0
23 | ]
24 |
25 | subgraph cluster_input {
26 | label = "Inputs"
27 | style = "dashed"
28 | color = "lightgrey"
29 | margin = 18
30 | question [label = "1. Question / Prompt", fillcolor = "oldlace"]
31 | config [label = "2. Configuration\n(Strategy Choice, LLM Choice)", fillcolor = "oldlace"]
32 | }
33 |
34 | subgraph cluster_core {
35 | label = "Cogitator"
36 | style = "dashed"
37 | color = "lightgrey"
38 | margin = 18
39 | strategy [label = <
40 | 3. Selected CoT Strategy |
41 | Orchestrates steps: |
42 | - Prompt Formatting |
43 | - LLM Calls |
44 | - Intermediate Processing |
45 | (Decomposition, Expansion, |
46 | Evaluation, Extraction, |
47 | Embedding, Clustering...) |
48 |
>, fillcolor ="lightblue", shape = box]
49 | }
50 |
51 | subgraph cluster_external {
52 | label = "LLM Service"
53 | style = "dashed"
54 | color = "lightgrey"
55 | margin = 18
56 | llm [label = "4. Model Provider\n(e.g., OpenAI API / Ollama)", fillcolor ="oldlace"]
57 | }
58 |
59 | subgraph cluster_output {
60 | label = "Output"
61 | style = "dashed"
62 | color = "lightgrey"
63 | margin = 18
64 | final_answer [label = "5. Final Answer", fillcolor = "oldlace"]
65 | }
66 |
67 | question -> strategy [lhead = cluster_core]
68 | config -> strategy [lhead = cluster_core]
69 | config -> llm [lhead= cluster_external]
70 |
71 | strategy -> llm [minlen = 2]
72 | llm -> strategy [minlen = 2]
73 |
74 | strategy -> final_answer [lhead = cluster_output, ltail = cluster_core, minlen = 2]
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/docs/assets/images/cogitator_v2.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
109 |
--------------------------------------------------------------------------------
/docs/assets/images/make_figures.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # You need to have Graphviz installed to run this script
4 | # On Debian-based OSes, you can install it using: sudo apt-get install graphviz
5 |
6 | # Directory containing .dot files (with default value)
7 | ASSET_DIR=${1:-"."}
8 |
9 | # Make figures from .dot files
10 | for f in "${ASSET_DIR}"/*.dot; do
11 | dot -Tsvg "$f" -o "${f%.dot}.svg"
12 | done
13 |
--------------------------------------------------------------------------------
/docs/benchmarking.md:
--------------------------------------------------------------------------------
1 | ## Benchmarking Framework
2 |
3 | Cogitator includes a framework for running benchmarks to compare the performance of different Chain-of-Thought methods
4 | on various datasets.
5 |
6 | ### Overview
7 |
8 | The framework consists of two main scripts:
9 |
10 | 1. **`benches/run.py`**: Handles the generation phase. It loads datasets, configures LLMs and CoT strategies based on
11 | `benches.yml` and CLI arguments, runs the strategies on dataset questions, and saves the raw LLM outputs and timings to
12 | a JSONL file.
13 | 2. **`benches/evaluate.py`**: Handles the evaluation phase. It reads the results file generated by `run.py`, extracts
14 | final answers from the raw outputs (using either heuristics or an LLM), compares them to gold answers, and calculates
15 | accuracy metrics for each method.
16 |
17 | ### Configuration
18 |
19 | Benchmarks are configured using the `benches.yml` file located in the project root, which can be overridden by
20 | command-line arguments passed to the scripts.
21 |
22 | ### How to Run
23 |
24 | Detailed instructions on how to configure and run the benchmarks are available in the benchmarking README
25 | ([`benches/README.md`](https://github.com/habedi/cogitator/blob/main/benches/README.md)).
26 | This includes:
27 |
28 | * Command-line options for `run.py` and `evaluate.py`.
29 | * Detailed explanation of the `benches.yml` configuration structure.
30 | * The benchmark workflow.
31 | * Example usage commands.
32 | * Available datasets and dependencies.
33 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | ## Contributing
2 |
3 | Thank you for considering contributing to Cogitator!
4 | We welcome contributions from the community, whether it is code, documentation, or feedback.
5 | Please refer to the main contribution guidelines document
6 | ([`CONTRIBUTING.md`)](https://github.com/habedi/cogitator/blob/main/CONTRIBUTING.md])
7 | in the repository's root directory for details on how to contribute, report bugs, suggest features, and the development workflow.
8 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Cogitator Documentation
2 |
3 | ## Overview
4 | Cogitator is a Python toolkit for experimenting and working with
5 | [chain-of-thought (CoT) prompting](https://arxiv.org/abs/2201.11903)
6 | methods in large language models (LLMs).
7 | CoT prompting improves LLM performance on complex tasks (like question-answering, reasoning, and problem-solving)
8 | by guiding the models to generate intermediate reasoning steps before arriving at the final answer.
9 | Additionally, it can be used to improve the interpretability of LLMs by providing insight into the model's reasoning process.
10 | The toolkit aims to make it easier to use popular CoT strategies and frameworks for research or integrating them into AI
11 | applications.
12 |
13 | ### Features
14 |
15 | * Provides unified sync/async API for CoT strategies
16 | * Supports using OpenAI and Ollama as LLM providers
17 | * Supports structured model output with Pydantic validation
18 | * Includes a customizable benchmarking framework (see [benches](https://github.com/habedi/cogitator/blob/main/benches))
19 | * Includes implementations of popular CoT strategies and frameworks like
20 | - [Self-Consistency CoT (ICLR 2023)](https://arxiv.org/abs/2203.11171)
21 | - [Automatic CoT (ICLR 2023)](https://arxiv.org/abs/2210.03493)
22 | - [Least-to-Most Prompting (ICLR 2023)](https://arxiv.org/abs/2205.10625)
23 | - [Tree of Thoughts (NeurIPS 2023)](https://arxiv.org/abs/2305.10601)
24 | - [Graph of Thoughts (AAAI 2024)](https://arxiv.org/abs/2308.09687)
25 | - [Clustered Distance-Weighted CoT (AAAI 2025)](https://arxiv.org/abs/2501.12226)
26 |
27 | The diagram below shows a high-level overview of Cogitator's workflow.
28 |
29 | 
30 |
31 | ## Installation
32 |
33 | Cogitator can be installed via pip using the following command:
34 |
35 | ```bash
36 | pip install cogitator
37 | ```
38 |
39 | To run the unit tests and benchmarks, development dependencies are needed that can be installed with the following command:
40 |
41 | ```bash
42 | git clone https://github.com/habedi/cogitator && cd cogitator
43 |
44 | # Set up Python environment
45 | pip install poetry
46 | poetry install --with dev
47 |
48 | # Run the tests to make sure everything is working (optional)
49 | poetry run pytest
50 | ```
51 |
52 | ## Examples
53 |
54 | Check the [examples](https://github.com/habedi/cogitator/blob/main/examples) for usage examples on how to use the library with
55 | different LLM providers and CoT strategies.
56 |
57 | ## API Reference
58 |
59 | The Cogitator library's functionality is organized into several modules:
60 |
61 | * **LLM Providers (`cogitator.model`)**
62 | * [`BaseLLM`](api/model/base.md): Base LLM provider class that defines a common interface for all providers.
63 | * [`OpenAILLM`](api/model/openai.md): LLM provider implementation for using OpenAI models (like gpt-4o-mini and gpt-4o).
64 | * [`OllamaLLM`](api/model/ollama.md): LLM provider implementation for using Ollama models (like Llama, Gemma, and Qwen).
65 |
66 | * **CoT Strategies (`cogitator.strategies`)**
67 | * [`AutoCoT`](api/strategies/auto_cot.md): An implementation of the automatic CoT prompting strategy.
68 | * [`CDWCoT`](api/strategies/cdw_cot.md): An implementation of the clustered distance-weighted CoT framework.
69 | * [`GraphOfThoughts`](api/strategies/graph_of_thoughts.md): An implementation of the graph of thoughts CoT framework.
70 | * [`LeastToMost`](api/strategies/least_to_most.md): An implementation of the least-to-most prompting strategy.
71 | * [`SelfConsistency`](api/strategies/sc_cot.md): An implementation of the self-consistency prompting strategy.
72 | * [`TreeOfThoughts`](api/strategies/tree_of_thoughts.md): An implementation of the tree of thoughts CoT framework.
73 |
74 | * **Data Formatting and Validation (`cogitator.schemas`)**
75 | * [`Schemas`](api/schemas.md): A set of Pydantic models that are used for validating structure of outputs from LLMs.
76 |
77 | * **Utilities**
78 | * [`Embedding`](api/embedding.md): A set of tools for embedding prompt text which is used by strategies like AutoCoT and CDWCoT.
79 | * [`Clustering`](api/clustering.md): Includes clustering algorithms for grouping similar embeddings that is used during the training phase of strategies like AutoCoT and CDWCoT.
80 | * [`Functions`](api/utils.md): A set of utility functions for working with the library.
81 |
82 | ## Extra Resources
83 |
84 | * **[Benchmarking](benchmarking.md):** Learn how to configure and run the performance evaluation framework.
85 | * **[Contributing](contributing.md):** Find guidelines for contributing to the Cogitator project.
86 |
87 |
88 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | ## Examples
2 |
3 | | File | Description |
4 | |------------------------------------------------------|--------------------------------------------------------------------------|
5 | | [run_simple_example.py](run_simple_example.py) | A simple end-to-end example of using the Cogitator library |
6 | | [run_least_to_most.py](run_least_to_most.py) | Example of using the Least-to-Most prompting strategy |
7 | | [run_sc_cot.py](run_sc_cot.py) | Example of using the Self-Consistency prompting strategy |
8 | | [run_auto_cot.py](run_auto_cot.py) | Example of using the Automatic CoT prompting strategy |
9 | | [run_tree_of_thoughts.py](run_tree_of_thoughts.py) | Example of using the Tree of Thoughts prompting framework |
10 | | [run_graph_of_thoughts.py](run_graph_of_thoughts.py) | Example of using the Graph of Thoughts prompting framework |
11 | | [run_cdw_cot.py](run_cdw_cot.py) | Example of using the Clustered Distance-Weighted CoT prompting framework |
12 | | [shared.py](shared.py) | Shared utilies and settings for the examples |
13 |
14 | ## Running Examples
15 |
16 | ```bash
17 | # Run the Least-to-Most example (OpenAI)
18 | python examples/run_least_to_most.py --provider openai --model-name gpt-4.1-nano
19 | ```
20 |
21 | ```bash
22 | # Run the Self-Consistency example (Ollama)
23 | python examples/run_least_to_most.py --provider ollama --model-name gemma3:4b
24 | ```
25 |
26 | ```bash
27 | # Run all examples (Ollama)
28 | make example-ollama
29 | ```
30 |
31 | ```bash
32 | # Run all examples (OpenAI)
33 | make example-openai
34 | ```
35 |
36 | Note that the examples should be run from the root directory of the repository.
37 | Additionally, to use `gemma3:4b` (or any other model like `gemma3:12b`) with Ollama, it must be pulled (or downloaded) first.
38 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habedi/cogitator/68ed941e6561baafccd9d19d5fc2d75a68ccc00a/examples/__init__.py
--------------------------------------------------------------------------------
/examples/run_auto_cot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import logging
5 |
6 | from cogitator import AutoCoT, BaseLLM
7 | from examples.shared import get_llm, run_main, setup_logging
8 |
9 | setup_logging()
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_auto_cot(llm: BaseLLM) -> AutoCoT:
14 | return AutoCoT(
15 | llm,
16 | n_demos=4,
17 | max_q_tokens=100,
18 | max_steps=8,
19 | max_retries=3,
20 | )
21 |
22 |
23 | TRAIN_QUESTIONS = [
24 | "A merchant had 10 apples. He sold 3. How many remain?",
25 | "There are 7 days in a week. How many days in 3 weeks?",
26 | "If you buy 4 pens at $2 each, what's the total cost?",
27 | "A car travels 60 km in 1 hour. How far in 2.5 hours?",
28 | "A rectangle is 3 by 5. What is its area?",
29 | "5 birds are on a wire. 2 fly away. How many left?",
30 | "A baker made 12 buns and packed 3 per box. How many boxes?",
31 | "You read 20 pages per day. How many pages in 5 days?",
32 | ]
33 |
34 | QUESTIONS = [
35 | "John has 8 oranges and gives 3 away. How many does he have?",
36 | "You run 5 km per day. How far in 7 days?",
37 | ]
38 |
39 |
40 | async def main_async(args: argparse.Namespace):
41 | llm = get_llm(args.provider, args.model_name, args.openai_key)
42 | auto = setup_auto_cot(llm)
43 | semaphore = asyncio.Semaphore(5)
44 |
45 | logger.info("Fitting AutoCoT asynchronously...")
46 | await auto.fit_async(TRAIN_QUESTIONS, semaphore=semaphore)
47 |
48 | logger.info("Running test questions asynchronously...")
49 | tasks = [auto.run_async(q) for q in QUESTIONS]
50 | answers = await asyncio.gather(*tasks)
51 |
52 | for q, a in zip(QUESTIONS, answers):
53 | print(f"Q: {q}\nA: {a}\n")
54 |
55 |
56 | def main_sync(args: argparse.Namespace):
57 | llm = get_llm(args.provider, args.model_name, args.openai_key)
58 | auto = setup_auto_cot(llm)
59 |
60 | logger.info("Fitting AutoCoT synchronously...")
61 | auto.fit(TRAIN_QUESTIONS)
62 |
63 | logger.info("Running test questions synchronously...")
64 | for q in QUESTIONS:
65 | result = auto.run(q)
66 | print(f"Q: {q}\nA: {result}\n")
67 |
68 |
69 | if __name__ == "__main__":
70 | run_main(main_sync, main_async, "Run Auto-CoT example")
71 |
--------------------------------------------------------------------------------
/examples/run_cdw_cot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import logging
5 |
6 | from cogitator import BaseLLM, CDWCoT
7 | from examples.shared import get_llm, run_main, setup_logging
8 |
9 | setup_logging()
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_cdw_cot(llm: BaseLLM) -> CDWCoT:
14 | return CDWCoT(llm, pool_size=8, n_clusters=2, lr=0.1, sample_size=3)
15 |
16 |
17 | TRAIN_QUESTIONS = [
18 | "A pot has 5 liters. You add 2 liters. How many liters now?",
19 | "If x+3=7, what is x?",
20 | "There are 4 pens in a box. How many pens in 3 boxes?",
21 | "You walk 2 km in 30 minutes. Distance in 1 hour?",
22 | "5 apples + 3 apples = ?",
23 | "Solve 2y = 10",
24 | "Area of 2x4 rectangle?",
25 | "Cost of 3 items at $5 each?",
26 | ]
27 | TRAIN_ANSWERS = ["7", "4", "12", "4", "8", "5", "8", "15"]
28 |
29 | QUESTIONS = ["If you have 3 boxes of 5 pens each, how many pens?",
30 | "Solve for y: y - 2 = 4"]
31 |
32 |
33 | async def main_async(args: argparse.Namespace):
34 | llm = get_llm(args.provider, args.model_name, args.openai_key)
35 | cdw = setup_cdw_cot(llm)
36 | semaphore = asyncio.Semaphore(5)
37 |
38 | logger.info("Initializing CDW-CoT pool asynchronously...")
39 | try:
40 | await cdw.init_pool_async(TRAIN_QUESTIONS, TRAIN_ANSWERS, semaphore=semaphore)
41 | logger.info("Training CDW-CoT asynchronously...")
42 | await cdw.train_async(val_split=0.4, epochs=5, patience=3, semaphore=semaphore)
43 | logger.info("Running test questions asynchronously...")
44 | tasks = [cdw.run_async(q, semaphore=semaphore) for q in QUESTIONS]
45 | answers = await asyncio.gather(*tasks)
46 | for q, a in zip(QUESTIONS, answers):
47 | print(f"Q: {q}\nA: {a}\n")
48 | except Exception as e:
49 | logger.error(f"CDW-CoT async example failed: {e}", exc_info=True)
50 |
51 |
52 | def main_sync(args: argparse.Namespace):
53 | llm = get_llm(args.provider, args.model_name, args.openai_key)
54 | cdw = setup_cdw_cot(llm)
55 |
56 | logger.info("Initializing CDW-CoT pool synchronously...")
57 | try:
58 | cdw.init_pool(TRAIN_QUESTIONS, TRAIN_ANSWERS)
59 | logger.info("Training CDW-CoT synchronously...")
60 | cdw.train(val_split=0.4, epochs=5, patience=3)
61 | logger.info("Running test questions synchronously...")
62 | for q in QUESTIONS:
63 | out = cdw.run(q)
64 | print(f"Q: {q}\nA: {out}\n")
65 | except Exception as e:
66 | logger.error(f"CDW-CoT sync example failed: {e}", exc_info=True)
67 |
68 |
69 | if __name__ == "__main__":
70 | run_main(main_sync, main_async, "Run CDW-CoT example")
71 |
--------------------------------------------------------------------------------
/examples/run_graph_of_thoughts.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import logging
5 | from typing import List, Dict, Tuple
6 |
7 | from cogitator import BaseLLM, GraphOfThoughts, ThoughtExpansion
8 | from examples.shared import get_llm, run_main, setup_logging
9 |
10 | setup_logging()
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def setup_got(llm: BaseLLM) -> GraphOfThoughts:
15 | return GraphOfThoughts(
16 | llm,
17 | final_answer_format="json", # "text" or "json"
18 | # embedder can be added here if needed by GoO
19 | )
20 |
21 |
22 | # Define a sample Graph of Operations (GoO)
23 | # This defines the steps GoT will take. Adjust as needed for the problem.
24 | # Example: Generate thoughts, score them, keep the best one.
25 | EXAMPLE_GOO: List[Tuple[str, Dict]] = [
26 | ('Generate',
27 | {'k': 3, 'target_set': 'frontier', 'output_set': 'generated_thoughts', 'prompt_key': 'expand',
28 | 'response_schema': ThoughtExpansion}),
29 | ('Score', {'target_set': 'generated_thoughts', 'prompt_key': 'evaluate'}),
30 | ('KeepBest', {'N': 1, 'target_set': 'generated_thoughts', 'output_set': 'frontier'})
31 | # Keep the best node in the frontier for the final answer
32 | ]
33 |
34 | QUESTIONS = [
35 | "A baker made 2 dozen cookies (24) and sold 8. How many left?",
36 | "If 7 times z equals 56, what is z?",
37 | ]
38 |
39 |
40 | async def main_async(args: argparse.Namespace):
41 | llm = get_llm(args.provider, args.model_name, args.openai_key)
42 | got = setup_got(llm)
43 | semaphore = asyncio.Semaphore(5) # Concurrency limit for LLM calls
44 |
45 | logger.info("Running GraphOfThoughts asynchronously...")
46 | tasks = [got.run_async(q, graph_of_operations=EXAMPLE_GOO, semaphore=semaphore) for q in
47 | QUESTIONS]
48 | answers = await asyncio.gather(*tasks)
49 |
50 | for q, a in zip(QUESTIONS, answers):
51 | print(f"Q: {q}\nA: {a}\n")
52 |
53 |
54 | def main_sync(args: argparse.Namespace):
55 | llm = get_llm(args.provider, args.model_name, args.openai_key)
56 | got = setup_got(llm)
57 |
58 | logger.info("Running GraphOfThoughts synchronously...")
59 | for q in QUESTIONS:
60 | try:
61 | a = got.run(q, graph_of_operations=EXAMPLE_GOO)
62 | print(f"Q: {q}\nA: {a}\n")
63 | except NotImplementedError as e:
64 | logger.debug(f"GraphOfThoughts sync run failed correctly: {e}", exc_info=True)
65 | print("GraphOfThoughts run failed correctly."
66 | " The implementation does not support synchronous execution.")
67 |
68 |
69 | if __name__ == "__main__":
70 | run_main(main_sync, main_async, "Run Graph-of-Thoughts example")
71 |
--------------------------------------------------------------------------------
/examples/run_least_to_most.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import logging
5 |
6 | from cogitator import LeastToMost
7 | from examples.shared import get_llm, run_main, setup_logging
8 |
9 | setup_logging()
10 | logger = logging.getLogger(__name__)
11 |
12 | QUESTIONS = [
13 | "A box has 3 red balls and 5 blue balls. How many balls in total?",
14 | "If x plus 7 equals 12, what is x?",
15 | ]
16 |
17 |
18 | async def main_async(args: argparse.Namespace):
19 | llm = get_llm(args.provider, args.model_name, args.openai_key)
20 | ltm = LeastToMost(llm, intermediate_output_format="json")
21 | semaphore = asyncio.Semaphore(5)
22 |
23 | logger.info("Running LeastToMost asynchronously...")
24 | tasks = [ltm.run_async(q, semaphore=semaphore) for q in QUESTIONS]
25 | answers = await asyncio.gather(*tasks)
26 |
27 | for q, a in zip(QUESTIONS, answers):
28 | print(f"Q: {q}\nA: {a}\n")
29 |
30 |
31 | def main_sync(args: argparse.Namespace):
32 | llm = get_llm(args.provider, args.model_name, args.openai_key)
33 | ltm = LeastToMost(llm, intermediate_output_format="json")
34 |
35 | logger.info("Running LeastToMost synchronously...")
36 | for q in QUESTIONS:
37 | a = ltm.run(q)
38 | print(f"Q: {q}\nA: {a}\n")
39 |
40 |
41 | if __name__ == "__main__":
42 | run_main(main_sync, main_async, "Run Least-to-Most example")
43 |
--------------------------------------------------------------------------------
/examples/run_sc_cot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import logging
5 |
6 | from cogitator import BaseLLM, SelfConsistency
7 | from examples.shared import get_llm, run_main, setup_logging
8 |
9 | setup_logging()
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_sc(llm: BaseLLM) -> SelfConsistency:
14 | return SelfConsistency(
15 | llm,
16 | n_samples=10,
17 | temperature=0.8, # More randomness leads to more diverse answers
18 | max_tokens=200,
19 | internal_extraction_format="json",
20 | # "heuristic" or "json" (JSON option is more robust but might not be supported by all LLMs)
21 | )
22 |
23 |
24 | QUESTIONS = [
25 | "Q: A farmer had 17 sheep. All but 9 died. How many sheep are left?\nA: Let's think step by step.",
26 | "Q: If a train travels 60 miles in 1 hour, how far can it travel in 2.5 hours?\nA: Let's break it down."
27 | ]
28 |
29 |
30 | async def main_async(args: argparse.Namespace):
31 | llm = get_llm(args.provider, args.model_name, args.openai_key)
32 | sc = setup_sc(llm)
33 | semaphore = asyncio.Semaphore(5)
34 |
35 | logger.info("Running SelfConsistency concurrently for multiple questions...")
36 |
37 | async def run_single_question(prompt: str):
38 | logger.info(f"Processing async: {prompt[:50]}...")
39 | answer = await sc.run_async(prompt, semaphore=semaphore)
40 | print(f"\nPrompt: {prompt}")
41 | print(f"Final Answer (async self-consistency): {answer}")
42 | return answer
43 |
44 | tasks = [run_single_question(q) for q in QUESTIONS]
45 | await asyncio.gather(*tasks)
46 | logger.info("Async processing complete.")
47 |
48 |
49 | def main_sync(args: argparse.Namespace):
50 | llm = get_llm(args.provider, args.model_name, args.openai_key)
51 | sc = setup_sc(llm)
52 |
53 | logger.info("Running SelfConsistency sequentially for multiple questions...")
54 | for i, prompt in enumerate(QUESTIONS):
55 | logger.info(f"Processing sync Q{i + 1}: {prompt[:50]}...")
56 | answer = sc.run(prompt)
57 | print(f"\nPrompt: {prompt}")
58 | print(f"Final Answer (sync self-consistency): {answer}")
59 | logger.info("Sync processing complete.")
60 |
61 |
62 | if __name__ == "__main__":
63 | run_main(main_sync, main_async, "Run Self-Consistency example")
64 |
--------------------------------------------------------------------------------
/examples/run_simple_example.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from cogitator import SelfConsistency, OllamaLLM
4 |
5 | # Step 1: Configure logging (optional, but helpful)
6 | logging.basicConfig(level=logging.INFO)
7 | logging.getLogger("httpx").setLevel(logging.WARNING) # Suppress HTTPX logs
8 |
9 | # Step 2: Initialize the LLM (using Ollama)
10 | # Needs Ollama running locally with the model pulled (e.g., `ollama pull gemma3:4b`)
11 | try:
12 | llm = OllamaLLM(model="gemma3:4b")
13 | except Exception as e:
14 | print(f"Error initializing Ollama LLM: {e}")
15 | print("Please make sure Ollama is running and the model is pulled.")
16 | exit(1)
17 |
18 | # Step 3: Choose a CoT strategies (Self-Consistency in this case)
19 | # Self-Consistency generates multiple reasoning paths and finds the most common answer
20 | sc_strategy = SelfConsistency(
21 | llm,
22 | n_samples=5, # Number of reasoning paths to generate
23 | temperature=0.7 # Higher temperature can lead to more diverse answers
24 | )
25 |
26 | # Step 4: Define the prompt (with a basic CoT trigger)
27 | question = "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?"
28 | prompt = f"Q: {question}\nA: Let's think step by step."
29 |
30 | # Step 5: Run the CoT prompting sc_strategy
31 | print(f"\nQuestion: {question}")
32 | print("Running Self-Consistency CoT...")
33 | final_answer = sc_strategy.run(prompt) # Returns the most consistent (repeated) answer
34 |
35 | # Expected output: $0.05 or 0.05 (may vary slightly based on model and temperature)
36 | print(f"\nCogitator's Answer (Self-Consistency): {final_answer}")
37 |
--------------------------------------------------------------------------------
/examples/run_tree_of_thoughts.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import logging
5 |
6 | from cogitator import BaseLLM, TreeOfThoughts
7 | from examples.shared import get_llm, run_main, setup_logging
8 |
9 | setup_logging()
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_tot(llm: BaseLLM) -> TreeOfThoughts:
14 | return TreeOfThoughts(llm, max_depth=2, num_branches=2, sims=4, c_puct=1.0)
15 |
16 |
17 | QUESTIONS = [
18 | "A garden has 4 rows of 6 plants each. How many plants total?",
19 | "If y minus 5 equals 10, what is y?",
20 | ]
21 |
22 |
23 | async def main_async(args: argparse.Namespace):
24 | llm = get_llm(args.provider, args.model_name, args.openai_key)
25 | tot = setup_tot(llm)
26 | semaphore = asyncio.Semaphore(5)
27 |
28 | logger.info("Running TreeOfThoughts asynchronously...")
29 | tasks = [tot.run_async(q, semaphore=semaphore) for q in QUESTIONS]
30 | answers = await asyncio.gather(*tasks)
31 |
32 | for q, a in zip(QUESTIONS, answers):
33 | print(f"Q: {q}\nA: {a}\n")
34 |
35 |
36 | def main_sync(args: argparse.Namespace):
37 | llm = get_llm(args.provider, args.model_name, args.openai_key)
38 | tot = setup_tot(llm)
39 |
40 | logger.info("Running TreeOfThoughts synchronously...")
41 | for q in QUESTIONS:
42 | a = tot.run(q)
43 | print(f"Q: {q}\nA: {a}\n")
44 |
45 |
46 | if __name__ == "__main__":
47 | run_main(main_sync, main_async, "Run Tree-of-Thoughts example")
48 |
--------------------------------------------------------------------------------
/examples/shared.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import logging
4 | import os
5 | from typing import Any, Callable, Coroutine, Optional
6 |
7 | from cogitator import BaseLLM, OllamaLLM, OpenAILLM
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def setup_logging() -> None:
13 | logging.basicConfig(
14 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
15 | )
16 | logging.getLogger("httpx").setLevel(logging.WARNING)
17 | logging.getLogger("sentence_transformers.SentenceTransformer").setLevel(logging.WARNING)
18 |
19 |
20 | def get_llm(provider: str, model_name: str, openai_key: Optional[str] = None) -> BaseLLM:
21 | logger.info(f"Initializing LLM for examples: provider={provider}, model={model_name}")
22 | if provider == "openai":
23 | key = openai_key or os.getenv("OPENAI_API_KEY")
24 | if not key:
25 | raise ValueError(
26 | "OpenAI API key must be provided via --openai-key or "
27 | "OPENAI_API_KEY environment variable."
28 | )
29 | return OpenAILLM(api_key=key, model=model_name)
30 | elif provider == "ollama":
31 | return OllamaLLM(model=model_name)
32 | else:
33 | raise ValueError(f"Unsupported provider: {provider}")
34 |
35 |
36 | def parse_common_args(description: str) -> argparse.Namespace:
37 | parser = argparse.ArgumentParser(description=description)
38 | parser.add_argument(
39 | "--provider",
40 | choices=["openai", "ollama"],
41 | default="ollama",
42 | help="LLM provider to use (default: ollama)",
43 | )
44 | parser.add_argument(
45 | "--model-name",
46 | default=None,
47 | help="Name of the model (default: 'gemma3:4b' for ollama, 'gpt-4.1-nano' for openai)",
48 | )
49 | parser.add_argument(
50 | "--openai-key",
51 | default=None,
52 | help="OpenAI API key (reads OPENAI_API_KEY env var if not set)",
53 | )
54 | parser.add_argument(
55 | "--use-async", action="store_true", help="Run the asynchronous version of the example"
56 | )
57 |
58 | args = parser.parse_args()
59 |
60 | if not args.model_name:
61 | args.model_name = "gpt-4.1-nano" if args.provider == "openai" else "gemma3:4b"
62 | logger.info(
63 | f"Model name not specified, using default for {args.provider}: {args.model_name}"
64 | )
65 |
66 | return args
67 |
68 |
69 | def run_main(
70 | main_sync_func: Callable[[argparse.Namespace], None],
71 | main_async_func: Callable[[argparse.Namespace], Coroutine[Any, Any, None]],
72 | description: str,
73 | ) -> None:
74 | args = parse_common_args(description)
75 |
76 | if args.use_async:
77 | asyncio.run(main_async_func(args))
78 | else:
79 | main_sync_func(args)
80 |
--------------------------------------------------------------------------------
/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
101 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Cogitator Documentation
2 | site_description: Documentation for the Cogitator toolkit.
3 | repo_url: https://github.com/habedi/cogitator
4 | repo_name: habedi/cogitator
5 |
6 | theme:
7 | name: material
8 | palette:
9 | - media: "(prefers-color-scheme: light)"
10 | scheme: default
11 | toggle:
12 | icon: material/brightness-7
13 | name: Switch to dark mode
14 | - media: "(prefers-color-scheme: dark)"
15 | scheme: slate
16 | toggle:
17 | icon: material/brightness-4
18 | name: Switch to light mode
19 | features:
20 | - content.code.copy
21 | - navigation.tabs
22 | - navigation.top
23 | - navigation.indexes
24 | - navigation.expand
25 | - content.code.select
26 | - content.code.annotate
27 |
28 | plugins:
29 | - search
30 | - mkdocstrings:
31 | handlers:
32 | python:
33 | options:
34 | show_root_heading: true
35 | show_source: true
36 | nav:
37 | - Home: index.md
38 | - API Reference:
39 | - LLM Providers:
40 | - Base: api/model/base.md
41 | - OpenAI: api/model/openai.md
42 | - Ollama: api/model/ollama.md
43 | - Schemas: api/schemas.md
44 | - Embedding: api/embedding.md
45 | - Clustering: api/clustering.md
46 | - Utilities: api/utils.md
47 | - Strategies:
48 | - AutoCoT: api/strategies/auto_cot.md
49 | - CDWCoT: api/strategies/cdw_cot.md
50 | - GraphOfThoughts: api/strategies/graph_of_thoughts.md
51 | - LeastToMost: api/strategies/least_to_most.md
52 | - SelfConsistency: api/strategies/sc_cot.md
53 | - TreeOfThoughts: api/strategies/tree_of_thoughts.md
54 | - Benchmarking: benchmarking.md
55 | - Contributing: contributing.md
56 |
57 | markdown_extensions:
58 | - pymdownx.highlight:
59 | anchor_linenums: true
60 | - pymdownx.inlinehilite
61 | - pymdownx.snippets
62 | - pymdownx.superfences
63 | - admonition
64 | - toc:
65 | permalink: true
66 |
--------------------------------------------------------------------------------
/poetry.toml:
--------------------------------------------------------------------------------
1 | [virtualenvs]
2 | in-project = true
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "cogitator"
3 | version = "0.1.0b1"
4 | description = "A Python toolkit for chain-of-thought prompting"
5 | authors = ["Hassan Abedi "]
6 | license = "MIT"
7 | readme = "README.md"
8 | packages = [{ include = "cogitator" }]
9 | repository = "https://github.com/habedi/cogitator"
10 | documentation = "https://habedi.github.io/cogitator"
11 | classifiers = [
12 | "Development Status :: 4 - Beta",
13 | "Intended Audience :: Developers",
14 | "License :: OSI Approved :: MIT License",
15 | "Programming Language :: Python :: 3",
16 | "Programming Language :: Python :: 3.10",
17 | "Programming Language :: Python :: 3.11",
18 | "Programming Language :: Python :: 3.12",
19 | "Programming Language :: Python :: 3.13",
20 | "Topic :: Software Development :: Libraries",
21 | "Topic :: Software Development :: Libraries :: Python Modules",
22 | ]
23 | keywords = ["prompt engineering", "chain-of-thought", "artificial intelligence", "machine learning"]
24 |
25 | [tool.poetry.dependencies]
26 | python = "^3.10"
27 | scikit-learn = "^1.6.1"
28 | openai = "^1.76.2"
29 | ollama = "^0.4.8"
30 | pydantic = "^2.11.4"
31 | sentence-transformers = "^4.1.0"
32 | tiktoken = "^0.9.0"
33 |
34 | [tool.poetry.group.dev.dependencies]
35 | pytest = "^8.0.1"
36 | pytest-cov = "^6.0.0"
37 | pytest-mock = "^3.14.0"
38 | pytest-asyncio = "^0.26.0"
39 | mypy = "^1.11.1"
40 | ruff = "^0.11.7"
41 | griffe = "^1.7.3"
42 | polars = "^1.28.1"
43 | datasets = "^3.5.1"
44 | pyyaml = "^6.0.2"
45 | pre-commit = "^4.2.0"
46 | mkdocs = "^1.6.1"
47 | mkdocs-material = "^9.6.12"
48 | mkdocstrings-python = "^1.16.10"
49 |
50 | [build-system]
51 | requires = ["poetry-core"]
52 | build-backend = "poetry.core.masonry.api"
53 |
54 | [tool.pytest.ini_options]
55 | pythonpath = ["cogitator"]
56 | testpaths = ["tests"]
57 | addopts = [
58 | "--tb=short",
59 | #"--disable-warnings",
60 | "--cov=cogitator",
61 | "--cov-branch",
62 | "--cov-report=term",
63 | "--cov-report=xml",
64 | "-rs",
65 | ]
66 | asyncio_mode = "auto"
67 | asyncio_default_fixture_loop_scope = "function"
68 | asyncio_default_test_loop_scope = "function"
69 |
70 | [tool.coverage.run]
71 | branch = true
72 | parallel = true
73 | source = ["cogitator"]
74 | omit = [
75 | "tests/*",
76 | "benches/*",
77 | "examples/*",
78 | ]
79 |
80 | [tool.coverage.report]
81 | show_missing = false
82 | skip_empty = true
83 | precision = 2
84 |
85 | [tool.mypy]
86 | python_version = "3.10"
87 | ignore_missing_imports = true
88 | disallow_untyped_defs = true
89 | disallow_untyped_calls = true
90 | disallow_incomplete_defs = true
91 | check_untyped_defs = true
92 | warn_return_any = true
93 | strict_optional = true
94 | warn_redundant_casts = true
95 | exclude = "^(benches/|examples/|tests/)"
96 |
97 | [tool.ruff]
98 | exclude = [
99 | ".bzr",
100 | ".direnv",
101 | ".eggs",
102 | ".git",
103 | ".git-rewrite",
104 | ".hg",
105 | ".mypy_cache",
106 | ".nox",
107 | ".pants.d",
108 | ".pytype",
109 | ".ruff_cache",
110 | ".svn",
111 | ".tox",
112 | ".venv",
113 | "__pypackages__",
114 | "_build",
115 | "buck-out",
116 | "build",
117 | "dist",
118 | "node_modules",
119 | "venv",
120 | # Additional directories to exclude
121 | "tests",
122 | "benches",
123 | "examples",
124 | ]
125 | line-length = 100
126 | indent-width = 4
127 | src = ["cogitator", "examples"]
128 | target-version = "py310"
129 | unsafe-fixes = true
130 |
131 | [tool.ruff.lint]
132 | select = ["ANN", "E", "F", "I", "W", "B", "RUF", "SIM", "C90"]
133 | ignore = [
134 | # Ignore docstring errors
135 | "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107",
136 | ]
137 | fixable = ["ALL"]
138 | unfixable = []
139 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
140 |
141 | [tool.ruff.format]
142 | quote-style = "double"
143 | indent-style = "space"
144 | skip-magic-trailing-comma = false
145 | line-ending = "auto"
146 |
147 | [tool.ruff.lint.pydocstyle]
148 | convention = "google"
149 |
150 | [tool.ruff.lint.per-file-ignores]
151 | "tests/**/*.py" = []
152 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import logging
4 | from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple, Type
5 |
6 | import numpy as np
7 | import pytest
8 | from pydantic import BaseModel
9 |
10 | from cogitator import BaseClusterer
11 | from cogitator import BaseEmbedder
12 | from cogitator import BaseLLM
13 | from cogitator import (
14 | EvaluationResult,
15 | ExtractedAnswer,
16 | LTMDecomposition,
17 | ThoughtExpansion,
18 | )
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | DEFAULT_SYNC_RESPONSE = "SYNC_RESPONSE"
23 | DEFAULT_ASYNC_RESPONSE = "ASYNC_RESPONSE"
24 | DEFAULT_JSON_RESPONSE = ExtractedAnswer(final_answer="json_sync_default")
25 | DEFAULT_ASYNC_JSON_RESPONSE = ExtractedAnswer(final_answer="json_async_default")
26 | DEFAULT_FINAL_ANSWER = "FINAL_ANSWER_DEFAULT"
27 | DEFAULT_SUBANSWER = "SUBANSWER_DEFAULT"
28 | DEFAULT_JSON_STEPS = ThoughtExpansion(thoughts=["step1_default_sync"])
29 | DEFAULT_ASYNC_JSON_STEPS = ThoughtExpansion(thoughts=["step1_default_async"])
30 | DEFAULT_JSON_SUBQUESTIONS = LTMDecomposition(subquestions=["subq1_default_sync"])
31 | DEFAULT_ASYNC_JSON_SUBQUESTIONS = LTMDecomposition(subquestions=["subq1_default_async"])
32 | DEFAULT_JSON_EVAL = EvaluationResult(score=7, justification="Default Eval Sync")
33 | DEFAULT_ASYNC_JSON_EVAL = EvaluationResult(score=8, justification="Default Eval Async")
34 |
35 |
36 | class ConfigurableFakeLLM(BaseLLM):
37 |
38 | def __init__(self, config: Optional[Dict[str, Any]] = None):
39 | self._config = config if config is not None else {}
40 | self.sync_calls: List[Dict[str, Any]] = []
41 | self.async_calls: List[Dict[str, Any]] = []
42 | self.responses_map: Dict[str, Any] = self._config.get("responses_map", {})
43 | self._call_counts = {
44 | "generate": 0,
45 | "generate_async": 0,
46 | "_generate_json_internal": 0,
47 | "_generate_json_internal_async": 0,
48 | "stream": 0,
49 | "async_stream": 0
50 | }
51 |
52 | def _get_next_response(self, key: str, config_lookup_key: str, default: Any) -> Any:
53 | if key not in self._call_counts:
54 | raise KeyError(f"Internal error: key '{key}' not initialized in _call_counts.")
55 |
56 | response_config = self._config.get(config_lookup_key)
57 |
58 | if isinstance(response_config, list):
59 | if not response_config:
60 | self._call_counts[key] += 1
61 | return default
62 | idx = self._call_counts[key] % len(response_config)
63 | self._call_counts[key] += 1
64 | return response_config[idx]
65 | elif response_config is not None:
66 | self._call_counts[key] += 1
67 | return response_config
68 | else:
69 | self._call_counts[key] += 1
70 | return default
71 |
72 | def _get_response_for_prompt(self, prompt: str, method_type: str) -> Any:
73 |
74 | longest_match_key = None
75 | for key_fragment in sorted(self.responses_map.keys(), key=len, reverse=True):
76 | if key_fragment in prompt:
77 | longest_match_key = key_fragment
78 | break
79 |
80 | if longest_match_key is not None:
81 | return self.responses_map[longest_match_key]
82 |
83 | is_json_method = "json" in method_type
84 | if "JSON Output:" in prompt and "thoughts" in prompt:
85 | key = "json_steps"
86 | default = DEFAULT_ASYNC_JSON_STEPS if 'async' in method_type else DEFAULT_JSON_STEPS
87 | elif "JSON list of strings" in prompt:
88 | key = "generate_async" if 'async' in method_type else "generate_sync"
89 | default = json.dumps(
90 | DEFAULT_ASYNC_JSON_STEPS.model_dump()) if 'async' in method_type else json.dumps(
91 | DEFAULT_JSON_STEPS.model_dump())
92 | elif "JSON Output:" in prompt and "subquestions" in prompt:
93 | key = "json_subquestions"
94 | default = DEFAULT_ASYNC_JSON_SUBQUESTIONS if 'async' in method_type else DEFAULT_JSON_SUBQUESTIONS
95 | elif "JSON Evaluation:" in prompt:
96 | key = "json_eval"
97 | default = DEFAULT_ASYNC_JSON_EVAL if 'async' in method_type else DEFAULT_JSON_EVAL
98 | elif "JSON Answer:" in prompt:
99 | key = "json_answer"
100 |
101 | return self._get_next_response(method_type, key,
102 | DEFAULT_ASYNC_JSON_RESPONSE if 'async' in method_type else DEFAULT_JSON_RESPONSE)
103 | elif "Current Subquestion:" in prompt:
104 | key = "sub_answer"
105 | default = DEFAULT_SUBANSWER + ("_async" if 'async' in method_type else "")
106 | elif "Given reasoning steps" in prompt \
107 | or prompt.startswith("Answer the question:") \
108 | or prompt.startswith(
109 | "Based on the following sequential subquestions"):
110 | key = "final_answer"
111 | default = DEFAULT_FINAL_ANSWER + ("_async" if 'async' in method_type else "")
112 | else:
113 | if method_type == "generate":
114 | key, default = "generate_sync", DEFAULT_SYNC_RESPONSE
115 | elif method_type == "generate_async":
116 | key, default = "generate_async", DEFAULT_ASYNC_RESPONSE
117 | elif method_type == "_generate_json_internal":
118 | key, default = "generate_json", DEFAULT_JSON_RESPONSE
119 | elif method_type == "_generate_json_internal_async":
120 | key, default = "generate_json_async", DEFAULT_ASYNC_JSON_RESPONSE
121 | elif method_type == "stream":
122 | key, default = "generate_sync", DEFAULT_SYNC_RESPONSE
123 | elif method_type == "async_stream":
124 | key, default = "generate_async", DEFAULT_ASYNC_RESPONSE
125 | else:
126 | key, default = "unhandled", "UNHANDLED_FAKE_RESPONSE"
127 |
128 | return self._get_next_response(method_type, key, default)
129 |
130 | def generate(self, prompt: str, **kwargs: Any) -> str:
131 | self.sync_calls.append({"type": "generate", "prompt": prompt, "kwargs": kwargs})
132 | response = self._get_response_for_prompt(prompt, "generate")
133 |
134 | if not isinstance(response, str):
135 | try:
136 |
137 | if isinstance(response, BaseModel): return response.model_dump_json()
138 | if isinstance(response, (dict, list)): return json.dumps(response)
139 | except Exception:
140 | pass
141 | return str(response)
142 | return response
143 |
144 | async def generate_async(self, prompt: str, **kwargs: Any) -> str:
145 | self.async_calls.append({"type": "generate_async", "prompt": prompt, "kwargs": kwargs})
146 | await asyncio.sleep(0.001)
147 | response = self._get_response_for_prompt(prompt, "generate_async")
148 |
149 | if not isinstance(response, str):
150 | try:
151 | if isinstance(response, BaseModel): return response.model_dump_json()
152 | if isinstance(response, (dict, list)): return json.dumps(response)
153 | except Exception:
154 | pass
155 | return str(response)
156 | return response
157 |
158 | def generate_stream(self, prompt: str, **kwargs: Any) -> Iterator[str]:
159 | self.sync_calls.append({"type": "stream", "prompt": prompt, "kwargs": kwargs})
160 | response = self._get_response_for_prompt(prompt, "stream")
161 | yield str(response) + "_stream"
162 |
163 | async def generate_stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[str]:
164 | self.async_calls.append({"type": "async_stream", "prompt": prompt, "kwargs": kwargs})
165 | await asyncio.sleep(0.001)
166 | response = self._get_response_for_prompt(prompt, "async_stream")
167 | yield str(response) + "_async_stream"
168 |
169 | def _generate_json_internal(self, prompt: str, response_model: Type[BaseModel],
170 | **kwargs: Any) -> Tuple[str, Optional[str]]:
171 | self.sync_calls.append({
172 | "type": "_generate_json_internal",
173 | "prompt": prompt,
174 | "response_model": response_model.__name__,
175 | "kwargs": kwargs
176 | })
177 | response_obj = self._get_response_for_prompt(prompt, "_generate_json_internal")
178 | json_string = ""
179 | if isinstance(response_obj, BaseModel):
180 | json_string = response_obj.model_dump_json()
181 | elif isinstance(response_obj, dict):
182 | json_string = json.dumps(response_obj)
183 | elif isinstance(response_obj, str):
184 | try:
185 | json.loads(response_obj)
186 | json_string = response_obj
187 | except json.JSONDecodeError:
188 | logger.warning(
189 | f"Mock Configured string response for JSON prompt is not valid JSON: {response_obj}")
190 | json_string = "{}"
191 | else:
192 | try:
193 | json_string = json.dumps(response_obj)
194 | except TypeError:
195 | logger.warning(
196 | f"Mock cannot dump configured response to JSON: {type(response_obj)}")
197 | json_string = "{}"
198 |
199 | mode_used = "mock_json_mode"
200 | return json_string, mode_used
201 |
202 | async def _generate_json_internal_async(self, prompt: str, response_model: Type[BaseModel],
203 | **kwargs: Any) -> Tuple[
204 | str, Optional[str]]:
205 | self.async_calls.append({
206 | "type": "_generate_json_internal_async",
207 | "prompt": prompt,
208 | "response_model": response_model.__name__,
209 | "kwargs": kwargs
210 | })
211 | await asyncio.sleep(0.001)
212 | response_obj = self._get_response_for_prompt(prompt, "_generate_json_internal_async")
213 | json_string = ""
214 | if isinstance(response_obj, BaseModel):
215 | json_string = response_obj.model_dump_json()
216 | elif isinstance(response_obj, dict):
217 | json_string = json.dumps(response_obj)
218 | elif isinstance(response_obj, str):
219 | try:
220 | json.loads(response_obj)
221 | json_string = response_obj
222 | except json.JSONDecodeError:
223 | logger.warning(
224 | f"Mock Configured string response for async JSON prompt is not valid JSON: {response_obj}")
225 | json_string = "{}"
226 | else:
227 | try:
228 | json_string = json.dumps(response_obj)
229 | except TypeError:
230 | logger.warning(
231 | f"Mock cannot dump configured async response to JSON: {type(response_obj)}")
232 | json_string = "{}"
233 |
234 | mode_used = "mock_json_mode_async"
235 | return json_string, mode_used
236 |
237 |
238 | @pytest.fixture
239 | def fake_llm_factory() -> Callable[[Optional[Dict[str, Any]]], ConfigurableFakeLLM]:
240 | def _create_llm(config: Optional[Dict[str, Any]] = None) -> ConfigurableFakeLLM:
241 | return ConfigurableFakeLLM(config)
242 |
243 | return _create_llm
244 |
245 |
246 | class MockEmbedder(BaseEmbedder):
247 | def encode(self, texts: List[str]) -> List[np.ndarray]:
248 | logger.debug(f"Mock encoding texts: {texts}")
249 | return [np.array([float(i), float(i + 1)], dtype=float) for i in range(len(texts))]
250 |
251 |
252 | class MockClusterer(BaseClusterer):
253 | def cluster(self, embeddings: np.ndarray, n_clusters: int, **kwargs) -> Tuple[
254 | np.ndarray, np.ndarray]:
255 | logger.debug(
256 | f"Mock clustering embeddings (shape {embeddings.shape}) into {n_clusters} clusters")
257 | if embeddings.shape[0] == 0 or n_clusters <= 0:
258 | output_dim = embeddings.shape[1] if len(
259 | embeddings.shape) > 1 and embeddings.shape[1] > 0 else 1
260 | labels = np.array([], dtype=int)
261 | centers = np.array([], dtype=float).reshape(0, output_dim)
262 | else:
263 | output_dim = embeddings.shape[1]
264 | n_clusters = min(n_clusters, embeddings.shape[0])
265 | labels = (embeddings[:, 0] % n_clusters).astype(int)
266 | centers = np.array(
267 | [embeddings[labels == i].mean(axis=0) if np.any(labels == i) else np.zeros(
268 | output_dim) for i in range(n_clusters)])
269 | if centers.ndim == 1 and output_dim > 0:
270 | centers = centers.reshape(-1, output_dim)
271 | elif centers.ndim == 0 and output_dim == 0:
272 | centers = centers.reshape(n_clusters, 1)
273 | logger.debug(f"Generated labels: {labels}")
274 | logger.debug(f"Generated centers shape: {centers.shape}")
275 | return labels, centers
276 |
277 |
278 | @pytest.fixture
279 | def patch_embedding_clustering(monkeypatch):
280 | logger.debug("Patching embedding and clustering classes")
281 | monkeypatch.setattr("cogitator.strategies.auto_cot.SentenceTransformerEmbedder", MockEmbedder)
282 | monkeypatch.setattr("cogitator.strategies.auto_cot.KMeansClusterer", MockClusterer)
283 | monkeypatch.setattr("cogitator.strategies.cdw_cot.SentenceTransformerEmbedder", MockEmbedder)
284 | monkeypatch.setattr("cogitator.strategies.cdw_cot.KMeansClusterer", MockClusterer)
285 | monkeypatch.setattr("cogitator.strategies.graph_of_thoughts.SentenceTransformerEmbedder",
286 | MockEmbedder)
287 |
--------------------------------------------------------------------------------
/tests/extra_tests/test_examples.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib
3 | import logging
4 | import sys
5 | from pathlib import Path
6 | from unittest.mock import MagicMock
7 |
8 | import pytest
9 |
10 | # --- Test Setup ---
11 | # Add the examples directory to the path for module imports
12 | examples_dir = Path(__file__).parent.parent / "examples"
13 | sys.path.insert(0, str(examples_dir.parent))
14 |
15 | # List of example script filenames
16 | example_scripts = [f.stem for f in examples_dir.glob("run_*.py")]
17 |
18 |
19 | @pytest.fixture
20 | def mock_get_llm_examples(mocker):
21 | """ Mocks get_llm specifically for the examples """
22 | mock_llm_instance = MagicMock()
23 | mock_llm_instance.generate.return_value = "Example Mock LLM Response"
24 | mock_llm_instance.generate_async.return_value = "Example Mock Async LLM Response"
25 | # Add mocks for other methods used by examples if necessary
26 | # e.g., fit, fit_async, run, run_async for specific strategies
27 | mock_llm_instance.fit = MagicMock()
28 | mock_llm_instance.fit_async = MagicMock(return_value=None) # Needs to be awaitable
29 | mock_llm_instance.init_pool = MagicMock()
30 | mock_llm_instance.init_pool_async = MagicMock(return_value=None)
31 | mock_llm_instance.train = MagicMock()
32 | mock_llm_instance.train_async = MagicMock(return_value=None)
33 | mock_llm_instance.run = MagicMock(return_value="Example Mock Run")
34 | mock_llm_instance.run_async = MagicMock(
35 | return_value="Example Mock Async Run") # Needs to be awaitable
36 | mock_llm_instance.generate_json = MagicMock(return_value={"final_answer": "Mock JSON"})
37 | mock_llm_instance.generate_json_async = MagicMock(
38 | return_value={"final_answer": "Mock Async JSON"}) # Needs to be awaitable
39 |
40 | # Mock the get_llm function within the examples.shared module
41 | return mocker.patch("examples.shared.get_llm", return_value=mock_llm_instance)
42 |
43 |
44 | @pytest.fixture
45 | def mock_embedding_clustering_examples(mocker):
46 | """ Mocks embedding and clustering for examples that might use them """
47 | mocker.patch("cogitator.embedding.SentenceTransformerEmbedder.encode",
48 | return_value=[MagicMock()])
49 | mocker.patch("cogitator.clustering.KMeansClusterer.cluster",
50 | return_value=(MagicMock(), MagicMock()))
51 |
52 |
53 | # --- Test Runner ---
54 |
55 | @pytest.mark.parametrize("script_name", example_scripts)
56 | def test_example_script_sync(script_name, mock_get_llm_examples, mock_embedding_clustering_examples,
57 | mocker, capsys):
58 | """ Tests if example scripts run synchronously without errors """
59 | logging.info(f"Testing example (sync): {script_name}")
60 | # Mock sys.argv to prevent interference and set provider
61 | mocker.patch("sys.argv", ["examples/" + script_name + ".py", "--provider",
62 | "ollama"]) # Use ollama as it requires no key
63 |
64 | try:
65 | # Import the module
66 | module = importlib.import_module(f"examples.{script_name}")
67 | # Check if main_sync exists and run it
68 | if hasattr(module, "main_sync"):
69 | module.main_sync(
70 | argparse.Namespace(provider="ollama", model_name="mock_model", openai_key=None,
71 | use_async=False))
72 | # Capture stdout to check for basic execution (optional)
73 | captured = capsys.readouterr()
74 | assert "Error" not in captured.err # Basic check for errors in output
75 | logging.info(f"Sync run for {script_name} completed.")
76 | else:
77 | pytest.skip(f"No main_sync function found in {script_name}")
78 | except Exception as e:
79 | pytest.fail(f"Example script {script_name} (sync) failed: {e}")
80 |
81 |
82 | @pytest.mark.asyncio
83 | @pytest.mark.parametrize("script_name", example_scripts)
84 | async def test_example_script_async(script_name, mock_get_llm_examples,
85 | mock_embedding_clustering_examples, mocker, capsys):
86 | """ Tests if example scripts run asynchronously without errors """
87 | logging.info(f"Testing example (async): {script_name}")
88 | # Mock sys.argv to prevent interference and set provider + async flag
89 | mocker.patch("sys.argv",
90 | ["examples/" + script_name + ".py", "--provider", "ollama", "--use-async"])
91 |
92 | try:
93 | module = importlib.import_module(f"examples.{script_name}")
94 | if hasattr(module, "main_async"):
95 | # Mock asyncio.gather if necessary, especially for complex async flows
96 | # mocker.patch('asyncio.gather', return_value=["Mock Async Result"])
97 |
98 | # Mock the specific CoT methods used in async examples if get_llm mock isn't enough
99 | # Example: mocker.patch('cogitator.AutoCoT.fit_async', return_value=None)
100 | # Example: mocker.patch('cogitator.AutoCoT.run_async', return_value="Mock Async Output")
101 | # (Covered by mock_get_llm_examples for now, but might need refinement)
102 |
103 | await module.main_async(
104 | argparse.Namespace(provider="ollama", model_name="mock_model", openai_key=None,
105 | use_async=True))
106 | captured = capsys.readouterr()
107 | assert "Error" not in captured.err
108 | logging.info(f"Async run for {script_name} completed.")
109 | else:
110 | pytest.skip(f"No main_async function found in {script_name}")
111 | except Exception as e:
112 | pytest.fail(f"Example script {script_name} (async) failed: {e}")
113 |
--------------------------------------------------------------------------------
/tests/test_auto_cot.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cogitator import AutoCoT
4 |
5 |
6 | def test_fit_builds_expected_number_of_demos(fake_llm_factory, patch_embedding_clustering):
7 | questions = [f"q{i}" for i in range(8)]
8 | llm = fake_llm_factory({
9 | "generate_sync": "Step 1\nStep 2"
10 | })
11 | ac = AutoCoT(llm, n_demos=2, max_q_tokens=100, max_steps=5)
12 | ac.fit(questions)
13 | assert ac.demos is not None
14 | assert len(ac.demos) == 2
15 | for demo in ac.demos:
16 | assert demo.startswith("Q: ")
17 | assert "Step 1" in demo
18 |
19 |
20 | @pytest.mark.asyncio
21 | async def test_fit_async_builds_expected_number_of_demos(fake_llm_factory,
22 | patch_embedding_clustering):
23 | questions = [f"q{i}" for i in range(8)]
24 | llm = fake_llm_factory({
25 | "generate_async": "Async Step 1\nAsync Step 2"
26 | })
27 | ac = AutoCoT(llm, n_demos=2, max_q_tokens=100, max_steps=5)
28 | await ac.fit_async(questions)
29 | assert ac.demos is not None
30 | assert len(ac.demos) == 2
31 | for demo in ac.demos:
32 | assert demo.startswith("Q: ")
33 | assert "Async Step 1" in demo
34 |
35 |
36 | def test_run_uses_cached_demos_and_constructs_payload(fake_llm_factory,
37 | patch_embedding_clustering):
38 | questions = [f"q{i}" for i in range(8)]
39 | llm = fake_llm_factory({
40 | "generate_sync": "Sync Final Answer"
41 | })
42 | ac = AutoCoT(llm, n_demos=2)
43 | ac.fit(questions)
44 | assert ac.demos is not None
45 | out = ac.run("test question")
46 | assert out == "Sync Final Answer"
47 | assert "test question" in llm.sync_calls[-1]["prompt"]
48 |
49 |
50 | @pytest.mark.asyncio
51 | async def test_run_async_uses_cached_demos(fake_llm_factory, patch_embedding_clustering):
52 | questions = [f"q{i}" for i in range(8)]
53 | llm = fake_llm_factory({
54 | "generate_async": "Async Final Answer"
55 | })
56 | ac = AutoCoT(llm, n_demos=2)
57 | await ac.fit_async(questions)
58 | assert ac.demos is not None
59 | out = await ac.run_async("test question async")
60 | assert out == "Async Final Answer"
61 | assert "test question async" in llm.async_calls[-1]["prompt"]
62 |
63 |
64 | def test_fit_raises_with_insufficient_questions(fake_llm_factory):
65 | llm = fake_llm_factory()
66 | ac = AutoCoT(llm, n_demos=3)
67 | with pytest.raises(ValueError):
68 | ac.fit(["only one"])
69 |
70 |
71 | @pytest.mark.asyncio
72 | async def test_fit_async_raises_with_insufficient_questions(fake_llm_factory):
73 | llm = fake_llm_factory()
74 | ac = AutoCoT(llm, n_demos=3)
75 | with pytest.raises(ValueError):
76 | await ac.fit_async(["only one"])
77 |
--------------------------------------------------------------------------------
/tests/test_cdw_cot.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cogitator import CDWCoT
4 |
5 |
6 | def test_init_pool_builds_candidate_pool(fake_llm_factory, patch_embedding_clustering):
7 | questions = [f"q{i}" for i in range(8)]
8 | answers = [f"a{i}" for i in range(8)]
9 | llm = fake_llm_factory({"generate_sync": "Pool CoT Sync"})
10 | cdw = CDWCoT(llm, pool_size=4, n_clusters=2, sample_size=2)
11 | cdw.init_pool(questions, answers)
12 | assert len(cdw.PC) <= 4
13 | assert cdw.cluster_centers is not None
14 | assert len(cdw.p_cluster) == cdw.cluster_centers.shape[0]
15 | if cdw.PC:
16 | assert "Pool CoT Sync" in cdw.PC[0]
17 | for p in cdw.p_cluster:
18 | assert pytest.approx(1.0) == p.sum()
19 |
20 |
21 | @pytest.mark.asyncio
22 | async def test_init_pool_async_builds_candidate_pool(fake_llm_factory, patch_embedding_clustering):
23 | questions = [f"q{i}" for i in range(8)]
24 | answers = [f"a{i}" for i in range(8)]
25 | llm = fake_llm_factory({"generate_async": "Pool CoT Async"})
26 | cdw = CDWCoT(llm, pool_size=4, n_clusters=2, sample_size=2)
27 | await cdw.init_pool_async(questions, answers)
28 | assert len(cdw.PC) <= 4
29 | assert cdw.cluster_centers is not None
30 | assert len(cdw.p_cluster) == cdw.cluster_centers.shape[0]
31 | if cdw.PC:
32 | assert "Pool CoT Async" in cdw.PC[0]
33 | for p in cdw.p_cluster:
34 | assert pytest.approx(1.0) == p.sum()
35 |
36 |
37 | def test_train_and_run_flow_runs_without_error(fake_llm_factory, patch_embedding_clustering):
38 | questions = [f"q{i}{i}" for i in range(10)]
39 | answers = [f"a{i}" for i in range(10)]
40 | llm = fake_llm_factory({"generate_sync": "Train/Run Sync"})
41 | # Removed temp=1.0
42 | cdw = CDWCoT(llm, pool_size=5, n_clusters=3, sample_size=2, lr=0.1)
43 | cdw.init_pool(questions, answers)
44 | cdw.train(val_split=0.3, epochs=1, patience=1)
45 | if not cdw.PC or cdw.cluster_centers is None: pytest.skip("Pool/Clusters empty")
46 | out = cdw.run("some sync test") # Changed from answer to run
47 | assert out == "Train/Run Sync"
48 | assert "some sync test" in llm.sync_calls[-1]["prompt"]
49 |
50 |
51 | @pytest.mark.asyncio
52 | async def test_train_async_and_run_async_flow(fake_llm_factory, patch_embedding_clustering):
53 | questions = [f"q{i}{i}" for i in range(10)]
54 | answers = [f"a{i}" for i in range(10)]
55 | llm = fake_llm_factory({"generate_async": "Train/Run Async"})
56 | # Removed temp=1.0
57 | cdw = CDWCoT(llm, pool_size=5, n_clusters=3, sample_size=2, lr=0.1, seed=42)
58 | await cdw.init_pool_async(questions, answers)
59 | await cdw.train_async(val_split=0.3, epochs=1, patience=1)
60 | if not cdw.PC or cdw.cluster_centers is None: pytest.skip("Pool/Clusters empty")
61 | out = await cdw.run_async("some async test") # Changed from answer_async to run_async
62 | assert out == "Train/Run Async"
63 | assert "some async test" in llm.async_calls[-1]["prompt"]
64 |
65 |
66 | def test_init_pool_raises_on_length_mismatch(fake_llm_factory):
67 | llm = fake_llm_factory()
68 | cdw = CDWCoT(llm)
69 | with pytest.raises(ValueError):
70 | cdw.init_pool(["q1"], ["a1", "a2"])
71 |
72 |
73 | @pytest.mark.asyncio
74 | async def test_init_pool_async_raises_on_length_mismatch(fake_llm_factory):
75 | llm = fake_llm_factory()
76 | cdw = CDWCoT(llm)
77 | with pytest.raises(ValueError):
78 | await cdw.init_pool_async(["q1"], ["a1", "a2"])
79 |
--------------------------------------------------------------------------------
/tests/test_graph_of_thoughts.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Tuple
2 |
3 | import pytest
4 |
5 | from cogitator import EvaluationResult, ExtractedAnswer, ThoughtExpansion
6 | from cogitator import GraphOfThoughts
7 |
8 | EXAMPLE_BASIC_GOO: List[Tuple[str, Dict]] = [
9 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated_step1', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}),
10 | ('Score', {'target_set': 'generated_step1', 'prompt_key': 'evaluate'}),
11 | ('KeepBest', {'N': 1, 'target_set': 'generated_step1', 'output_set': 'best_final_node'})
12 | ]
13 |
14 | @pytest.mark.asyncio
15 | async def test_run_async_returns_result_and_calls_prompts_text_format(
16 | fake_llm_factory, patch_embedding_clustering):
17 | fake_expansion_config_async = ThoughtExpansion(thoughts=["stepA_async"])
18 | fake_eval_config_async = EvaluationResult(score=9, justification="Good_async")
19 | llm = fake_llm_factory({
20 | "json_steps": fake_expansion_config_async,
21 | "json_eval": fake_eval_config_async,
22 | "responses_map": {
23 | "Based on the final reasoning": "RESULT_async_text"
24 | }
25 | })
26 |
27 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="text")
28 |
29 | test_goo: List[Tuple[str, Dict]] = [
30 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}),
31 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}),
32 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'})
33 | ]
34 |
35 | out = await got_instance.run_async("start_async?", graph_of_operations=test_goo)
36 | assert out == "RESULT_async_text"
37 |
38 | gen_op_call = next((c for c in llm.async_calls if
39 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "ThoughtExpansion"),
40 | None)
41 | score_op_call = next((c for c in llm.async_calls if
42 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "EvaluationResult"),
43 | None)
44 | final_answer_call = next((c for c in llm.async_calls if
45 | c["type"] == "generate_async" and c["prompt"].startswith("Based on the final reasoning")),
46 | None)
47 |
48 | assert gen_op_call is not None, "Async GenerateOp LLM call not found"
49 | # Corrected Assertion: Check for the actual end of the modified prompt
50 | assert "JSON Output:" in gen_op_call["prompt"], "GenerateOp did not use expected prompt content"
51 | assert score_op_call is not None, "Async ScoreOp LLM call not found"
52 | assert "JSON Evaluation:" in score_op_call["prompt"], "ScoreOp did not use expected prompt content"
53 | assert final_answer_call is not None, "Async final answer generation call (text) not found"
54 |
55 | @pytest.mark.asyncio
56 | async def test_run_async_returns_result_and_calls_prompts_json_format(
57 | fake_llm_factory, patch_embedding_clustering):
58 | fake_expansion_config_async = ThoughtExpansion(thoughts=["stepA_async_json"])
59 | fake_eval_config_async = EvaluationResult(score=9, justification="Good_async_json")
60 | fake_final_answer_obj_async = ExtractedAnswer(final_answer="RESULT_async_json")
61 | llm = fake_llm_factory({
62 | "json_steps": fake_expansion_config_async,
63 | "json_eval": fake_eval_config_async,
64 | "json_answer": fake_final_answer_obj_async
65 | })
66 |
67 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="json")
68 |
69 | test_goo: List[Tuple[str, Dict]] = [
70 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}),
71 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}),
72 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'})
73 | ]
74 |
75 | out = await got_instance.run_async("start_async_json?", graph_of_operations=test_goo)
76 | assert out == "RESULT_async_json"
77 |
78 | gen_op_call = next((c for c in llm.async_calls if
79 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "ThoughtExpansion"),
80 | None)
81 | score_op_call = next((c for c in llm.async_calls if
82 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "EvaluationResult"),
83 | None)
84 | final_json_call = next((c for c in llm.async_calls if
85 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "ExtractedAnswer"),
86 | None)
87 |
88 | assert gen_op_call is not None, "Async GenerateOp LLM call not found"
89 | assert score_op_call is not None, "Async ScoreOp LLM call not found"
90 | assert final_json_call is not None, "Async final answer generation call (JSON) not found"
91 | assert "Based on the final reasoning" in final_json_call["prompt"], "Final prompt content mismatch"
92 |
93 | def test_run_returns_result_and_calls_prompts_text_format(
94 | fake_llm_factory, patch_embedding_clustering):
95 | fake_expansion_config = ThoughtExpansion(thoughts=["stepA_sync"])
96 | fake_eval_config = EvaluationResult(score=9, justification="Good_sync")
97 | llm = fake_llm_factory({
98 | "json_steps": fake_expansion_config,
99 | "json_eval": fake_eval_config,
100 | "responses_map": {
101 | "Based on the final reasoning": "RESULT_sync_text"
102 | }
103 | })
104 |
105 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="text")
106 |
107 | test_goo: List[Tuple[str, Dict]] = [
108 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}),
109 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}),
110 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'})
111 | ]
112 |
113 | try:
114 | # This test might still fail or be unreliable due to asyncio.run issues
115 | # Mark as skipped or handle potential errors robustly if run is just a wrapper
116 | # pytest.skip("Skipping sync test for GoT due to asyncio.run wrapper issues.")
117 | out = got_instance.run("start?", graph_of_operations=test_goo)
118 | assert out == "RESULT_sync_text"
119 | assert len(llm.async_calls) > 0, "Expected async calls even in sync test due to wrapper"
120 |
121 | except NotImplementedError:
122 | pytest.skip("Synchronous 'run' not implemented for GraphOfThoughts.")
123 | except RuntimeError as e:
124 | if "event loop" in str(e).lower():
125 | pytest.skip(f"Skipping sync test due to asyncio event loop issue: {e}")
126 | else:
127 | raise
128 |
129 |
130 | def test_run_returns_result_and_calls_prompts_json_format(
131 | fake_llm_factory, patch_embedding_clustering):
132 | fake_expansion_config = ThoughtExpansion(thoughts=["stepA_sync_json"])
133 | fake_eval_config = EvaluationResult(score=9, justification="Good_sync_json")
134 | fake_final_answer_obj = ExtractedAnswer(final_answer="RESULT_sync_json")
135 | llm = fake_llm_factory({
136 | "json_steps": fake_expansion_config,
137 | "json_eval": fake_eval_config,
138 | "json_answer": fake_final_answer_obj
139 | })
140 |
141 | got_instance = GraphOfThoughts(llm=llm, final_answer_format="json")
142 |
143 | test_goo: List[Tuple[str, Dict]] = [
144 | ('Generate', {'k': 1, 'target_set': 'frontier', 'output_set': 'generated', 'prompt_key': 'expand', 'response_schema': ThoughtExpansion}),
145 | ('Score', {'target_set': 'generated', 'prompt_key': 'evaluate'}),
146 | ('KeepBest', {'N': 1, 'target_set': 'generated', 'output_set': 'best_node'})
147 | ]
148 |
149 | try:
150 | # This test might still fail or be unreliable due to asyncio.run issues
151 | # Mark as skipped or handle potential errors robustly if run is just a wrapper
152 | # pytest.skip("Skipping sync test for GoT due to asyncio.run wrapper issues.")
153 | out = got_instance.run("start_json?", graph_of_operations=test_goo)
154 | assert out == "RESULT_sync_json"
155 | assert len(llm.async_calls) > 0, "Expected async calls even in sync test due to wrapper"
156 | assert any(c["type"] == "_generate_json_internal_async" and c["response_model"] == "ExtractedAnswer" for c in llm.async_calls), "Final async JSON call missing"
157 |
158 | except NotImplementedError:
159 | pytest.skip("Synchronous 'run' not implemented for GraphOfThoughts.")
160 | except RuntimeError as e:
161 | if "event loop" in str(e).lower():
162 | pytest.skip(f"Skipping sync test due to asyncio event loop issue: {e}")
163 | else:
164 | raise
165 |
--------------------------------------------------------------------------------
/tests/test_least_to_most.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cogitator import LTMDecomposition, ExtractedAnswer
4 | from cogitator import LeastToMost
5 |
6 |
7 | def test_decompose_numbered(fake_llm_factory):
8 | llm = fake_llm_factory({
9 | "json_subquestions": LTMDecomposition(subquestions=["sub1", "sub2", "sub3"])
10 | })
11 | ltm = LeastToMost(llm, intermediate_output_format="json")
12 | subs = ltm.decompose("anything")
13 | assert subs == ["sub1", "sub2", "sub3"]
14 | assert len(llm.sync_calls) == 1
15 | assert llm.sync_calls[0]["type"] == "_generate_json_internal"
16 | assert llm.sync_calls[0]["response_model"] == "LTMDecomposition"
17 |
18 |
19 | def test_decompose_bullets(fake_llm_factory):
20 | llm = fake_llm_factory({
21 | "json_subquestions": LTMDecomposition(subquestions=["subA", "subB"])
22 | })
23 | ltm = LeastToMost(llm, intermediate_output_format="json")
24 | subs = ltm.decompose("anything")
25 | assert subs == ["subA", "subB"]
26 |
27 |
28 | def test_decompose_fallback_simulated(fake_llm_factory):
29 | llm = fake_llm_factory({
30 | "json_subquestions": LTMDecomposition(
31 | subquestions=["This is a sentence", "Another sentence"])
32 | })
33 | ltm = LeastToMost(llm, intermediate_output_format="json")
34 | subs = ltm.decompose("anything")
35 | assert subs == ["This is a sentence", "Another sentence"]
36 |
37 |
38 | def test_decompose_empty_list_raises(fake_llm_factory):
39 | llm = fake_llm_factory({
40 | "json_subquestions": LTMDecomposition(subquestions=[])
41 | })
42 | ltm = LeastToMost(llm, intermediate_output_format="json")
43 | with pytest.raises(ValueError, match="LLM returned empty subquestions list after validation."):
44 | ltm.decompose("anything")
45 |
46 |
47 | def test_decompose_invalid_json_raises(fake_llm_factory):
48 | llm = fake_llm_factory()
49 | llm.generate_json = lambda *args, **kwargs: (_ for _ in ()).throw(
50 | RuntimeError("generate_json failed... Last error: JSONDecodeError..."))
51 | ltm = LeastToMost(llm, intermediate_output_format="json")
52 | with pytest.raises(ValueError,
53 | match=r"Failed to decompose question due to LLM error: RuntimeError"):
54 | ltm.decompose("anything")
55 |
56 |
57 | def test_decompose_invalid_schema_raises(fake_llm_factory):
58 | llm = fake_llm_factory()
59 | llm.generate_json = lambda *args, **kwargs: (_ for _ in ()).throw(
60 | RuntimeError("generate_json failed... Last error: ValidationError..."))
61 | ltm = LeastToMost(llm, intermediate_output_format="json")
62 | with pytest.raises(ValueError,
63 | match=r"Failed to decompose question due to LLM error: RuntimeError"):
64 | ltm.decompose("anything")
65 |
66 |
67 | def test_max_subqs_trims(fake_llm_factory):
68 | many_subs = [f"sub{i + 1}" for i in range(20)]
69 | llm = fake_llm_factory({
70 | "json_subquestions": LTMDecomposition(subquestions=many_subs)
71 | })
72 | ltm = LeastToMost(llm, max_subqs=5, intermediate_output_format="json")
73 | subs = ltm.decompose("anything")
74 | assert len(subs) == 5
75 | assert subs == ["sub1", "sub2", "sub3", "sub4", "sub5"]
76 |
77 |
78 | @pytest.mark.asyncio
79 | async def test_decompose_async_numbered(fake_llm_factory):
80 | llm = fake_llm_factory({
81 | "json_subquestions": LTMDecomposition(subquestions=["sub1", "sub2", "sub3"])
82 | })
83 | ltm = LeastToMost(llm, intermediate_output_format="json")
84 | subs = await ltm.decompose_async("anything async")
85 | assert subs == ["sub1", "sub2", "sub3"]
86 | assert len(llm.async_calls) == 1
87 | assert llm.async_calls[0]["type"] == "_generate_json_internal_async"
88 | assert llm.async_calls[0]["response_model"] == "LTMDecomposition"
89 |
90 |
91 | @pytest.mark.asyncio
92 | async def test_decompose_async_empty_list_raises(fake_llm_factory):
93 | llm = fake_llm_factory({
94 | "json_subquestions": LTMDecomposition(subquestions=[])
95 | })
96 | ltm = LeastToMost(llm, intermediate_output_format="json")
97 | with pytest.raises(ValueError,
98 | match="Async LLM returned empty subquestions list after validation."):
99 | await ltm.decompose_async("anything async")
100 |
101 |
102 | @pytest.mark.asyncio
103 | async def test_decompose_async_invalid_json_raises(fake_llm_factory):
104 | llm = fake_llm_factory()
105 |
106 | async def mock_fail(*args, **kwargs):
107 | raise RuntimeError("generate_json_async failed... Last error: JSONDecodeError...")
108 |
109 | llm.generate_json_async = mock_fail
110 | ltm = LeastToMost(llm, intermediate_output_format="json")
111 | with pytest.raises(ValueError,
112 | match=r"Async decomposition failed due to LLM error: RuntimeError"):
113 | await ltm.decompose_async("anything async")
114 |
115 |
116 | @pytest.mark.asyncio
117 | async def test_decompose_async_invalid_schema_raises(fake_llm_factory):
118 | llm = fake_llm_factory()
119 |
120 | async def mock_fail(*args, **kwargs):
121 | raise RuntimeError("generate_json_async failed... Last error: ValidationError...")
122 |
123 | llm.generate_json_async = mock_fail
124 | ltm = LeastToMost(llm, intermediate_output_format="json")
125 | with pytest.raises(ValueError,
126 | match=r"Async decomposition failed due to LLM error: RuntimeError"):
127 | await ltm.decompose_async("anything async")
128 |
129 |
130 | # --- Run tests ---
131 |
132 | def test_run_integration_text_mode(fake_llm_factory):
133 | llm = fake_llm_factory({
134 | "json_subquestions": LTMDecomposition(subquestions=["dummy sub1", "dummy sub2"]),
135 | "sub_answer": "sub_answer_sync_text",
136 | "final_answer": "ANS_sync_text"
137 | })
138 | ltm = LeastToMost(llm, intermediate_output_format="text")
139 | out = ltm.run("question?")
140 | assert out == "ANS_sync_text"
141 | assert any(
142 | c["type"] == "_generate_json_internal" and c["response_model"] == "LTMDecomposition" for c
143 | in llm.sync_calls)
144 | solve_calls = [c for c in llm.sync_calls if
145 | c["type"] == "generate" and "Current Subquestion:" in c["prompt"]]
146 | final_call = next((c for c in llm.sync_calls if
147 | c["type"] == "generate" and c["prompt"].startswith(
148 | "Based on the following")), None)
149 | assert len(solve_calls) == 2
150 | assert final_call is not None
151 |
152 |
153 | def test_run_integration_json_mode(fake_llm_factory):
154 | subquestions = ["dummy sub1", "dummy sub2"]
155 | question = "question?"
156 | fake_sub_answer_obj = ExtractedAnswer(final_answer="sub_answer_sync_json")
157 | fake_final_answer_obj = ExtractedAnswer(final_answer="ANS_sync_json")
158 |
159 | solve_prompt_key_1 = "Current Subquestion: dummy sub1"
160 | solve_prompt_key_2 = "Current Subquestion: dummy sub2"
161 | # The final prompt starts with "Based on..." and contains the original question
162 | final_prompt_key = f"Original Main Question: {question}"
163 |
164 | llm = fake_llm_factory({
165 | "json_subquestions": LTMDecomposition(subquestions=subquestions),
166 | "responses_map": {
167 |
168 | solve_prompt_key_1: fake_sub_answer_obj,
169 | solve_prompt_key_2: fake_sub_answer_obj,
170 | final_prompt_key: fake_final_answer_obj
171 | }
172 | })
173 | ltm = LeastToMost(llm, intermediate_output_format="json")
174 | out = ltm.run(question)
175 | assert out == "ANS_sync_json"
176 |
177 | assert any(
178 | c["type"] == "_generate_json_internal" and c["response_model"] == "LTMDecomposition" for c
179 | in llm.sync_calls)
180 | solve_json_calls = [c for c in llm.sync_calls if
181 | c["type"] == "_generate_json_internal" and "Current Subquestion:" in c[
182 | "prompt"]]
183 | final_json_call = next((c for c in llm.sync_calls if
184 | c["type"] == "_generate_json_internal" and c["prompt"].startswith(
185 | "Based on the following")), None)
186 | assert len(solve_json_calls) == 2
187 | assert final_json_call is not None
188 |
189 |
190 | @pytest.mark.asyncio
191 | async def test_run_async_integration_text_mode(fake_llm_factory):
192 | llm = fake_llm_factory({
193 | "json_subquestions": LTMDecomposition(
194 | subquestions=["dummy sub1 async", "dummy sub2 async"]),
195 | "sub_answer": "sub_answer_async_text",
196 | "final_answer": "ANS_async_text"
197 | })
198 | ltm = LeastToMost(llm, intermediate_output_format="text")
199 | out = await ltm.run_async("question async?")
200 | assert out == "ANS_async_text"
201 | assert any(
202 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "LTMDecomposition"
203 | for c in llm.async_calls)
204 | solve_calls = [c for c in llm.async_calls if
205 | c["type"] == "generate_async" and "Current Subquestion:" in c["prompt"]]
206 | final_call = next((c for c in llm.async_calls if
207 | c["type"] == "generate_async" and c["prompt"].startswith(
208 | "Based on the following")), None)
209 | assert len(solve_calls) == 2
210 | assert final_call is not None
211 |
212 |
213 | @pytest.mark.asyncio
214 | async def test_run_async_integration_json_mode(fake_llm_factory):
215 | subquestions = ["dummy sub1 async", "dummy sub2 async"]
216 | question = "question async?"
217 | fake_sub_answer_obj_async = ExtractedAnswer(final_answer="sub_answer_async_json")
218 | fake_final_answer_obj_async = ExtractedAnswer(final_answer="ANS_async_json")
219 |
220 | solve_prompt_key_1 = "Current Subquestion: dummy sub1 async"
221 | solve_prompt_key_2 = "Current Subquestion: dummy sub2 async"
222 | final_prompt_key = f"Original Main Question: {question}" # Include the specific question
223 |
224 | llm = fake_llm_factory({
225 | "json_subquestions": LTMDecomposition(subquestions=subquestions),
226 | "responses_map": {
227 | solve_prompt_key_1: fake_sub_answer_obj_async,
228 | solve_prompt_key_2: fake_sub_answer_obj_async,
229 | final_prompt_key: fake_final_answer_obj_async
230 | }
231 | })
232 | ltm = LeastToMost(llm, intermediate_output_format="json")
233 | out = await ltm.run_async(question)
234 | assert out == "ANS_async_json"
235 |
236 | assert any(
237 | c["type"] == "_generate_json_internal_async" and c["response_model"] == "LTMDecomposition"
238 | for c in llm.async_calls)
239 | solve_json_calls = [c for c in llm.async_calls if
240 | c["type"] == "_generate_json_internal_async" and "Current Subquestion:" in
241 | c["prompt"]]
242 | final_json_call = next((c for c in llm.async_calls if
243 | c["type"] == "_generate_json_internal_async" and c["prompt"].startswith(
244 | "Based on the following")), None)
245 | assert len(solve_json_calls) == 2
246 | assert final_json_call is not None
247 |
248 |
249 | # --- Solve tests ---
250 | @pytest.mark.asyncio
251 | async def test_solve_async_calls_generate_async_text(fake_llm_factory):
252 | expected_sub_answer = "async_sub_answer_test_text"
253 | llm = fake_llm_factory({"sub_answer": expected_sub_answer})
254 | ltm = LeastToMost(llm, intermediate_output_format="text")
255 | solved = await ltm.solve_async("main q", ["sub1", "sub2"])
256 | assert solved == [("sub1", expected_sub_answer), ("sub2", expected_sub_answer)]
257 | assert len(llm.async_calls) == 2
258 | assert all(c["type"] == "generate_async" for c in llm.async_calls)
259 |
260 |
261 | @pytest.mark.asyncio
262 | async def test_solve_async_calls_generate_json_async(fake_llm_factory):
263 | expected_sub_answer_obj = ExtractedAnswer(final_answer="async_sub_answer_test_json")
264 | llm = fake_llm_factory({"json_answer": expected_sub_answer_obj})
265 | ltm = LeastToMost(llm, intermediate_output_format="json")
266 | solved = await ltm.solve_async("main q", ["sub1", "sub2"])
267 | assert solved == [("sub1", "async_sub_answer_test_json"),
268 | ("sub2", "async_sub_answer_test_json")]
269 | assert len(llm.async_calls) == 2
270 | assert all(c["type"] == "_generate_json_internal_async" for c in llm.async_calls)
271 | assert all(c["response_model"] == "ExtractedAnswer" for c in llm.async_calls)
272 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock, AsyncMock
2 |
3 | import pytest
4 | from pydantic import BaseModel, Field
5 |
6 | from cogitator import BaseLLM
7 | from cogitator import OllamaLLM
8 | from cogitator import OpenAILLM
9 |
10 |
11 | class DummySchema(BaseModel):
12 | name: str = Field(...)
13 | value: int = Field(..., gt=0)
14 |
15 |
16 | class SimpleSchema(BaseModel):
17 | answer: str
18 |
19 |
20 | @pytest.fixture
21 | def mock_openai_clients(mocker):
22 | mock_sync_client = MagicMock()
23 | mock_async_client = AsyncMock()
24 | mock_sync_client.chat.completions.create = MagicMock()
25 | mock_async_client.chat.completions.create = AsyncMock()
26 | mocker.patch("cogitator.model.openai.SyncOpenAI", return_value=mock_sync_client)
27 | mocker.patch("cogitator.model.openai.AsyncOpenAI", return_value=mock_async_client)
28 | return mock_sync_client, mock_async_client
29 |
30 |
31 | @pytest.fixture
32 | def mock_ollama_clients(mocker):
33 | mock_sync_client = MagicMock()
34 | mock_async_client = AsyncMock()
35 | mock_sync_client.chat = MagicMock()
36 | mock_async_client.chat = AsyncMock()
37 | mock_sync_client.list = MagicMock(return_value={'models': []})
38 | mocker.patch("cogitator.model.ollama.Client", new_callable=MagicMock, return_value=mock_sync_client)
39 | mocker.patch("cogitator.model.ollama.AsyncClient", return_value=mock_async_client)
40 | return mock_sync_client, mock_async_client
41 |
42 |
43 | def test_base_llm_abstract_methods():
44 | class ConcreteLLM(BaseLLM):
45 | def generate(self, prompt: str, **kwargs) -> str: return ""
46 |
47 | async def generate_async(self, prompt: str, **kwargs) -> str: return ""
48 |
49 | def generate_stream(self, prompt: str, **kwargs): yield ""
50 |
51 | async def generate_stream_async(self, prompt: str, **kwargs): yield ""
52 |
53 | def _generate_json_internal(self, prompt: str, response_model, **kwargs): return "{}", None
54 |
55 | async def _generate_json_internal_async(self, prompt: str, response_model,
56 | **kwargs): return "{}", None
57 |
58 | instance = ConcreteLLM()
59 | assert hasattr(instance, "generate")
60 |
61 |
62 | def test_extract_json_block(mock_ollama_clients):
63 | llm = OllamaLLM(model="dummy")
64 |
65 | text_fence = "```json\n{\"key\": \"value\"}\n```"
66 | text_fence_no_lang = "```\n{\"key\": \"value2\"}\n```"
67 | text_braces = "Some text {\"key\": \"value3\"} more text"
68 | text_brackets = "Some text [{\"key\": \"value4\"}] more text"
69 | text_both = "Text {\"a\":1} then [\"b\"] end"
70 | text_nested = "Text {\"a\": {\"b\": 1}} end"
71 | text_no_json = "Just plain text"
72 | text_empty = ""
73 |
74 | assert llm._extract_json_block(text_fence) == "{\"key\": \"value\"}"
75 | assert llm._extract_json_block(text_fence_no_lang) == "{\"key\": \"value2\"}"
76 | assert llm._extract_json_block(text_braces) == "{\"key\": \"value3\"}"
77 | assert llm._extract_json_block(text_brackets) == "[{\"key\": \"value4\"}]"
78 | assert llm._extract_json_block(text_both) == "{\"a\":1}"
79 | assert llm._extract_json_block(text_nested) == "{\"a\": {\"b\": 1}}"
80 | assert llm._extract_json_block(text_no_json) == "Just plain text"
81 | assert llm._extract_json_block(text_empty) == ""
82 |
83 |
84 | @pytest.mark.parametrize(
85 | "model_name, expected_mode, expect_format_present, expect_additional_props", [
86 | ("gpt-4o", "json_schema", True, False),
87 | ("gpt-4o-mini", "json_schema", True, False),
88 | ("gpt-4-turbo", "json_object", True, None),
89 | ("gpt-3.5-turbo-1106", "json_object", True, None),
90 | ("gpt-3.5-turbo-0613", "json_schema", True, False),
91 | ("unknown-model", "json_schema", True, False),
92 | ])
93 | def test_openai_prepare_api_params_json_modes(mock_openai_clients, model_name, expected_mode,
94 | expect_format_present, expect_additional_props):
95 | llm = OpenAILLM(api_key="dummy_key", model=model_name)
96 | params, mode = llm._prepare_api_params(is_json_mode=True, response_schema=DummySchema)
97 |
98 | if expect_format_present:
99 | assert "response_format" in params
100 | rf = params["response_format"]
101 | assert rf["type"] == expected_mode
102 | if expected_mode == "json_schema":
103 | assert "json_schema" in rf
104 | assert rf["json_schema"]["name"] == "DummySchema"
105 | schema = rf["json_schema"]["schema"]
106 | assert schema.get("additionalProperties") is expect_additional_props
107 | elif expected_mode == "json_object":
108 | assert "json_schema" not in rf
109 | assert expect_additional_props is None
110 | else:
111 | assert "response_format" not in params
112 | assert expect_additional_props is None
113 | if expect_format_present:
114 | assert mode == expected_mode
115 | else:
116 | assert mode is None
117 |
118 |
119 | def test_openai_prepare_api_params_no_schema_json_mode(mock_openai_clients):
120 | llm_json = OpenAILLM(api_key="d", model="gpt-4-turbo")
121 | params, mode = llm_json._prepare_api_params(is_json_mode=True, response_schema=None)
122 | assert params["response_format"]["type"] == "json_object"
123 | assert mode == "json_object"
124 |
125 | llm_no_json = OpenAILLM(api_key="d", model="gpt-3.5-turbo-0613")
126 | params, mode = llm_no_json._prepare_api_params(is_json_mode=True, response_schema=None)
127 | assert "response_format" not in params
128 | assert mode is None
129 |
130 |
131 | def test_openai_prepare_api_params_schema_generation_fails(mock_openai_clients, mocker):
132 | llm = OpenAILLM(api_key="d", model="gpt-4o")
133 | mocker.patch.object(DummySchema, "model_json_schema", side_effect=TypeError("Schema fail"))
134 | params, mode = llm._prepare_api_params(is_json_mode=True, response_schema=DummySchema)
135 | assert params["response_format"]["type"] == "json_object"
136 | assert mode == "json_object"
137 |
138 | llm_no_fallback = OpenAILLM(api_key="d", model="gpt-3.5-turbo-0613")
139 | mocker.patch.object(DummySchema, "model_json_schema",
140 | side_effect=TypeError("Schema fail again"))
141 | params_no_fallback, mode_no_fallback = llm_no_fallback._prepare_api_params(is_json_mode=True,
142 | response_schema=DummySchema)
143 | assert "response_format" not in params_no_fallback
144 | assert mode_no_fallback is None
145 |
146 |
147 | def test_ollama_init_success(mock_ollama_clients):
148 | mock_sync, mock_async = mock_ollama_clients
149 | llm = OllamaLLM(model="ollama-test", ollama_host="http://testhost:11434")
150 | assert llm.model == "ollama-test"
151 | assert llm.host == "http://testhost:11434"
152 | assert llm._client == mock_sync
153 | assert llm._async_client == mock_async
154 |
155 | # Verify constructors were called by the patcher via the fixture
156 | # Check the call args on the mock returned by the patcher
157 | from cogitator.model.ollama import Client, AsyncClient
158 | Client.assert_called_once_with(host="http://testhost:11434")
159 | AsyncClient.assert_called_once_with(host="http://testhost:11434")
160 |
161 |
162 | def test_ollama_strip_content(mock_ollama_clients):
163 | llm = OllamaLLM(model="dummy")
164 |
165 | response_dict = {"message": {"content": " Strip Me! "}}
166 | response_obj = MagicMock(message=MagicMock(content=" Strip Me Too! "))
167 | response_bad_obj = MagicMock(message=None)
168 | response_bad_dict = {"message": None}
169 | response_no_content = {"message": {"role": "assistant"}}
170 |
171 | assert llm._strip_content(response_dict) == "Strip Me!"
172 | assert llm._strip_content(response_obj) == "Strip Me Too!"
173 | assert llm._strip_content(response_bad_obj) == ""
174 | assert llm._strip_content(response_bad_dict) == ""
175 | assert llm._strip_content(response_no_content) == ""
176 | assert llm._strip_content(None) == ""
177 | assert llm._strip_content("string") == ""
178 |
179 |
180 | def test_ollama_prepare_options(mock_ollama_clients):
181 | llm = OllamaLLM(model="d", temperature=0.5, max_tokens=100, stop=["\n"], seed=1)
182 |
183 | opts = llm._prepare_options(temperature=0.8, seed=None, stop=["stop"], extra_param=True)
184 | assert opts == {"temperature": 0.8, "num_predict": 100, "stop": ["stop"], "extra_param": True}
185 |
186 | opts_defaults = llm._prepare_options()
187 | assert opts_defaults == {"temperature": 0.5, "num_predict": 100, "seed": 1, "stop": ["\n"]}
188 |
--------------------------------------------------------------------------------
/tests/test_sc_cot.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import pytest
4 |
5 | from cogitator import ExtractedAnswer
6 | from cogitator import SelfConsistency
7 |
8 |
9 | def test_extract_answer_heuristic_with_equal_sign():
10 | sc = SelfConsistency(llm=None)
11 | assert sc._extract_answer_heuristic("…\n4+1 = 5") == "5"
12 | assert sc._extract_answer_heuristic("Result=42.") == "42"
13 |
14 |
15 | def test_extract_answer_heuristic_with_prefixes():
16 | sc = SelfConsistency(llm=None)
17 | assert sc._extract_answer_heuristic("Therefore, the answer is 7.") == "7"
18 | assert sc._extract_answer_heuristic("Answer: 99") == "99"
19 | assert sc._extract_answer_heuristic("Ans 13") == "13"
20 | assert sc._extract_answer_heuristic("### 123") == "123"
21 | assert sc._extract_answer_heuristic("Final Answer: foo bar") == "foo bar"
22 |
23 |
24 | def test_extract_answer_heuristic_last_line_fallback():
25 | sc = SelfConsistency(llm=None)
26 | cot = "Step1\nSome thought\nFinal numerical answer 100"
27 | assert sc._extract_answer_heuristic(cot) == "Final numerical answer 100"
28 |
29 |
30 | def test_extract_answer_heuristic_numeric_last_line():
31 | sc = SelfConsistency(llm=None)
32 | cot = "Step1\nSome thought\n12345"
33 | assert sc._extract_answer_heuristic(cot) == "12345"
34 | cot_dollar = "Step1\nSome thought\n$56.78"
35 | assert sc._extract_answer_heuristic(cot_dollar) == "56.78"
36 |
37 |
38 | def test_run_majority_vote_heuristic(fake_llm_factory):
39 | responses = ["...\nAns 10", "...\nFinal Answer: 10", "...\nResult=5"]
40 | llm = fake_llm_factory({"generate_sync": responses})
41 | sc = SelfConsistency(llm=llm, n_samples=3) # Defaults to heuristic
42 | out = sc.run("dummy prompt")
43 | assert out == "10"
44 | assert len(llm.sync_calls) == 3
45 |
46 |
47 | def test_extract_answer_json_path(fake_llm_factory):
48 | expected_answer = "AnswerFromJSON"
49 | llm = fake_llm_factory({
50 | "json_answer": ExtractedAnswer(final_answer=expected_answer)
51 | })
52 | sc = SelfConsistency(llm=llm, internal_extraction_format="json")
53 | extracted = sc.extract_answer("Some reasoning text...")
54 | assert extracted == expected_answer
55 | assert len(llm.sync_calls) == 1
56 | assert llm.sync_calls[0]["type"] == "_generate_json_internal"
57 | assert llm.sync_calls[0]["response_model"] == "ExtractedAnswer"
58 |
59 |
60 | def test_run_majority_vote_json(fake_llm_factory):
61 | cot_responses = ["CoT leading to 10", "Another CoT leading to 10", "CoT leading to 5"]
62 | extraction_responses = [
63 | ExtractedAnswer(final_answer="10"),
64 | ExtractedAnswer(final_answer="10"),
65 | ExtractedAnswer(final_answer="5")
66 | ]
67 | # Configure mock: generate CoTs, then respond to extraction prompts
68 | llm = fake_llm_factory({
69 | "generate_sync": cot_responses,
70 | "json_answer": extraction_responses # Mock should pick this for JSON extraction calls
71 | })
72 | sc = SelfConsistency(llm=llm, n_samples=3, internal_extraction_format="json")
73 | out = sc.run("dummy prompt for json run")
74 | assert out == "10" # Verify majority vote result
75 | assert len(llm.sync_calls) == 6 # 3 generate + 3 json_internal
76 | assert sum(1 for c in llm.sync_calls if c["type"] == "generate") == 3
77 | assert sum(1 for c in llm.sync_calls if c["type"] == "_generate_json_internal") == 3
78 |
79 |
80 | def test_run_stream_not_implemented(fake_llm_factory):
81 | llm = fake_llm_factory()
82 | sc = SelfConsistency(llm=llm)
83 | with pytest.raises(NotImplementedError):
84 | next(sc.run_stream("anything"))
85 |
86 |
87 | @pytest.mark.asyncio
88 | async def test_extract_answer_async_heuristic():
89 | sc = SelfConsistency(llm=None)
90 | assert await sc.extract_answer_async("...\nFinal Answer: ABC_async") == "ABC_async"
91 | assert await sc.extract_answer_async("...\nResult=55_async") == "55_async"
92 |
93 |
94 | @pytest.mark.asyncio
95 | async def test_run_async_majority_vote_heuristic(fake_llm_factory):
96 | responses = ["...\nAns Async10", "...\nFinal Answer: Async10", "...\nResult=Async5"]
97 | llm = fake_llm_factory({"generate_async": responses})
98 | sc = SelfConsistency(llm=llm, n_samples=3)
99 | out = await sc.run_async("dummy async prompt")
100 | assert out == "Async10"
101 | assert len(llm.async_calls) == 3
102 |
103 |
104 | @pytest.mark.asyncio
105 | async def test_extract_answer_json_async_path(fake_llm_factory):
106 | expected_answer = "AsyncAnswerFromJSON"
107 | llm = fake_llm_factory({
108 | "json_answer": ExtractedAnswer(final_answer=expected_answer)
109 | })
110 | sc = SelfConsistency(llm=llm, internal_extraction_format="json")
111 | extracted = await sc.extract_answer_async("Some async reasoning text...")
112 | assert extracted == expected_answer
113 | assert len(llm.async_calls) == 1
114 | assert llm.async_calls[0]["type"] == "_generate_json_internal_async"
115 | assert llm.async_calls[0]["response_model"] == "ExtractedAnswer"
116 |
117 |
118 | @pytest.mark.asyncio
119 | async def test_run_async_majority_vote_json(fake_llm_factory):
120 | cot_responses_async = ["Async CoT 10", "Async CoT 10 again", "Async CoT 5"]
121 | extraction_responses_async = [
122 | ExtractedAnswer(final_answer="10"),
123 | ExtractedAnswer(final_answer="10"),
124 | ExtractedAnswer(final_answer="5")
125 | ]
126 | llm = fake_llm_factory({
127 | "generate_async": cot_responses_async,
128 | "json_answer": extraction_responses_async
129 | })
130 | sc = SelfConsistency(llm=llm, n_samples=3, internal_extraction_format="json")
131 | out = await sc.run_async("dummy async json run prompt")
132 | assert out == "10" # Verify majority vote result
133 | assert len(llm.async_calls) == 6 # 3 generate_async + 3 json_internal_async
134 | assert sum(1 for c in llm.async_calls if c["type"] == "generate_async") == 3
135 | assert sum(1 for c in llm.async_calls if c["type"] == "_generate_json_internal_async") == 3
136 |
137 |
138 | @pytest.mark.asyncio
139 | async def test_run_async_with_semaphore(fake_llm_factory):
140 | responses = ["…\nAnswer: S1", "…\nAnswer: S2", "…\nAnswer: S1"]
141 | llm = fake_llm_factory({"generate_async": responses})
142 | sc = SelfConsistency(llm=llm, n_samples=3)
143 | semaphore = asyncio.Semaphore(2)
144 | out = await sc.run_async("dummy async prompt semaphore", semaphore=semaphore)
145 | assert out == "S1"
146 | assert len(llm.async_calls) == 3
147 |
148 |
149 | @pytest.mark.asyncio
150 | async def test_run_stream_async_not_implemented(fake_llm_factory):
151 | llm = fake_llm_factory()
152 | sc = SelfConsistency(llm=llm)
153 | with pytest.raises(NotImplementedError):
154 | # Corrected: await the coroutine
155 | await sc.run_stream_async("anything")
156 |
--------------------------------------------------------------------------------
/tests/test_tree_of_thoughts.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cogitator import ThoughtExpansion, EvaluationResult
4 | from cogitator import TreeOfThoughts
5 |
6 |
7 | def test_run_returns_final_and_calls_prompts(fake_llm_factory):
8 | fake_expansion = ThoughtExpansion(thoughts=["step1_sync"])
9 | fake_eval = EvaluationResult(score=8, justification="Okay_sync")
10 | llm = fake_llm_factory({
11 | "json_steps": fake_expansion,
12 | "json_eval": fake_eval,
13 | "final_answer": "FINAL_sync"
14 | })
15 | tot = TreeOfThoughts(llm, max_depth=1, num_branches=1, sims=1, c_puct=1.0)
16 | out = tot.run("test?")
17 |
18 | assert out == "FINAL_sync"
19 |
20 | expand_call = next((c for c in llm.sync_calls if
21 | c["type"] == "_generate_json_internal" and "JSON Output:" in c[
22 | "prompt"] and "thoughts" in c["prompt"]),
23 | None)
24 | eval_call = next((c for c in llm.sync_calls if
25 | c["type"] == "_generate_json_internal" and "JSON Evaluation:" in c["prompt"]),
26 | None)
27 | final_call = next((c for c in llm.sync_calls if
28 | c["type"] == "generate" and (
29 | "Given reasoning steps" in c["prompt"] or c["prompt"].startswith(
30 | "Answer the question:"))),
31 | None)
32 |
33 | assert expand_call is not None, "Expansion call not found"
34 | assert expand_call["response_model"] == "ThoughtExpansion"
35 | assert eval_call is not None, "Evaluation call not found"
36 | assert eval_call["response_model"] == "EvaluationResult"
37 | assert final_call is not None, "Final answer generation call not found"
38 |
39 |
40 | @pytest.mark.asyncio
41 | async def test_run_async_returns_final_and_calls_prompts(fake_llm_factory):
42 | fake_expansion_async = ThoughtExpansion(thoughts=["step1_async"])
43 | fake_eval_async = EvaluationResult(score=8, justification="Okay_async")
44 | llm = fake_llm_factory({
45 | "json_steps": fake_expansion_async,
46 | "json_eval": fake_eval_async,
47 | "final_answer": "FINAL_async"
48 | })
49 | tot = TreeOfThoughts(llm, max_depth=1, num_branches=1, sims=1, c_puct=1.0)
50 | out = await tot.run_async("test_async?")
51 |
52 | assert out == "FINAL_async"
53 |
54 | expand_call = next((c for c in llm.async_calls if
55 | c["type"] == "_generate_json_internal_async" and "JSON Output:" in c[
56 | "prompt"] and "thoughts" in c["prompt"]),
57 | None)
58 | eval_call = next((c for c in llm.async_calls if
59 | c["type"] == "_generate_json_internal_async" and "JSON Evaluation:" in c[
60 | "prompt"]),
61 | None)
62 | final_call = next((c for c in llm.async_calls if
63 | c["type"] == "generate_async" and (
64 | "Given reasoning steps" in c["prompt"] or c["prompt"].startswith(
65 | "Answer the question:"))),
66 | None)
67 |
68 | assert expand_call is not None, "Async expansion call not found"
69 | assert expand_call["response_model"] == "ThoughtExpansion"
70 | assert eval_call is not None, "Async evaluation call not found"
71 | assert eval_call["response_model"] == "EvaluationResult"
72 | assert final_call is not None, "Async final answer generation call not found"
73 |
--------------------------------------------------------------------------------