├── .cursorrules ├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── .replit ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── build-dev.sh ├── build.sh ├── build_e2b_image.sh ├── e2b.Dockerfile ├── eval_requirements.txt ├── poetry.toml ├── pyproject.toml ├── replit.nix ├── requirements.txt ├── run.sh ├── src ├── __init__.py └── wandbot │ ├── __init__.py │ ├── api │ ├── __init__.py │ ├── app.py │ ├── client.py │ └── routers │ │ ├── __init__.py │ │ ├── chat.py │ │ ├── database.py │ │ └── retrieve.py │ ├── apps │ ├── __init__.py │ ├── discord │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── config.py │ ├── slack │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── config.py │ │ └── formatter.py │ └── utils.py │ ├── chat │ ├── __init__.py │ ├── chat.py │ ├── rag.py │ ├── schemas.py │ └── utils.py │ ├── configs │ ├── __init__.py │ ├── app_config.py │ ├── chat_config.py │ ├── database_config.py │ ├── ingestion_config.py │ └── vector_store_config.py │ ├── database │ ├── __init__.py │ ├── client.py │ ├── config.py │ ├── database.py │ ├── models.py │ └── schemas.py │ ├── evaluation │ ├── README.md │ ├── __init__.py │ ├── eval.py │ ├── eval_config.py │ ├── eval_metrics │ │ └── correctness.py │ └── utils │ │ ├── jp_evaluation_dataprep.py │ │ ├── jp_evaluation_dataupload.py │ │ ├── log_data.py │ │ └── utils.py │ ├── ingestion │ ├── __init__.py │ ├── __main__.py │ ├── prepare_data.py │ ├── preprocess_data.py │ ├── preprocessors │ │ ├── markdown.py │ │ └── source_code.py │ ├── run_ingestion_config.py │ ├── utils.py │ └── vectorstore_and_report.py │ ├── models │ ├── embedding.py │ └── llm.py │ ├── rag │ ├── __init__.py │ ├── query_handler.py │ ├── response_synthesis.py │ ├── retrieval.py │ └── utils.py │ ├── retriever │ ├── __init__.py │ ├── base.py │ ├── chroma.py │ ├── mmr.py │ ├── utils.py │ └── web_search.py │ ├── schema │ ├── api_status.py │ ├── document.py │ └── retrieval.py │ └── utils.py ├── tests ├── conftest.py ├── evaluation │ ├── test_correctness.py │ ├── test_eval.py │ └── test_eval_config.py ├── test_config.py ├── test_embedding.py ├── test_error_propagation.py ├── test_llm.py ├── test_model_config.py ├── test_query_handler.py ├── test_response_synthesis.py └── test_retrieval.py └── uv.lock /.cursorrules: -------------------------------------------------------------------------------- 1 | When running tests in this codebase: 2 | 3 | # Testing 4 | 1. Use the following pytest flags to prevent early exit issues and ensure complete test output: 5 | ```bash 6 | python -m pytest tests/ -v --tb=short --capture=tee-sys 7 | ``` 8 | 9 | 2. These flags help in the following ways: 10 | - `-v`: Verbose output 11 | - `--tb=short`: Short traceback format 12 | - `--capture=tee-sys`: Proper output capture that prevents early termination 13 | 14 | 3. This is particularly important for async tests and tests involving API calls or event loops. 15 | 16 | 4. If you need to debug a specific test, you can run it in isolation: 17 | ```bash 18 | python -m pytest tests/path_to_test.py::test_name -v --tb=short --capture=tee-sys 19 | ``` 20 | 21 | Remember to use these flags when running tests to ensure reliable test execution and complete output. 22 | 23 | # Using Weave to analyze logged inputs and outputs 24 | 25 | The Weave api can be used to analyze logged inputs and outputs. Here is an example of iterating over the 26 | input documents to this call and extracting the ids. 27 | 28 | Search the Weave documentation for more information on how to use the Weave api. 29 | 30 | ```python 31 | import weave 32 | client = weave.init("wandbot/wandbot-dev") 33 | candidate_call = client.get_call("0194b427-ba78-77f3-9989-222419262817") 34 | final_candidate_ids = [doc.metadata["id"] for doc in candidate_call.inputs["inputs"].documents] 35 | ``` -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '29 23 * * 6' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Use only 'java' to analyze code written in Java, Kotlin or both 38 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 39 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 40 | 41 | steps: 42 | - name: Checkout repository 43 | uses: actions/checkout@v3 44 | 45 | # Initializes the CodeQL tools for scanning. 46 | - name: Initialize CodeQL 47 | uses: github/codeql-action/init@v2 48 | with: 49 | languages: ${{ matrix.language }} 50 | # If you wish to specify custom queries, you can do so here or in a config file. 51 | # By default, queries listed here will override any specified in a config file. 52 | # Prefix the list here with "+" to use these queries and those in the config file. 53 | 54 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 55 | # queries: security-extended,security-and-quality 56 | 57 | 58 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 59 | # If this step fails, then you should remove it and run the build manually (see below) 60 | - name: Autobuild 61 | uses: github/codeql-action/autobuild@v2 62 | 63 | # ℹ️ Command-line programs to run using the OS shell. 64 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 65 | 66 | # If the Autobuild fails above, remove it and uncomment the following three lines. 67 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 68 | 69 | # - run: | 70 | # echo "Run, Build Application using script" 71 | # ./location_of_script_within_repo/buildscript.sh 72 | 73 | - name: Perform CodeQL Analysis 74 | uses: github/codeql-action/analyze@v2 75 | with: 76 | category: "/language:${{matrix.language}}" 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *. 4 | *.py[cod] 5 | *$py.class 6 | 7 | temp_index/ 8 | e2b* 9 | 10 | # C extensions 11 | *.so 12 | 13 | testing_*.py 14 | testing_*.ipynb 15 | temp_* 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | .pythonlibs/ 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | wandbot_venv/ 117 | 3-10_env/ 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | wandb/ 143 | artifacts/ 144 | data/ 145 | .idea/ 146 | .aider -------------------------------------------------------------------------------- /.replit: -------------------------------------------------------------------------------- 1 | entrypoint = "main.py" 2 | modules = ["python-3.12"] 3 | 4 | [nix] 5 | channel = "stable-24_05" 6 | 7 | [unitTest] 8 | language = "python3" 9 | 10 | [gitHubImport] 11 | requiredFiles = [".replit", "replit.nix"] 12 | 13 | [deployment] 14 | run = ["python3", "main.py"] 15 | deploymentTarget = "cloudrun" 16 | 17 | [[ports]] 18 | localPort = 8000 19 | externalPort = 80 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to the W&B Q&A Chatbot 2 | 3 | We're excited that you're considering contributing to the W&B Q&A Chatbot! Your contributions will help improve the chatbot's performance and user experience. Please read the following guidelines to ensure a smooth contribution process. 4 | 5 | ## Getting Started 6 | 7 | 1. Fork the repository on GitHub. 8 | 2. Clone the forked repository to your local machine. 9 | 3. Set up a virtual environment and install the required dependencies. 10 | 11 | ## How to Contribute 12 | 13 | 1. **Bug Reports**: If you find a bug, please open an issue on GitHub with a clear description of the problem, and include any relevant logs, screenshots, or code samples. 14 | 15 | 2. **Feature Requests**: If you have an idea for a new feature or improvement, please open an issue on GitHub with a detailed explanation of the feature, its benefits, and any proposed implementation details. 16 | 17 | 3. **Code Contributions**: If you'd like to contribute code directly, follow these steps: 18 | 19 | - Make sure you've set up the development environment as described above. 20 | - Create a new branch for your feature or bugfix. Use a descriptive branch name, such as `feature/new-feature` or `bugfix/issue-123`. 21 | - Make your changes, following the existing code style and conventions. 22 | - Add tests for your changes to ensure they work correctly and maintain compatibility with existing code. 23 | - Run tests and ensure they pass. 24 | - Update the documentation as necessary to reflect your changes. 25 | - Commit your changes with a clear and concise commit message. 26 | - Push your changes to your fork on GitHub. 27 | - Create a pull request from your fork to the main repository. In the pull request description, provide an overview of your changes, any relevant issue numbers, and a summary of the testing you've performed. 28 | - Address any feedback or requested changes from the project maintainers. 29 | 30 | ## Code of Conduct 31 | 32 | Please be respectful and considerate of other contributors. We are committed to fostering a welcoming and inclusive community. Harassment, discrimination, and offensive behavior will not be tolerated. By participating in this project, you agree to adhere to these principles. 33 | 34 | ## Contact 35 | 36 | If you have any questions, concerns, or need assistance, please reach out to the project maintainers through GitHub or the official communication channels. 37 | 38 | Thank you for your interest in contributing to the W&B Q&A Chatbot! 39 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Reporting a Vulnerability 4 | 5 | Please report all vulnerabilities to security@wandb.com. 6 | -------------------------------------------------------------------------------- /build-dev.sh: -------------------------------------------------------------------------------- 1 | pip install fasttext && \ 2 | poetry install --all-extras && \ 3 | pip install protobuf==3.19.6 && \ 4 | poetry build && \ 5 | mkdir -p ./data/cache -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | echo "Running build.sh" 2 | set -x # Enable command echo 3 | set -e # Exit on error 4 | 5 | # # Debug disk usage 6 | # du -sh . 7 | # top_usage=$(du -ah . | sort -rh | head -n 20) 8 | # current_dir_usage=$(du -sm . | awk '{print $1}') 9 | # echo -e "Current directory usage: ${current_dir_usage}M" 10 | # echo -e "Top files/dirs usage: ${top_usage}\n" 11 | 12 | # Find libstdc++ to use 13 | for dir in /nix/store/*-gcc-*/lib64 /nix/store/*-stdenv-*/lib /nix/store/*-libstdc++*/lib; do 14 | echo "Checking directory: $dir" # Add this line for debugging 15 | if [ -f "$dir/libstdc++.so.6" ]; then 16 | export LD_LIBRARY_PATH="$dir:$LD_LIBRARY_PATH" 17 | echo "Found libstdc++.so.6 in $dir" 18 | break 19 | fi 20 | done 21 | 22 | # Create virtualenv & set up 23 | rm -rf .venv 24 | 25 | # Use uv for faster installs 26 | 27 | pip install --user pip uv --upgrade 28 | # pip install --no-user pip uv --upgrade 29 | 30 | # python3.12 -m venv wandbot_venv --clear 31 | # source wandbot_venv/bin/activate 32 | uv venv --python python3.12 33 | source .pythonlibs/bin/activate 34 | 35 | # export VIRTUAL_ENV=wandbot_venv 36 | # export PATH="$VIRTUAL_ENV/bin:$PATH" 37 | # export PYTHONPATH="$(pwd)/src:$PYTHONPATH" 38 | # Only set a narrow python path, excludes numpy 1.24 39 | # export PYTHONPATH=/home/runner/workspace/src:/home/runner/workspace/wandbot_venv/lib/python3.12/site-packages 40 | export PYTHONPATH=/home/runner/workspace/src:/home/runner/workspace/.pythonlibs/lib/python3.12/site-packages 41 | 42 | uv pip install "numpy>=2.2.0" --force-reinstall 43 | 44 | # Clear any existing installations that might conflict 45 | rm -rf $VIRTUAL_ENV/lib/python*/site-packages/typing_extensions* 46 | rm -rf $VIRTUAL_ENV/lib/python*/site-packages/pydantic* 47 | rm -rf $VIRTUAL_ENV/lib/python*/site-packages/fastapi* 48 | 49 | # Install dependencies 50 | uv pip install -r requirements.txt --no-cache 51 | 52 | # Re-install problematic package 53 | uv pip install --no-cache-dir --force-reinstall typing_extensions==4.12.2 54 | 55 | # Install app 56 | uv pip install . --no-deps 57 | 58 | # Check if the package is installed correctly 59 | python -c "import wandbot; print('Wandbot package installed successfully')" 60 | 61 | # Free up disk space 62 | pip cache purge 63 | 64 | mkdir -p ./data/cache 65 | 66 | # Debug information 67 | echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" 68 | ls -la $LIBSTDCXX_DIR/libstdc++.so* || true 69 | ldd $VIRTUAL_ENV/lib/python*/site-packages/pandas/_libs/*.so || true 70 | 71 | # Debug disk usage 72 | du -sh . 73 | top_usage=$(du -ah . | sort -rh | head -n 20) 74 | current_disk_usage=$(du -sm . | awk '{print $1}') 75 | echo -e "Current directory usage: ${current_dir_usage}M" 76 | echo -e "Top files/dirs usage: ${top_usage}\n" 77 | increment=$((current_disk_usage - initial_disk_usage)) 78 | echo -e "Disk usage increment: ${increment}M\n" -------------------------------------------------------------------------------- /build_e2b_image.sh: -------------------------------------------------------------------------------- 1 | # Check if argument is provided 2 | if [ $# -eq 0 ]; then 3 | echo "Error: Missing the `WANDBOT_COMMIT` argument" 4 | echo "Error: Please pass the wandbot commit has or branch name to check out." 5 | echo "Usage: ./build_e2b_image.sh " 6 | exit 1 7 | fi 8 | 9 | WANDBOT_COMMIT=$1 10 | 11 | # Set wandb key from .env file for wandb artifacts download 12 | WANDB_API_KEY=$(grep WANDB_API_KEY .env | cut -d= -f2) 13 | export WANDB_API_KEY 14 | 15 | # Download index, the index used will be what is set in the src/wandbot/configsvectorstore_config.py files 16 | rm -rf temp_index 17 | mkdir -p temp_index 18 | python download_vectordb_index.py --vectordb_index_dir=temp_index # Save index to a new temp dir to avoid mistaken index uploads 19 | 20 | # Build image, docker will copy temp_index dir into the image 21 | e2b template build --build-arg WANDBOT_COMMIT="${WANDBOT_COMMIT}" -n "wandbot_$WANDBOT_COMMIT" -c "/root/.jupyter/start-up.sh" 22 | rm -rf temp_index -------------------------------------------------------------------------------- /e2b.Dockerfile: -------------------------------------------------------------------------------- 1 | # Install dependencies and customize sandbox 2 | FROM e2bdev/code-interpreter:python-3.12.8 3 | 4 | # Set working directory 5 | WORKDIR /home/user 6 | 7 | # Install Python 3.12 and set it as default 8 | RUN apt-get update && apt-get install -y \ 9 | wget \ 10 | gpg \ 11 | libstdc++6 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # Clone into workspace directory 16 | # Pass a new value for CACHE_BUST in the docker --build-args 17 | # to invalidate the cache from here and trigger a fresh git pull and build from here 18 | ARG WANDBOT_COMMIT 19 | ARG CACHE_BUST=1 20 | RUN git clone https://github.com/wandb/wandbot.git /home/user/wandbot && \ 21 | cd /home/user/wandbot && \ 22 | git checkout $WANDBOT_COMMIT 23 | 24 | RUN pip install uv 25 | 26 | # Set working directory 27 | WORKDIR /home/user/wandbot 28 | 29 | # Set LD_LIBRARY_PATH before running build.sh 30 | RUN export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH && \ 31 | (bash build.sh || true) 32 | 33 | RUN uv pip install --system . 34 | RUN poetry install 35 | 36 | # Copy in the vector index 37 | COPY temp_index/* /home/user/temp_index/ 38 | 39 | RUN export INDEX_DIR=$(python -c 'from wandbot.configs.vector_store_config import VectorStoreConfig; \ 40 | index_dir = VectorStoreConfig().index_dir; \ 41 | print(index_dir, end="")') && \ 42 | mkdir -p $INDEX_DIR && \ 43 | cp -r /home/user/temp_index/* $INDEX_DIR/ && \ 44 | rm -rf /home/user/temp_index 45 | -------------------------------------------------------------------------------- /eval_requirements.txt: -------------------------------------------------------------------------------- 1 | dataclasses-json >= 0.6.4 2 | simple-parsing >= 0.1.6 3 | gunicorn >= 23.0.0 4 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = false -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "wandbot" 3 | version = "1.3.2" 4 | description = "WandBot is a Q&A bot for Weights & Biases Models and Weave documentation" 5 | readme = "README.md" 6 | requires-python = ">=3.11,<3.13" 7 | license = {text = "Apache-2.0"} 8 | authors = [ 9 | {name = "parambharat", email = "bharat.ramanathan@wandb.com"}, 10 | {name = "morganmcg1", email = "morganmcg1@users.noreply.github.com"}, 11 | {name = "ayulockin", email = "ayusht@wandb.com"}, 12 | {name = "ash0ts", email = "anish@wandb.com"} 13 | ] 14 | dependencies = [ 15 | "pandas>=2.1.2", 16 | "pydantic>=2.10.0", 17 | "pydantic-settings>=2.5.1", 18 | "gitpython>=3.1.40", 19 | "giturlparse>=0.12.0", 20 | "scikit-learn>=1.3.2", 21 | "python-dotenv>=1.0.0", 22 | "slack-bolt>=1.18.0", 23 | "slack-sdk>=3.21.3", 24 | "discord>=2.3.2", 25 | "markdown>=3.5.1", 26 | "fastapi>=0.115.0", 27 | "uvicorn>=0.24.0", 28 | "openai>=1.71.0", 29 | "anthropic>=0.49.0", 30 | "google-genai>=1.9.0", 31 | "cohere>=5.13.0", 32 | "langchain>=0.2.2", 33 | "langchain-core>=0.2.2", 34 | "chromadb==1.0.3", 35 | "weave>=0.51.47", 36 | "wandb[workspaces]>=0.19.9", 37 | "tiktoken", 38 | "fasttext-wheel", 39 | "tree-sitter-languages", 40 | "markdownify>=0.11.6", 41 | "colorlog>=6.8.0", 42 | "google-cloud-bigquery>=3.14.1", 43 | "python-frontmatter>=1.1.0", 44 | "nbformat>=5.10.4", 45 | "nbconvert>=7.16.4", 46 | "langchain-community>=0.3.24", 47 | "simple-parsing>=0.1.7", 48 | "pymdown-extensions>=10.15", 49 | "tree-sitter==0.21.3", 50 | ] 51 | 52 | [tool.isort] 53 | profile = "black" 54 | line_length = 80 55 | skip = [".gitignore", "data", "examples", "notebooks", "artifacts", ".vscode", ".github", ".idea", ".replit", "*.md", "wandb", ".env", ".git", ] 56 | 57 | [tool.black] 58 | line-length = 80 59 | skip = [".gitignore", "data", "examples", "notebooks", "artifacts", ".vscode", ".github", ".idea", ".replit", "*.md", "wandb", ".env", ".git", ] 60 | 61 | [tool.pytest.ini_options] 62 | asyncio_mode = "strict" 63 | asyncio_default_fixture_loop_scope = "function" 64 | markers = [ 65 | "integration: marks tests that make real API calls to external services", 66 | ] 67 | 68 | [tool.ruff] 69 | # Line length setting 70 | line-length = 120 71 | 72 | # Enable specific rule categories 73 | select = [ 74 | "E", # pycodestyle errors 75 | "F", # pyflakes 76 | "I", # isort 77 | "B", # flake8-bugbear 78 | "C4", # flake8-comprehensions 79 | ] 80 | 81 | # Ignore specific rules 82 | ignore = [ 83 | "E501", # Line too long (handled by formatter) 84 | ] 85 | 86 | # Exclude directories/files 87 | exclude = [ 88 | ".git", 89 | ".venv", 90 | "venv", 91 | "__pycache__", 92 | "build", 93 | "dist", 94 | ] 95 | 96 | # Additional configurations 97 | unfixable = ["F401"] # Don't auto-fix unused imports 98 | 99 | # Allow unused variables when they start with an underscore 100 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 101 | 102 | # Specific rule configurations 103 | [tool.ruff.mccabe] 104 | max-complexity = 10 105 | 106 | [tool.ruff.isort] 107 | known-first-party = ["mypackage"] # Replace with your package name 108 | 109 | # Per-file ignores 110 | [tool.ruff.per-file-ignores] 111 | "__init__.py" = ["F401"] # Ignore unused imports in __init__.py files 112 | "tests/*.py" = ["E501"] # Ignore line length in test files 113 | -------------------------------------------------------------------------------- /replit.nix: -------------------------------------------------------------------------------- 1 | { pkgs }: { 2 | deps = with pkgs; [ 3 | # Core system libraries 4 | stdenv.cc.cc.lib 5 | libstdcxx5 6 | 7 | # Make sure gcc and its runtime are available, needed for fasttext 8 | gcc 9 | gcc.cc.lib 10 | 11 | # Other dependencies 12 | bash 13 | hydrus 14 | gitFull 15 | glibcLocales 16 | python311 17 | ]; 18 | 19 | env = { 20 | APPEND_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib"; 21 | NIXPKGS_ALLOW_UNFREE = "1"; 22 | LD_LIBRARY_PATH = let 23 | libraries = [ 24 | pkgs.stdenv.cc.cc.lib 25 | pkgs.libstdcxx5 26 | pkgs.gcc.cc.lib 27 | ]; 28 | in "${pkgs.lib.makeLibraryPath libraries}"; 29 | }; 30 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # numpy>=2.2.0 2 | pandas>=2.1.2 3 | pydantic>=2.10.0 4 | pydantic-settings>=2.5.1 5 | gitpython>=3.1.40 6 | giturlparse>=0.12.0 7 | scikit-learn>=1.3.2 8 | python-dotenv>=1.0.0 9 | slack-bolt>=1.18.0 10 | slack-sdk>=3.21.3 11 | discord>=2.3.2 12 | markdown>=3.5.1 13 | fastapi>=0.115.0 14 | uvicorn>=0.24.0 15 | openai>=1.71.0 16 | anthropic>=0.49.0 17 | google-genai>=1.9.0 18 | cohere>=5.13.0 19 | langchain >= 0.2.2 20 | langchain-core >= 0.2.2 21 | 22 | # chromadb>=0.6.0 23 | chromadb == 1.0.3 24 | weave>=0.51.47 25 | wandb[workspaces]>=0.19.9 26 | tiktoken 27 | fasttext-wheel 28 | tree-sitter-languages 29 | markdownify>=0.11.6 30 | colorlog>=6.8.0 31 | google-cloud-bigquery>=3.14.1 32 | # db-dtypes>=1.2.0 33 | python-frontmatter>=1.1.0 34 | # pymdown-extensions>=10.5 35 | # # simsimd==3.7.7 36 | nbformat>=5.10.4 37 | nbconvert>=7.16.4 38 | tree-sitter 39 | # typing_extensions>=4.12.0 40 | # nest-asyncio==1.6.0 41 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | set -e # Exit on error 2 | 3 | # Ensure we're using the virtual environment from build.sh 4 | # export VIRTUAL_ENV=wandbot_venv 5 | # export PATH="$VIRTUAL_ENV/bin:$PATH" 6 | # export PYTHONPATH="$(pwd)/src:$PYTHONPATH" 7 | export PYTHONPATH=/home/runner/workspace/src:/home/runner/workspace/.pythonlibs/lib/python3.12/site-packages 8 | 9 | # source wandbot_venv/bin/activate 10 | source .pythonlibs/bin/activate 11 | 12 | echo "Starting Wandbot application..." 13 | 14 | # Function to start a service with logging to stdout 15 | start_service() { 16 | echo "Starting service: $*" 17 | "$@" || { 18 | echo "Failed to start service: $*" >&2 19 | return 1 20 | } 21 | } 22 | 23 | # Print all python prints 24 | export PYTHONUNBUFFERED=1 25 | 26 | # Start all services 27 | (uv run uvicorn wandbot.api.app:app --host 0.0.0.0 --port 8000 --workers 2) & \ 28 | ($VIRTUAL_ENV/bin/python -m wandbot.apps.slack -l en) & \ 29 | ($VIRTUAL_ENV/bin/python -m wandbot.apps.slack -l ja) & \ 30 | ($VIRTUAL_ENV/bin/python -m wandbot.apps.discord) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/__init__.py -------------------------------------------------------------------------------- /src/wandbot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/__init__.py -------------------------------------------------------------------------------- /src/wandbot/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/api/__init__.py -------------------------------------------------------------------------------- /src/wandbot/api/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/api/routers/__init__.py -------------------------------------------------------------------------------- /src/wandbot/api/routers/chat.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException 2 | from starlette import status 3 | 4 | from wandbot.chat.schemas import ChatRequest, ChatResponse 5 | from wandbot.utils import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | class APIQueryRequest(ChatRequest): 11 | pass 12 | 13 | 14 | class APIQueryResponse(ChatResponse): 15 | pass 16 | 17 | 18 | # Store initialization components 19 | chat_components = { 20 | "vector_store": None, 21 | "chat_config": None, 22 | "chat": None, # We'll store the actual Chat instance here 23 | } 24 | 25 | router = APIRouter(prefix="/chat", tags=["chat"]) 26 | 27 | 28 | @router.post( 29 | "/query", response_model=APIQueryResponse, status_code=status.HTTP_200_OK 30 | ) 31 | async def query(request: APIQueryRequest) -> APIQueryResponse: 32 | if not chat_components.get("chat"): 33 | raise HTTPException( 34 | status_code=503, detail="Chat service is not yet initialized" 35 | ) 36 | 37 | try: 38 | chat_instance = chat_components["chat"] 39 | result = await chat_instance.__acall__( 40 | ChatRequest( 41 | question=request.question, 42 | chat_history=request.chat_history, 43 | language=request.language, 44 | application=request.application, 45 | ), 46 | ) 47 | return APIQueryResponse(**result.model_dump()) 48 | except Exception as e: 49 | logger.error(f"Error processing chat query: {e}") 50 | raise HTTPException( 51 | status_code=500, detail="Error processing chat query" 52 | ) 53 | -------------------------------------------------------------------------------- /src/wandbot/api/routers/database.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from starlette import status 3 | from starlette.responses import Response 4 | 5 | import wandb 6 | from wandbot.database.client import DatabaseClient 7 | from wandbot.database.schemas import ( 8 | ChatThread, 9 | ChatThreadCreate, 10 | Feedback, 11 | FeedbackCreate, 12 | QuestionAnswer, 13 | QuestionAnswerCreate, 14 | ) 15 | from wandbot.utils import get_logger 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | db_client: DatabaseClient | None = None 21 | 22 | router = APIRouter( 23 | prefix="/data", 24 | tags=["database", "crud"], 25 | ) 26 | 27 | 28 | class APIQuestionAnswerRequest(QuestionAnswerCreate): 29 | pass 30 | 31 | 32 | class APIQuestionAnswerResponse(QuestionAnswer): 33 | pass 34 | 35 | 36 | @router.post( 37 | "/question_answer", 38 | response_model=APIQuestionAnswerResponse, 39 | status_code=status.HTTP_201_CREATED, 40 | ) 41 | def create_question_answer( 42 | request: APIQuestionAnswerRequest, response: Response 43 | ) -> APIQuestionAnswerResponse | None: 44 | """Creates a question answer. 45 | 46 | Args: 47 | request: The request object containing the question answer data. 48 | response: The response object to update with the result. 49 | 50 | Returns: 51 | The created question answer or None if creation failed. 52 | """ 53 | question_answer = db_client.create_question_answer(request) 54 | if question_answer is None: 55 | response.status_code = status.HTTP_400_BAD_REQUEST 56 | return question_answer 57 | 58 | 59 | class APIGetChatThreadResponse(ChatThread): 60 | pass 61 | 62 | 63 | class APIGetChatThreadRequest(ChatThreadCreate): 64 | pass 65 | 66 | 67 | class APICreateChatThreadRequest(ChatThreadCreate): 68 | pass 69 | 70 | 71 | @router.get( 72 | "/chat_thread/{application}/{thread_id}", 73 | response_model=APIGetChatThreadResponse | None, 74 | status_code=status.HTTP_200_OK, 75 | ) 76 | def get_chat_thread( 77 | application: str, thread_id: str, response: Response 78 | ) -> APIGetChatThreadResponse: 79 | """Retrieves a chat thread from the database. 80 | 81 | If the chat thread does not exist, it creates a new chat thread. 82 | 83 | Args: 84 | application: The application name. 85 | thread_id: The ID of the chat thread. 86 | response: The HTTP response object. 87 | 88 | Returns: 89 | The retrieved or created chat thread. 90 | """ 91 | chat_thread = db_client.get_chat_thread( 92 | application=application, 93 | thread_id=thread_id, 94 | ) 95 | if chat_thread is None: 96 | chat_thread = db_client.create_chat_thread( 97 | APICreateChatThreadRequest( 98 | application=application, 99 | thread_id=thread_id, 100 | ) 101 | ) 102 | response.status_code = status.HTTP_201_CREATED 103 | if chat_thread is None: 104 | response.status_code = status.HTTP_400_BAD_REQUEST 105 | return chat_thread 106 | 107 | 108 | class APIFeedbackRequest(FeedbackCreate): 109 | pass 110 | 111 | 112 | class APIFeedbackResponse(Feedback): 113 | pass 114 | 115 | 116 | @router.post( 117 | "/feedback", 118 | response_model=APIFeedbackResponse | None, 119 | status_code=status.HTTP_201_CREATED, 120 | ) 121 | def feedback( 122 | request: APIFeedbackRequest, response: Response 123 | ) -> APIFeedbackResponse: 124 | pass 125 | # """Handles the feedback request and logs the feedback data. 126 | 127 | # Args: 128 | # request: The feedback request object. 129 | # response: The response object. 130 | 131 | # Returns: 132 | # The feedback response object. 133 | # """ 134 | # feedback_response = db_client.create_feedback(request) 135 | # if feedback_response is not None: 136 | # wandb.log( 137 | # { 138 | # "feedback": wandb.Table( 139 | # columns=list(request.model_dump().keys()), 140 | # data=[list(request.model_dump().values())], 141 | # ) 142 | # } 143 | # ) 144 | # else: 145 | # response.status_code = status.HTTP_400_BAD_REQUEST 146 | # return feedback_response 147 | -------------------------------------------------------------------------------- /src/wandbot/api/routers/retrieve.py: -------------------------------------------------------------------------------- 1 | # from typing import Any, Dict, List 2 | 3 | # from fastapi import APIRouter 4 | # from pydantic import BaseModel 5 | # from starlette import status 6 | 7 | # from wandbot.retriever.base import SimpleRetrievalEngine 8 | 9 | # router = APIRouter( 10 | # prefix="/retrieve", 11 | # tags=["retrievers"], 12 | # ) 13 | 14 | # retriever: SimpleRetrievalEngine | None = None 15 | 16 | 17 | # class APIRetrievalResult(BaseModel): 18 | # text: str 19 | # score: float 20 | # metadata: Dict[str, Any] 21 | 22 | 23 | # class APIRetrievalResponse(BaseModel): 24 | # query: str 25 | # top_k: List[APIRetrievalResult] 26 | 27 | 28 | # class APIRetrievalRequest(BaseModel): 29 | # query: str 30 | # language: str = "en" 31 | # top_k: int = 5 32 | # sources: List[str] | None = None 33 | 34 | 35 | # @router.post( 36 | # "", 37 | # response_model=APIRetrievalResponse, 38 | # status_code=status.HTTP_200_OK, 39 | # ) 40 | # @router.post( 41 | # "/", 42 | # response_model=APIRetrievalResponse, 43 | # status_code=status.HTTP_200_OK, 44 | # ) 45 | # def retrieve(request: APIRetrievalRequest) -> APIRetrievalResponse: 46 | # """Retrieves the top k results for a given query. 47 | 48 | # Args: 49 | # request: The APIRetrievalRequest object containing the query and other parameters. 50 | 51 | # Returns: 52 | # The APIRetrievalResponse object containing the query and top k results. 53 | # """ 54 | # results = retriever( 55 | # question=request.query, 56 | # language=request.language, 57 | # top_k=request.top_k, 58 | # sources=request.sources, 59 | # ) 60 | 61 | # return APIRetrievalResponse( 62 | # query=request.query, 63 | # top_k=[ 64 | # APIRetrievalResult( 65 | # text=result["text"], 66 | # score=result["score"], 67 | # metadata=result["metadata"], 68 | # ) 69 | # for result in results 70 | # ], 71 | # ) 72 | -------------------------------------------------------------------------------- /src/wandbot/apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/apps/__init__.py -------------------------------------------------------------------------------- /src/wandbot/apps/discord/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/apps/discord/__init__.py -------------------------------------------------------------------------------- /src/wandbot/apps/discord/__main__.py: -------------------------------------------------------------------------------- 1 | """Discord bot for handling user queries and interacting with an API. 2 | 3 | This module contains the main functionality for a Discord bot that listens to user messages, 4 | detects the language of the message, creates threads for user queries, interacts with an API to get responses, 5 | formats the responses, and sends them back to the user. It also handles user feedback on the bot's responses. 6 | 7 | """ 8 | 9 | import asyncio 10 | import logging 11 | import uuid 12 | 13 | import discord 14 | from discord.ext import commands 15 | 16 | from wandbot.api.client import AsyncAPIClient 17 | from wandbot.apps.discord.config import DiscordAppConfig 18 | from wandbot.apps.utils import format_response 19 | 20 | from dotenv import load_dotenv 21 | import pathlib 22 | # Determine the project root directory (assuming it's 3 levels up from this script) 23 | project_root = pathlib.Path(__file__).resolve().parents[3] 24 | dotenv_path = project_root / ".env" 25 | 26 | load_dotenv(dotenv_path=dotenv_path) 27 | 28 | logger = logging.getLogger(__name__) 29 | logger.setLevel(logging.INFO) 30 | 31 | intents = discord.Intents.all() 32 | intents.typing = False 33 | intents.presences = False 34 | intents.messages = True 35 | intents.reactions = True 36 | 37 | bot = commands.Bot(command_prefix="!", intents=intents) 38 | config = DiscordAppConfig() 39 | api_client = AsyncAPIClient(url=config.WANDBOT_API_URL) 40 | 41 | 42 | @bot.event 43 | async def on_message(message: discord.Message): 44 | """Handles the on_message event in Discord. 45 | 46 | Args: 47 | message: The message object received. 48 | 49 | Returns: 50 | None 51 | """ 52 | if message.author == bot.user: 53 | return 54 | if bot.user is not None and bot.user.mentioned_in(message): 55 | mention = f"<@{message.author.id}>" 56 | thread = None 57 | is_following = None 58 | if isinstance(message.channel, discord.Thread): 59 | if ( 60 | message.channel.parent.id == config.PROD_DISCORD_CHANNEL_ID 61 | or message.channel.parent.id == config.TEST_DISCORD_CHANNEL_ID 62 | ): 63 | thread = message.channel 64 | is_following = True 65 | else: 66 | if ( 67 | message.channel.id == config.PROD_DISCORD_CHANNEL_ID 68 | or message.channel.id == config.TEST_DISCORD_CHANNEL_ID 69 | ): 70 | thread = await message.channel.create_thread( 71 | name="Thread", type=discord.ChannelType.public_thread 72 | ) # currently calling it "Thread" because W&B Support makes it sound too official. 73 | is_following = False 74 | if thread is not None: 75 | if is_following: 76 | chat_history = await api_client.get_chat_history( 77 | application=config.APPLICATION, thread_id=str(thread.id) 78 | ) 79 | else: 80 | chat_history = None 81 | if not chat_history: 82 | await thread.send( 83 | config.INTRO_MESSAGE.format(mention=mention), 84 | mention_author=True, 85 | ) 86 | 87 | response = await api_client.query( 88 | question=str(message.clean_content), 89 | chat_history=chat_history, 90 | language=config.bot_language, 91 | application=config.APPLICATION, 92 | ) 93 | if response is None: 94 | await thread.send( 95 | config.ERROR_MESSAGE.format(mention=mention), 96 | mention_author=True, 97 | ) 98 | return 99 | outro_message = config.OUTRO_MESSAGE 100 | sent_message = None 101 | if len(response.answer) > 1200: 102 | answer_chunks = [] 103 | for i in range(0, len(response.answer), 1200): 104 | answer_chunks.append(response.answer[i : i + 1200]) 105 | for i, answer_chunk in enumerate(answer_chunks): 106 | response_copy = response.model_copy() 107 | response_copy.answer = answer_chunk 108 | if i == len(answer_chunks) - 1: 109 | sent_message = await thread.send( 110 | format_response( 111 | config, 112 | response_copy, 113 | outro_message, 114 | ), 115 | ) 116 | else: 117 | sent_message = await thread.send( 118 | format_response( 119 | config, 120 | response_copy, 121 | "", 122 | is_last=False, 123 | ), 124 | ) 125 | else: 126 | sent_message = await thread.send( 127 | format_response( 128 | config, 129 | response, 130 | outro_message, 131 | ), 132 | ) 133 | if sent_message is not None: 134 | await api_client.create_question_answer( 135 | thread_id=str(thread.id), 136 | question_answer_id=str(sent_message.id), 137 | question=response.question, 138 | answer=response.answer, 139 | system_prompt=response.system_prompt, 140 | model=response.model, 141 | sources=response.sources, 142 | source_documents=response.source_documents, 143 | total_tokens=response.total_tokens, 144 | prompt_tokens=response.prompt_tokens, 145 | completion_tokens=response.completion_tokens, 146 | time_taken=response.time_taken, 147 | start_time=response.start_time, 148 | end_time=response.end_time, 149 | api_call_statuses=response.api_call_statuses, 150 | language=config.bot_language, 151 | ) 152 | # # Add reactions for feedback 153 | await sent_message.add_reaction("👍") 154 | await sent_message.add_reaction("👎") 155 | 156 | # # Wait for reactions 157 | def check(user_reaction, author): 158 | return author == message.author and str( 159 | user_reaction.emoji 160 | ) in [ 161 | "👍", 162 | "👎", 163 | ] 164 | 165 | try: 166 | reaction, user = await bot.wait_for( 167 | "reaction_add", timeout=config.WAIT_TIME, check=check 168 | ) 169 | 170 | except asyncio.TimeoutError: 171 | # await thread.send("🤖") 172 | rating = 0 173 | 174 | else: 175 | # Get the feedback value 176 | if str(reaction.emoji) == "👍": 177 | rating = 1 178 | elif str(reaction.emoji) == "👎": 179 | rating = -1 180 | else: 181 | rating = 0 182 | 183 | # # Send feedback to API 184 | # await api_client.create_feedback( 185 | # feedback_id=str(uuid.uuid4()), 186 | # question_answer_id=str(sent_message.id), 187 | # rating=rating, 188 | # ) 189 | 190 | await bot.process_commands(message) 191 | 192 | 193 | if __name__ == "__main__": 194 | bot.run(config.DISCORD_BOT_TOKEN) 195 | -------------------------------------------------------------------------------- /src/wandbot/apps/discord/config.py: -------------------------------------------------------------------------------- 1 | """Discord bot configuration module. 2 | 3 | This module contains the configuration settings for the Discord bot application. 4 | It includes settings for the application name, wait time, channel IDs, bot token, 5 | API keys, messages in English and Japanese, API URL, and a flag to include sources. 6 | 7 | The settings are defined in the DiscordAppConfig class, which inherits from the 8 | BaseSettings class provided by the pydantic_settings package. The settings values 9 | are either hardcoded or fetched from environment variables. 10 | 11 | Typical usage example: 12 | 13 | config = DiscordAppConfig() 14 | wait_time = config.WAIT_TIME 15 | bot_token = config.DISCORD_BOT_TOKEN 16 | """ 17 | 18 | from pydantic import AnyHttpUrl, Field 19 | from pydantic_settings import BaseSettings, SettingsConfigDict 20 | 21 | EN_INTRO_MESSAGE = ( 22 | "🤖 Hi {mention}: \n\n" 23 | "Please do not share any private or sensitive information in your query at this time.\n\n" 24 | "Please note that overly long messages (>1024 words) will be truncated!\n\nGenerating response...\n\n" 25 | ) 26 | 27 | EN_OUTRO_MESSAGE = ( 28 | "🤖 If you still need help please try re-phrase your question, " 29 | "or alternatively reach out to the Weights & Biases Support Team at support@wandb.com \n\n" 30 | " Was this response helpful? Please react below to let us know" 31 | ) 32 | 33 | EN_ERROR_MESSAGE = "Oops!, Sorry 🤖 {mention}: Something went wrong. Please retry again in some time" 34 | 35 | 36 | class DiscordAppConfig(BaseSettings): 37 | APPLICATION: str = "Discord" 38 | WAIT_TIME: float = 300.0 39 | PROD_DISCORD_CHANNEL_ID: int = 1090739438310654023 40 | TEST_DISCORD_CHANNEL_ID: int = 1088892013321142484 41 | DISCORD_BOT_TOKEN: str = Field(..., validation_alias="DISCORD_BOT_TOKEN") 42 | INTRO_MESSAGE: str = Field(EN_INTRO_MESSAGE) 43 | OUTRO_MESSAGE: str = Field(EN_OUTRO_MESSAGE) 44 | ERROR_MESSAGE: str = Field(EN_ERROR_MESSAGE) 45 | WANDBOT_API_URL: AnyHttpUrl = Field(..., validation_alias="WANDBOT_API_URL") 46 | include_sources: bool = True 47 | bot_language: str = "en" 48 | 49 | model_config = SettingsConfigDict( 50 | env_file=".env", env_file_encoding="utf-8", extra="allow" 51 | ) 52 | -------------------------------------------------------------------------------- /src/wandbot/apps/slack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/apps/slack/__init__.py -------------------------------------------------------------------------------- /src/wandbot/apps/slack/__main__.py: -------------------------------------------------------------------------------- 1 | """A Slack bot that interacts with users and processes their queries. 2 | 3 | This module contains the main functionality of the Slack bot. It listens for mentions of the bot in messages, 4 | processes the text of the message, and sends a response. It also handles reactions added to messages and 5 | saves them as feedback. The bot supports both English and Japanese languages. 6 | 7 | The bot uses the Slack Bolt framework for handling events and the langdetect library for language detection. 8 | It also communicates with an external API for processing queries and storing chat history and feedback. 9 | 10 | """ 11 | 12 | import argparse 13 | import asyncio 14 | import logging 15 | import pathlib 16 | from functools import partial 17 | 18 | from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler 19 | from slack_bolt.async_app import AsyncApp 20 | from slack_sdk.web import SlackResponse 21 | 22 | from wandbot.api.client import AsyncAPIClient 23 | from wandbot.apps.slack.config import SlackAppEnConfig, SlackAppJaConfig 24 | from wandbot.apps.slack.formatter import MrkdwnFormatter 25 | from wandbot.apps.utils import format_response 26 | from wandbot.utils import get_logger 27 | 28 | logger = get_logger(__name__) 29 | 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument( 33 | "-l", 34 | "--language", 35 | default="en", 36 | help="Language of the bot", 37 | type=str, 38 | choices=["en", "ja"], 39 | ) 40 | 41 | args = parser.parse_args() 42 | 43 | if args.language == "ja": 44 | config = SlackAppJaConfig() 45 | else: 46 | config = SlackAppEnConfig() 47 | 48 | 49 | app = AsyncApp(token=config.SLACK_APP_TOKEN) 50 | api_client = AsyncAPIClient(url=config.WANDBOT_API_URL) 51 | 52 | 53 | async def send_message(say: callable, message: str, thread: str = None) -> SlackResponse: 54 | message = MrkdwnFormatter()(message) 55 | if thread is not None: 56 | return await say(text=message, thread_ts=thread) 57 | else: 58 | return await say(text=message) 59 | 60 | 61 | @app.event("app_mention") 62 | async def command_handler(body: dict, say: callable, logger: logging.Logger) -> None: 63 | """ 64 | Handles the command when the app is mentioned in a message. 65 | 66 | Args: 67 | body (dict): The event body containing the message details. 68 | say (function): The function to send a message. 69 | logger (Logger): The logger instance for logging errors. 70 | 71 | Raises: 72 | Exception: If there is an error posting the message. 73 | """ 74 | try: 75 | query = body["event"].get("text") 76 | user = body["event"].get("user") 77 | thread_id = body["event"].get("thread_ts", None) or body["event"].get("ts", None) 78 | say = partial(say, token=config.SLACK_BOT_TOKEN) 79 | 80 | chat_history = await api_client.get_chat_history(application=config.APPLICATION, thread_id=thread_id) 81 | 82 | if not chat_history: 83 | # send out the intro message 84 | await send_message( 85 | say=say, 86 | message=config.INTRO_MESSAGE.format(user=user), 87 | thread=thread_id, 88 | ) 89 | # process the query through the api 90 | api_response = await api_client.query( 91 | question=query, 92 | chat_history=chat_history, 93 | language=config.bot_language, 94 | application=config.APPLICATION, 95 | ) 96 | response = format_response( 97 | config, 98 | api_response, 99 | config.OUTRO_MESSAGE, 100 | ) 101 | 102 | # send the response 103 | sent_message = await send_message(say=say, message=response, thread=thread_id) 104 | 105 | await app.client.reactions_add( 106 | channel=body["event"]["channel"], 107 | timestamp=sent_message["ts"], 108 | name="thumbsup", 109 | token=config.SLACK_BOT_TOKEN, 110 | ) 111 | await app.client.reactions_add( 112 | channel=body["event"]["channel"], 113 | timestamp=sent_message["ts"], 114 | name="thumbsdown", 115 | token=config.SLACK_BOT_TOKEN, 116 | ) 117 | 118 | # save the question answer to the database 119 | await api_client.create_question_answer( 120 | thread_id=thread_id, 121 | question_answer_id=sent_message["ts"], 122 | question=api_response.question, 123 | answer=api_response.answer, 124 | system_prompt=api_response.system_prompt, 125 | model=api_response.model, 126 | sources=api_response.sources, 127 | source_documents=api_response.source_documents, 128 | total_tokens=api_response.total_tokens, 129 | prompt_tokens=api_response.prompt_tokens, 130 | completion_tokens=api_response.completion_tokens, 131 | time_taken=api_response.time_taken, 132 | start_time=api_response.start_time, 133 | end_time=api_response.end_time, 134 | api_call_statuses=api_response.api_call_statuses, 135 | language=config.bot_language, 136 | ) 137 | 138 | except Exception as e: 139 | logger.error(f"Error posting message: {e}") 140 | 141 | 142 | def parse_reaction(reaction: str) -> int: 143 | """ 144 | Parses the reaction and returns the corresponding rating value. 145 | 146 | Args: 147 | reaction (str): The reaction emoji. 148 | 149 | Returns: 150 | int: The rating value (-1, 0, or 1). 151 | """ 152 | if reaction == "+1": 153 | return 1 154 | elif reaction == "-1": 155 | return -1 156 | else: 157 | return 0 158 | 159 | 160 | @app.event("reaction_added") 161 | async def handle_reaction_added(event: dict) -> None: 162 | """ 163 | Handles the event when a reaction is added to a message. 164 | 165 | Args: 166 | event (dict): The event details. 167 | 168 | """ 169 | channel_id = event["item"]["channel"] 170 | message_ts = event["item"]["ts"] 171 | 172 | conversation = await app.client.conversations_replies( 173 | channel=channel_id, 174 | ts=message_ts, 175 | inclusive=True, 176 | limit=1, 177 | token=config.SLACK_BOT_TOKEN, 178 | ) 179 | messages = conversation.get( 180 | "messages", 181 | ) 182 | if messages and len(messages): 183 | thread_ts = messages[0].get("thread_ts") 184 | if thread_ts: 185 | rating = parse_reaction(event["reaction"]) 186 | # await api_client.create_feedback( 187 | # feedback_id=event["event_ts"], 188 | # question_answer_id=message_ts, 189 | # rating=rating, 190 | # ) 191 | 192 | 193 | async def main(): 194 | handler = AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN) 195 | await handler.start_async() 196 | 197 | 198 | if __name__ == "__main__": 199 | asyncio.run(main()) 200 | -------------------------------------------------------------------------------- /src/wandbot/apps/slack/config.py: -------------------------------------------------------------------------------- 1 | """This module contains the configuration settings for the Slack application. 2 | 3 | This module uses the Pydantic library to define the configuration settings for the Slack application. 4 | These settings include tokens, secrets, API keys, and messages for the application. 5 | The settings are loaded from an environment file and can be accessed as properties of the `SlackAppEnConfig` class. 6 | 7 | Typical usage example: 8 | 9 | from .config import SlackAppEnConfig 10 | 11 | config = SlackAppEnConfig() 12 | token = config.SLACK_APP_TOKEN 13 | """ 14 | 15 | from pydantic import AnyHttpUrl, Field 16 | from pydantic_settings import BaseSettings, SettingsConfigDict 17 | 18 | EN_INTRO_MESSAGE = ( 19 | "Hi <@{user}>:\n\n" 20 | "Please do not share any private or sensitive information in your query.\n\n" 21 | "Please note that overly long messages (>1024 words) will be truncated.\n\n" 22 | "Generating response...\n\n" 23 | ) 24 | 25 | EN_OUTRO_MESSAGE = ( 26 | "🤖 If you still need help please try re-phrase your question, " 27 | "or alternatively reach out to the Weights & Biases Support Team at support@wandb.com \n\n" 28 | " Was this response helpful? Please react below to let us know" 29 | ) 30 | 31 | EN_ERROR_MESSAGE = ( 32 | "Oops, Something went wrong. Please try again later" 33 | ) 34 | 35 | # EN_FALLBACK_WARNING_MESSAGE = ( 36 | # "**Warning: Falling back to {model}**\n\n" 37 | # ) 38 | 39 | JA_INTRO_MESSAGE = ( 40 | "こんにちは <@{user}>:\n\n" 41 | "Wandbotは現在アルファテスト中ですので、頻繁にアップデートされます。" 42 | "ご利用の際にはプライバシーに関わる情報は入力されないようお願いします。返答を生成しています・・・" 43 | ) 44 | 45 | JA_OUTRO_MESSAGE = ( 46 | ":robot_face: この答えが十分でなかった場合には、質問を少し変えて試してみると結果が良くなることがあるので、お試しください。もしくは、" 47 | "#support チャンネルにいるwandbチームに質問してください。この答えは役に立ったでしょうか?下のボタンでお知らせ下さい。" 48 | ) 49 | 50 | JA_ERROR_MESSAGE = "「おっと、問題が発生しました。しばらくしてからもう一度お試しください。」" 51 | 52 | # JA_FALLBACK_WARNING_MESSAGE = ( 53 | # "**警告: {model}** にフォールバックします。これらの結果は **gpt-4** ほど良くない可能性があります\n\n" 54 | # ) 55 | 56 | 57 | class SlackAppEnConfig(BaseSettings): 58 | model_config = SettingsConfigDict( 59 | env_file=".env", 60 | env_file_encoding="utf-8", 61 | extra="allow" 62 | ) 63 | 64 | APPLICATION: str = Field("Slack_EN") 65 | SLACK_APP_TOKEN: str = Field( 66 | ..., 67 | validation_alias="SLACK_EN_APP_TOKEN" 68 | ) 69 | SLACK_BOT_TOKEN: str = Field( 70 | ..., 71 | validation_alias="SLACK_EN_BOT_TOKEN" 72 | ) 73 | SLACK_SIGNING_SECRET: str = Field( 74 | ..., 75 | validation_alias="SLACK_EN_SIGNING_SECRET" 76 | ) 77 | INTRO_MESSAGE: str = Field(EN_INTRO_MESSAGE) 78 | OUTRO_MESSAGE: str = Field(EN_OUTRO_MESSAGE) 79 | ERROR_MESSAGE: str = Field(EN_ERROR_MESSAGE) 80 | # WARNING_MESSAGE: str = Field(EN_FALLBACK_WARNING_MESSAGE) 81 | WANDBOT_API_URL: AnyHttpUrl = Field( 82 | ..., 83 | validation_alias="WANDBOT_API_URL" 84 | ) 85 | include_sources: bool = True 86 | bot_language: str = "en" 87 | 88 | 89 | 90 | class SlackAppJaConfig(BaseSettings): 91 | model_config = SettingsConfigDict( 92 | env_file=".env", 93 | env_file_encoding="utf-8", 94 | extra="allow" 95 | ) 96 | APPLICATION: str = Field("Slack_JA") 97 | SLACK_APP_TOKEN: str = Field( 98 | ..., 99 | validation_alias="SLACK_JA_APP_TOKEN" 100 | ) 101 | SLACK_BOT_TOKEN: str = Field( 102 | ..., 103 | validation_alias="SLACK_JA_BOT_TOKEN" 104 | ) 105 | SLACK_SIGNING_SECRET: str = Field( 106 | ..., 107 | validation_alias="SLACK_JA_SIGNING_SECRET" 108 | ) 109 | INTRO_MESSAGE: str = Field(JA_INTRO_MESSAGE) 110 | OUTRO_MESSAGE: str = Field(JA_OUTRO_MESSAGE) 111 | ERROR_MESSAGE: str = Field(JA_ERROR_MESSAGE) 112 | # WARNING_MESSAGE: str = Field(JA_FALLBACK_WARNING_MESSAGE) 113 | WANDBOT_API_URL: AnyHttpUrl = Field( 114 | ..., 115 | validation_alias="WANDBOT_API_URL" 116 | ) 117 | include_sources: bool = True 118 | bot_language: str = "ja" 119 | 120 | -------------------------------------------------------------------------------- /src/wandbot/apps/slack/formatter.py: -------------------------------------------------------------------------------- 1 | import regex as re 2 | 3 | 4 | class MrkdwnFormatter: 5 | def __init__(self): 6 | self.code_block_pattern = re.compile(r"(```.*?```)", re.DOTALL) 7 | self.language_spec_pattern = re.compile( 8 | r"^```[a-zA-Z]+\n", re.MULTILINE 9 | ) 10 | self.markdown_link_pattern = re.compile( 11 | r"\[([^\[]+)\]\((.*?)\)", re.MULTILINE 12 | ) 13 | self.bold_pattern = re.compile(r"\*\*([^*]+)\*\*", re.MULTILINE) 14 | self.strike_pattern = re.compile(r"~~([^~]+)~~", re.MULTILINE) 15 | self.header_pattern = re.compile(r"^#+\s*(.*?)\n", re.MULTILINE) 16 | 17 | @staticmethod 18 | def replace_markdown_link(match): 19 | text = match.group(1) 20 | url = match.group(2) 21 | return f"<{url}|{text}>" 22 | 23 | @staticmethod 24 | def replace_bold(match): 25 | return f"*{match.group(1)}*" 26 | 27 | @staticmethod 28 | def replace_strike(match): 29 | return f"~{match.group(1)}~" 30 | 31 | @staticmethod 32 | def replace_headers(match): 33 | header_text = match.group(1) 34 | return f"\n*{header_text}*\n" 35 | 36 | def __call__(self, text): 37 | try: 38 | segments = self.code_block_pattern.split(text) 39 | 40 | for i, segment in enumerate(segments): 41 | if segment.startswith("```") and segment.endswith("```"): 42 | segment = self.language_spec_pattern.sub("```\n", segment) 43 | segments[i] = segment 44 | else: 45 | segment = self.markdown_link_pattern.sub( 46 | self.replace_markdown_link, segment 47 | ) 48 | segment = self.bold_pattern.sub(self.replace_bold, segment) 49 | segment = self.strike_pattern.sub( 50 | self.replace_strike, segment 51 | ) 52 | segment = self.header_pattern.sub( 53 | self.replace_headers, segment 54 | ) 55 | segments[i] = segment 56 | 57 | return "".join(segments) 58 | except Exception: 59 | return text 60 | -------------------------------------------------------------------------------- /src/wandbot/apps/utils.py: -------------------------------------------------------------------------------- 1 | """This module contains utility functions for the Wandbot application. 2 | 3 | This module provides two main functions: `deduplicate` and `format_response`. 4 | The `deduplicate` function is used to remove duplicates from a list while preserving the order. 5 | The `format_response` function is used to format the response from the API query for the application 6 | 7 | Typical usage example: 8 | 9 | from .utils import deduplicate, format_response 10 | 11 | unique_list = deduplicate(input_list) 12 | formatted_response = format_response(config, response, outro_message, lang, is_last) 13 | """ 14 | 15 | from collections import OrderedDict 16 | from typing import Any, List 17 | 18 | from pydantic_settings import BaseSettings 19 | 20 | from wandbot.api.routers.chat import APIQueryResponse 21 | 22 | 23 | def deduplicate(input_list: List[Any]) -> List[Any]: 24 | """Remove duplicates from a list while preserving order. 25 | 26 | Args: 27 | input_list: The list to remove duplicates from. 28 | 29 | Returns: 30 | A new list with duplicates removed while preserving the original order. 31 | """ 32 | return list(OrderedDict.fromkeys(input_list)) 33 | 34 | 35 | def format_response( 36 | config: BaseSettings, 37 | response: APIQueryResponse | None, 38 | outro_message: str = "", 39 | is_last: bool = True, 40 | ) -> str: 41 | """Formats the response from the API query. 42 | 43 | Args: 44 | config: The config object for the app. 45 | response: The response from the API query. 46 | outro_message: The outro message to append to the formatted response. 47 | is_last: Whether the response is the last in a series. 48 | 49 | Returns: 50 | The formatted response as a string. 51 | 52 | """ 53 | if response is not None: 54 | result = response.answer 55 | 56 | if config.include_sources and response.sources and is_last: 57 | sources_list = deduplicate( 58 | [ 59 | item 60 | for item in response.sources.split("\n") 61 | if item.strip().startswith("http") 62 | ] 63 | ) 64 | if len(sources_list) > 0: 65 | items = min(len(sources_list), 3) 66 | if config.bot_language == "ja": 67 | result = ( 68 | f"{result}\n\n*参考文献*\n\n>" 69 | + "\n> ".join(sources_list[:items]) 70 | + "\n" 71 | ) 72 | else: 73 | result = ( 74 | f"{result}\n\n*References*\n\n>" 75 | + "\n> ".join(sources_list[:items]) 76 | + "\n" 77 | ) 78 | if outro_message: 79 | result = f"{result}\n\n{outro_message}" 80 | 81 | else: 82 | result = config.ERROR_MESSAGE 83 | return result 84 | -------------------------------------------------------------------------------- /src/wandbot/chat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/chat/__init__.py -------------------------------------------------------------------------------- /src/wandbot/chat/chat.py: -------------------------------------------------------------------------------- 1 | """Handles chat interactions for WandBot. 2 | 3 | This module contains the Chat class which is responsible for handling chat interactions with the WandBot system. 4 | It provides both synchronous and asynchronous interfaces for chat operations, manages translations between 5 | languages, and coordinates the RAG (Retrieval Augmented Generation) pipeline. 6 | 7 | The Chat class handles: 8 | - Initialization of the vector store and RAG pipeline 9 | - Translation between Japanese and English (when needed) 10 | - Error handling and status tracking 11 | - Timing of operations 12 | - Response generation and formatting 13 | 14 | Typical usage example: 15 | 16 | from wandbot.configs.chat_config import ChatConfig 17 | from wandbot.configs.vector_store_config import VectorStoreConfig 18 | from wandbot.chat.schemas import ChatRequest 19 | 20 | # Initialize with both required configs 21 | vector_store_config = VectorStoreConfig() 22 | chat_config = ChatConfig() 23 | chat = Chat(vector_store_config=vector_store_config, chat_config=chat_config) 24 | 25 | # Async usage 26 | async def chat_example(): 27 | response = await chat.__acall__( 28 | ChatRequest( 29 | question="How do I use wandb?", 30 | chat_history=[], 31 | language="en" 32 | ) 33 | ) 34 | print(f"Answer: {response.answer}") 35 | print(f"Time taken: {response.time_taken}") 36 | 37 | # Sync usage 38 | response = chat( 39 | ChatRequest( 40 | question="How do I use wandb?", 41 | chat_history=[], 42 | language="en" 43 | ) 44 | ) 45 | print(f"Answer: {response.answer}") 46 | print(f"Time taken: {response.time_taken}") 47 | """ 48 | import sys 49 | import traceback 50 | from typing import List 51 | 52 | import weave 53 | 54 | from wandbot.chat.rag import RAGPipeline, RAGPipelineOutput 55 | from wandbot.chat.schemas import ChatRequest, ChatResponse 56 | from wandbot.chat.utils import translate_en_to_ja, translate_ja_to_en 57 | from wandbot.configs.chat_config import ChatConfig 58 | from wandbot.configs.vector_store_config import VectorStoreConfig 59 | from wandbot.database.schemas import QuestionAnswer 60 | from wandbot.retriever import VectorStore 61 | from wandbot.utils import ErrorInfo, Timer, get_error_file_path, get_logger, run_sync 62 | 63 | logger = get_logger(__name__) 64 | 65 | class Chat: 66 | """Class for handling chat interactions and managing the RAG system components.""" 67 | 68 | def __init__(self, vector_store_config: VectorStoreConfig, chat_config: ChatConfig): 69 | """Initializes the Chat instance with all necessary RAG components. 70 | 71 | Args: 72 | vector_store_config: Configuration for vector store setup 73 | chat_config: Configuration for chat and RAG behavior 74 | """ 75 | self.chat_config = chat_config 76 | 77 | # Initialize vector store internally 78 | self.vector_store = VectorStore.from_config( 79 | vector_store_config=vector_store_config, 80 | chat_config=chat_config 81 | ) 82 | 83 | # Initialize RAG pipeline with internal vector store 84 | self.rag_pipeline = RAGPipeline( 85 | vector_store=self.vector_store, 86 | chat_config=chat_config, 87 | ) 88 | 89 | @weave.op 90 | async def _aget_answer( 91 | self, question: str, chat_history: List[QuestionAnswer] 92 | ) -> RAGPipelineOutput: 93 | history = [] 94 | for item in chat_history: 95 | history.append(("user", item.question)) 96 | history.append(("assistant", item.answer)) 97 | result = await self.rag_pipeline.__acall__(question, history) 98 | return result 99 | 100 | @weave.op 101 | async def __acall__(self, chat_request: ChatRequest) -> ChatResponse: 102 | """Async method for chat interactions.""" 103 | original_language = chat_request.language 104 | api_call_statuses = {} 105 | 106 | # Initialize working request with original request 107 | working_request = chat_request 108 | 109 | with Timer() as timer: 110 | try: 111 | # Handle Japanese translation 112 | if original_language == "ja": 113 | try: 114 | translated_question = translate_ja_to_en( 115 | chat_request.question, 116 | self.chat_config.ja_translation_model_name 117 | ) 118 | working_request = ChatRequest( 119 | question=translated_question, 120 | chat_history=chat_request.chat_history, 121 | application=chat_request.application, 122 | language="en", 123 | ) 124 | except Exception as e: 125 | error_info = ErrorInfo( 126 | has_error=True, 127 | error_message=str(e), 128 | error_type=type(e).__name__, 129 | stacktrace=''.join(traceback.format_exc()), 130 | file_path=get_error_file_path(sys.exc_info()[2]), 131 | component="translation" 132 | ) 133 | api_call_statuses["chat_success"] = False 134 | api_call_statuses["chat_error_info"] = error_info.model_dump() 135 | # Create error response preserving translation error context 136 | return ChatResponse( 137 | system_prompt="", 138 | question=chat_request.question, # Original question 139 | answer=f"Translation error: {str(e)}", 140 | response_synthesis_llm_messages=[], 141 | model="", 142 | sources="", 143 | source_documents="", 144 | total_tokens=0, 145 | prompt_tokens=0, 146 | completion_tokens=0, 147 | time_taken=timer.elapsed, 148 | start_time=timer.start, 149 | end_time=timer.stop, 150 | application=chat_request.application, 151 | api_call_statuses=api_call_statuses 152 | ) 153 | 154 | # Get answer using working request 155 | result = await self._aget_answer( 156 | working_request.question, working_request.chat_history or [] 157 | ) 158 | 159 | result_dict = result.model_dump() 160 | api_call_statuses.update(result_dict.get("api_call_statuses", {})) 161 | 162 | # Handle Japanese translation of response 163 | if original_language == "ja": 164 | try: 165 | result_dict["answer"] = translate_en_to_ja( 166 | result_dict["answer"], 167 | self.chat_config.ja_translation_model_name 168 | ) 169 | except Exception as e: 170 | error_info = ErrorInfo( 171 | has_error=True, 172 | error_message=str(e), 173 | error_type=type(e).__name__, 174 | stacktrace=''.join(traceback.format_exc()), 175 | file_path=get_error_file_path(sys.exc_info()[2]), 176 | component="translation" 177 | ) 178 | api_call_statuses["chat_success"] = False 179 | api_call_statuses["chat_error_info"] = error_info.model_dump() 180 | # Return response with translation error but preserve original answer 181 | result_dict["answer"] = f"Translation error: {str(e)}\nOriginal answer: {result_dict['answer']}" 182 | 183 | # Update with final metadata 184 | api_call_statuses["chat_success"] = True 185 | api_call_statuses["chat_error_info"] = ErrorInfo( 186 | has_error=False, 187 | error_message="", 188 | error_type="", 189 | stacktrace="", 190 | file_path="", 191 | component="chat" 192 | ).model_dump() 193 | result_dict.update({ 194 | "application": chat_request.application, 195 | "api_call_statuses": api_call_statuses, 196 | "time_taken": timer.elapsed, 197 | "start_time": timer.start, 198 | "end_time": timer.stop, 199 | }) 200 | 201 | return ChatResponse(**result_dict) 202 | 203 | except Exception as e: 204 | error_info = ErrorInfo( 205 | has_error=True, 206 | error_message=str(e), 207 | error_type=type(e).__name__, 208 | stacktrace=''.join(traceback.format_exc()), 209 | file_path=get_error_file_path(sys.exc_info()[2]), 210 | component="chat" 211 | ) 212 | api_call_statuses["chat_success"] = False 213 | api_call_statuses["chat_error_info"] = error_info.model_dump() 214 | 215 | return ChatResponse( 216 | system_prompt="", 217 | question=chat_request.question, 218 | answer=error_info.error_message, 219 | response_synthesis_llm_messages=[], 220 | model="", 221 | sources="", 222 | source_documents="", 223 | total_tokens=0, 224 | prompt_tokens=0, 225 | completion_tokens=0, 226 | time_taken=timer.elapsed, 227 | start_time=timer.start, 228 | end_time=timer.stop, 229 | application=chat_request.application, 230 | api_call_statuses=api_call_statuses 231 | ) 232 | 233 | @weave.op 234 | def __call__(self, chat_request: ChatRequest) -> ChatResponse: 235 | return run_sync(self.__acall__(chat_request)) 236 | -------------------------------------------------------------------------------- /src/wandbot/chat/rag.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Dict, List, Tuple 3 | 4 | import weave 5 | from pydantic import BaseModel 6 | 7 | from wandbot.configs.chat_config import ChatConfig 8 | from wandbot.rag.query_handler import QueryEnhancer 9 | from wandbot.rag.response_synthesis import ResponseSynthesizer 10 | from wandbot.rag.retrieval import FusionRetrievalEngine 11 | from wandbot.retriever import VectorStore 12 | from wandbot.utils import ErrorInfo, Timer, get_logger, run_sync 13 | 14 | logger = get_logger(__name__) 15 | chat_config = ChatConfig() 16 | 17 | def get_stats_dict_from_token_callback(token_callback): 18 | return { 19 | "total_tokens": token_callback.total_tokens, 20 | "prompt_tokens": token_callback.prompt_tokens, 21 | "completion_tokens": token_callback.completion_tokens, 22 | "successful_requests": token_callback.successful_requests, 23 | } 24 | 25 | 26 | def get_stats_dict_from_timer(timer): 27 | return { 28 | "start_time": timer.start, 29 | "end_time": timer.stop, 30 | "time_taken": timer.elapsed, 31 | } 32 | 33 | 34 | class RAGPipelineOutput(BaseModel): 35 | question: str 36 | answer: str 37 | sources: str 38 | source_documents: str 39 | system_prompt: str 40 | model: str 41 | total_tokens: int 42 | prompt_tokens: int 43 | completion_tokens: int 44 | time_taken: float 45 | start_time: datetime.datetime 46 | end_time: datetime.datetime 47 | api_call_statuses: dict = {} 48 | response_synthesis_llm_messages: List[Dict[str, str]] | None = None 49 | 50 | 51 | class RAGPipeline: 52 | 53 | def __init__( 54 | self, 55 | vector_store: VectorStore, 56 | chat_config: ChatConfig, 57 | ): 58 | self.vector_store = vector_store 59 | self.query_enhancer = QueryEnhancer( 60 | model_provider = chat_config.query_enhancer_provider, 61 | model_name = chat_config.query_enhancer_model, 62 | temperature = chat_config.query_enhancer_temperature, 63 | fallback_model_provider = chat_config.query_enhancer_fallback_provider, 64 | fallback_model_name = chat_config.query_enhancer_fallback_model, 65 | fallback_temperature = chat_config.query_enhancer_fallback_temperature, 66 | max_retries=chat_config.llm_max_retries 67 | ) 68 | self.retrieval_engine = FusionRetrievalEngine( 69 | vector_store=vector_store, 70 | chat_config=chat_config, 71 | ) 72 | self.response_synthesizer = ResponseSynthesizer( 73 | primary_provider=chat_config.response_synthesizer_provider, 74 | primary_model_name=chat_config.response_synthesizer_model, 75 | primary_temperature=chat_config.response_synthesizer_temperature, 76 | fallback_provider=chat_config.response_synthesizer_fallback_provider, 77 | fallback_model_name=chat_config.response_synthesizer_fallback_model, 78 | fallback_temperature=chat_config.response_synthesizer_fallback_temperature, 79 | max_retries=chat_config.llm_max_retries 80 | ) 81 | 82 | async def __acall__( 83 | self, question: str, chat_history: List[Tuple[str, str]] | None = None 84 | ) -> RAGPipelineOutput: 85 | """ 86 | Async version of the RAG pipeline. 87 | 1) query enhancement 88 | 2) retrieval 89 | 3) response synthesis 90 | """ 91 | if chat_history is None: 92 | chat_history = [] 93 | 94 | enhanced_query = await self.query_enhancer({"query": question, "chat_history": chat_history}) 95 | 96 | # with Timer() as retrieval_tb: 97 | # If retrieval_engine is async, do: 98 | retrieval_result = await self.retrieval_engine.__acall__(enhanced_query) 99 | 100 | # with get_openai_callback() as response_cb, Timer() as response_tb: 101 | response = await self.response_synthesizer(retrieval_result) 102 | # or if it is truly async, do: 103 | # response = await self.response_synthesizer.__acall__(retrieval_results) 104 | 105 | # Build final output 106 | output = RAGPipelineOutput( 107 | question=enhanced_query["standalone_query"], 108 | answer=response["response"], 109 | sources="\n".join( 110 | [doc.metadata["source"] for doc in retrieval_result.documents] 111 | ), 112 | source_documents=response["context_str"], 113 | system_prompt=response["response_prompt"], 114 | model=response["response_model"], 115 | total_tokens=0, 116 | prompt_tokens=0, 117 | completion_tokens=0, 118 | time_taken=0, 119 | start_time=datetime.datetime.now(), 120 | end_time=datetime.datetime.now(), 121 | # total_tokens=query_enhancer_cb.total_tokens + response_cb.total_tokens, 122 | # prompt_tokens=query_enhancer_cb.prompt_tokens + response_cb.prompt_tokens, 123 | # completion_tokens=query_enhancer_cb.completion_tokens + response_cb.completion_tokens, 124 | # time_taken=query_enhancer_tb.elapsed + retrieval_tb.elapsed + response_tb.elapsed, 125 | # start_time=query_enhancer_tb.start, 126 | # end_time=response_tb.stop, 127 | api_call_statuses={ 128 | "web_search_success": retrieval_result.retrieval_info["api_statuses"]["web_search_api"].success, 129 | "reranker_api_error_info": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].error_info, 130 | "reranker_api_success": retrieval_result.retrieval_info["api_statuses"]["reranker_api"].success, 131 | "query_enhancer_llm_api_error_info": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).error_info if enhanced_query.get("api_statuses") else None, 132 | "query_enhancer_llm_api_success": enhanced_query.get("api_statuses", {}).get("query_enhancer_llm_api", {}).success if enhanced_query.get("api_statuses") else False, 133 | "embedding_api_error_info": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].error_info, 134 | "embedding_api_success": retrieval_result.retrieval_info["api_statuses"]["embedding_api"].success, 135 | }, 136 | response_synthesis_llm_messages=response.get("response_synthesis_llm_messages") 137 | ) 138 | return output 139 | 140 | @weave.op 141 | def __call__( 142 | self, question: str, chat_history: List[Tuple[str, str]] | None = None 143 | ) -> RAGPipelineOutput: 144 | return run_sync(self.__acall__(question, chat_history)) -------------------------------------------------------------------------------- /src/wandbot/chat/schemas.py: -------------------------------------------------------------------------------- 1 | """This module defines the Pydantic models for the chat system. 2 | 3 | This module contains the Pydantic models that are used to validate the data 4 | for the chat system. It includes models for chat threads, chat requests, and 5 | chat responses. The models are used to ensure that the data sent to and received 6 | from the chat system is in the correct format. 7 | 8 | Typical usage example: 9 | 10 | chat_thread = ChatThread(thread_id="123", application="app1") 11 | chat_request = ChatRequest(question="What is the weather?", chat_history=None) 12 | chat_response = ChatResponse(system_prompt="Weather is sunny", question="What is the weather?", 13 | answer="It's sunny", model="model1", sources="source1", 14 | source_documents="doc1", total_tokens=10, prompt_tokens=2, 15 | completion_tokens=8, time_taken=1.0, 16 | start_time=datetime.now(), end_time=datetime.now()) 17 | """ 18 | 19 | from datetime import datetime 20 | from typing import Dict, List 21 | 22 | from pydantic import BaseModel 23 | 24 | from wandbot.database.schemas import QuestionAnswer 25 | 26 | 27 | class ChatThreadBase(BaseModel): 28 | question_answers: list[QuestionAnswer] | None = [] 29 | 30 | 31 | class ChatThreadCreate(ChatThreadBase): 32 | thread_id: str 33 | application: str 34 | 35 | class Config: 36 | use_enum_values = True 37 | 38 | 39 | class ChatThread(ChatThreadCreate): 40 | class Config: 41 | from_attributes = True 42 | 43 | 44 | class ChatRequest(BaseModel): 45 | question: str 46 | chat_history: List[QuestionAnswer] | None = None 47 | application: str | None = None 48 | language: str = "en" 49 | 50 | 51 | class ChatResponse(BaseModel): 52 | system_prompt: str 53 | question: str 54 | answer: str 55 | model: str 56 | sources: str 57 | source_documents: str 58 | total_tokens: int 59 | prompt_tokens: int 60 | completion_tokens: int 61 | time_taken: float 62 | start_time: datetime 63 | end_time: datetime 64 | api_call_statuses: dict = {} 65 | response_synthesis_llm_messages: List[Dict[str, str]] | None = None 66 | -------------------------------------------------------------------------------- /src/wandbot/chat/utils.py: -------------------------------------------------------------------------------- 1 | import weave 2 | from openai import OpenAI 3 | 4 | EN_TO_JA_SYSTEM_PROMPT = "You are a professional translator. \n\n\ 5 | Translate the user's text into Japanese according to the specified rules. \n\ 6 | Rule of translation. \n\ 7 | - Maintain the original nuance\n\ 8 | - Use 'run' in English where appropriate, as it's a term used in Wandb.\n\ 9 | - Translate the terms 'reference artifacts' and 'lineage' into Katakana. \n\ 10 | - Include specific terms in English or Katakana where appropriate\n\ 11 | - Keep code unchanged.\n\ 12 | - Only return the Japanese translation without any additional explanation" 13 | 14 | JA_TO_EN_SYSTEM_PROMPT = "You are a professional translator. \n\n\ 15 | Translate the user's question about Weights & Biases into English according to the specified rules. \n\ 16 | Rule of translation. \n\ 17 | - Maintain the original nuance\n\ 18 | - Keep code unchanged.\n\ 19 | - Only return the English translation without any additional explanation" 20 | 21 | 22 | @weave.op 23 | def translate_ja_to_en(text: str, model_name: str) -> str: 24 | """ 25 | Translates Japanese text to English using OpenAI's GPT-4. 26 | 27 | Args: 28 | text: The Japanese text to be translated. 29 | 30 | Returns: 31 | The translated text in English. 32 | """ 33 | client = OpenAI() 34 | response = client.chat.completions.create( 35 | model=model_name, 36 | messages=[ 37 | { 38 | "role": "system", 39 | "content": JA_TO_EN_SYSTEM_PROMPT, 40 | }, 41 | {"role": "user", "content": text}, 42 | ], 43 | temperature=0, 44 | max_tokens=1000, 45 | top_p=1, 46 | ) 47 | return response.choices[0].message.content 48 | 49 | @weave.op 50 | def translate_en_to_ja(text: str, model_name: str) -> str: 51 | """ 52 | Translates English text to Japanese using OpenAI's GPT-4. 53 | 54 | Args: 55 | text: The English text to be translated. 56 | 57 | Returns: 58 | The translated text in Japanese. 59 | """ 60 | client = OpenAI() 61 | response = client.chat.completions.create( 62 | model=model_name, 63 | messages=[ 64 | { 65 | "role": "system", 66 | "content": EN_TO_JA_SYSTEM_PROMPT 67 | }, 68 | {"role": "user", "content": text}, 69 | ], 70 | temperature=0, 71 | max_tokens=1000, 72 | top_p=1, 73 | ) 74 | return response.choices[0].message.content -------------------------------------------------------------------------------- /src/wandbot/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/configs/__init__.py -------------------------------------------------------------------------------- /src/wandbot/configs/app_config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | from pydantic_settings import BaseSettings, SettingsConfigDict 3 | 4 | 5 | class AppConfig(BaseSettings): 6 | model_config = SettingsConfigDict( 7 | env_prefix="", 8 | env_file=".env", 9 | env_file_encoding="utf-8", 10 | extra="ignore" 11 | ) 12 | 13 | wandb_project: str | None = Field("wandbot-dev") 14 | wandb_entity: str | None = Field("wandbot") 15 | log_level: str | None = Field("INFO") -------------------------------------------------------------------------------- /src/wandbot/configs/chat_config.py: -------------------------------------------------------------------------------- 1 | """This module contains the configuration settings for wandbot. 2 | 3 | The `ChatConfig` class in this module is used to define various settings for wandbot, such as the model name, 4 | maximum retries, fallback model name, chat temperature, chat prompt, index artifact, embeddings cache, verbosity, 5 | wandb project and entity, inclusion of sources, and query tokens threshold. These settings are used throughout the 6 | chatbot's operation to control its behavior. 7 | 8 | Typical usage example: 9 | 10 | from wandbot.configs.chat_config import ChatConfig 11 | config = ChatConfig() 12 | print(config.chat_model_name) 13 | """ 14 | 15 | from typing import Literal 16 | 17 | from pydantic_settings import BaseSettings 18 | 19 | 20 | class ChatConfig(BaseSettings): 21 | # Retrieval settings 22 | top_k: int = 15 23 | top_k_per_query: int = 15 24 | search_type: Literal["mmr", "similarity"] = "mmr" 25 | do_web_search: bool = False 26 | redundant_similarity_threshold: float = 0.95 # used to remove very similar retrieved documents 27 | 28 | # Retrieval settings: MMR settings 29 | fetch_k: int = 20 # Used in mmr retrieval. Typically set as top_k * 4 30 | mmr_lambda_mult: float = 0.5 # used in mmr retrieval 31 | 32 | # Reranker models 33 | rereanker_provider: str = "cohere" 34 | english_reranker_model: str = "rerank-v3.5" 35 | multilingual_reranker_model: str = "rerank-v3.5" 36 | 37 | # Query enhancer settings 38 | query_enhancer_provider: str = "google" 39 | query_enhancer_model: str = "gemini-2.0-flash-001" 40 | query_enhancer_temperature: float = 1.0 41 | query_enhancer_fallback_provider: str = "google" 42 | query_enhancer_fallback_model: str = "gemini-2.0-flash-001" 43 | query_enhancer_fallback_temperature: float = 1.0 44 | 45 | # Response synthesis model settings 46 | response_synthesizer_provider: str = "anthropic" 47 | response_synthesizer_model: str = "claude-3-7-sonnet-20250219" 48 | response_synthesizer_temperature: float = 0.1 49 | response_synthesizer_fallback_provider: str = "anthropic" 50 | response_synthesizer_fallback_model: str = "claude-3-7-sonnet-20250219" 51 | response_synthesizer_fallback_temperature: float = 0.1 52 | 53 | # Translation models settings 54 | ja_translation_model_name: str = "gpt-4o-2024-08-06" 55 | 56 | # LLM Model retry settings 57 | llm_max_retries: int = 3 58 | llm_retry_min_wait: float = 4 # minimum seconds to wait between retries 59 | llm_retry_max_wait: float = 60 # maximum seconds to wait between retries 60 | llm_retry_multiplier: float = 1 # multiplier for exponential backoff 61 | 62 | # Embedding Model retry settings 63 | embedding_max_retries: int = 3 64 | embedding_retry_min_wait: float = 4 65 | embedding_retry_max_wait: float = 60 66 | embedding_retry_multiplier: float = 1 67 | 68 | # Reranker retry settings 69 | reranker_max_retries: int = 5 70 | reranker_retry_min_wait: float = 2.0 71 | reranker_retry_max_wait: float = 180 72 | reranker_retry_multiplier: float = 2.5 73 | 74 | # Vector Store retry settings 75 | vector_store_max_retries: int = 3 76 | vector_store_retry_min_wait: float = 1.0 # Start with a short wait 77 | vector_store_retry_max_wait: float = 10.0 # Cap the wait time 78 | vector_store_retry_multiplier: float = 2.0 # Double wait time each retry 79 | -------------------------------------------------------------------------------- /src/wandbot/configs/database_config.py: -------------------------------------------------------------------------------- 1 | """This module provides a DataBaseConfig class for managing database configuration. 2 | 3 | The DataBaseConfig class uses the BaseSettings class from pydantic_settings to define and manage the database 4 | configuration settings. It includes the SQLAlchemy database URL and connection arguments. 5 | 6 | Typical usage example: 7 | 8 | db_config = DataBaseConfig() 9 | database_url = db_config.SQLALCHEMY_DATABASE_URL 10 | connect_args = db_config.connect_args 11 | """ 12 | 13 | from typing import Any 14 | 15 | from pydantic import Field 16 | from pydantic_settings import BaseSettings, SettingsConfigDict 17 | 18 | 19 | class DataBaseConfig(BaseSettings): 20 | model_config = SettingsConfigDict( 21 | env_prefix="", 22 | env_file=".env", 23 | env_file_encoding="utf-8", 24 | extra="ignore" 25 | ) 26 | 27 | SQLALCHEMY_DATABASE_URL: str = Field( 28 | "sqlite:///./data/cache/app.db" 29 | ) 30 | connect_args: dict[str, Any] = Field({"check_same_thread": False}) 31 | -------------------------------------------------------------------------------- /src/wandbot/configs/vector_store_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | from typing import List, Literal, Optional 4 | 5 | from pydantic import Field, model_validator 6 | from pydantic_settings import BaseSettings, SettingsConfigDict 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class VectorStoreConfig(BaseSettings): 11 | model_config = SettingsConfigDict( 12 | env_prefix="", 13 | env_file=".env", 14 | env_file_encoding="utf-8", 15 | extra="ignore" 16 | ) 17 | 18 | # Vector Store 19 | vectordb_collection_name: str = "chroma_index-v54" #"vectorstore", vectorstore-chroma_index-v54 20 | vectordb_index_dir: pathlib.Path = Field( 21 | pathlib.Path("artifacts/vector_stores"), env="VECTORDB_INDEX_DIR" 22 | ) 23 | vectordb_index_artifact_url: str = "wandbot/wandbot-dev/chroma_index:v54" 24 | distance: str = "l2" # used in retrieval from vectordb 25 | distance_key: str = "hnsw:space" # used in retrieval from vectordb 26 | 27 | # ChromaDB Client Mode 28 | vector_store_mode: Literal["local", "hosted"] = "hosted" 29 | # Settings for hosted mode (using direct HttpClient parameters) 30 | vector_store_host: Optional[str] = "api.trychroma.com" # e.g., 'api.trychroma.com' 31 | vector_store_tenant: Optional[str] = '3c66fbfc-98ce-41ff-92ec-ef16e71c8c0a' # Tenant ID for hosted Chroma 32 | vector_store_database: Optional[str] = 'wandbot-prod' # Database name for hosted Chroma 33 | vector_store_api_key: Optional[str] = None # Pulled as env variable from .env file 34 | 35 | # Embeddings settings 36 | embeddings_provider:str = "openai" 37 | embeddings_model_name: str = "text-embedding-3-small" 38 | embeddings_dimensions: int = 512 # needed when using OpenAI embeddings 39 | 40 | # Embedding input types, e.g. "search_query" or "search_document" 41 | embeddings_query_input_type: str = "search_query" # needed when using Cohere embeddings 42 | embeddings_document_input_type: str = "search_document" # needed when using Cohere embeddings 43 | 44 | # Embedding encoding format 45 | embeddings_encoding_format: Literal["float", "base64"] = "base64" 46 | 47 | # Ingestion settings 48 | batch_size: int = 256 49 | persist_directory: Optional[pathlib.Path] = None 50 | 51 | # Remote ChromaDB Upload Configuration - Keys for transformation 52 | remote_chroma_keys_to_prepend: List[str] = Field( 53 | default_factory=lambda: [ 54 | "parent_id", 55 | "id", 56 | "source", 57 | "file_type", 58 | "has_code", 59 | "source_type", 60 | "tags", 61 | "description", 62 | ] 63 | ) 64 | remote_chroma_keys_to_remove: List[str] = Field( 65 | default_factory=lambda: [ 66 | "source_content", 67 | "file_type", 68 | "tags", 69 | "description", 70 | ] 71 | ) 72 | 73 | @model_validator(mode="after") 74 | def _adjust_paths_for_dimension(cls, values: "VectorStoreConfig") -> "VectorStoreConfig": 75 | """Adjusts index directory path based on embedding dimension.""" 76 | if values.vector_store_mode == "local": # Only adjust for local mode 77 | base_dir = values.vectordb_index_dir 78 | dimension = values.embeddings_dimensions 79 | # Ensure we don't append dimension multiple times if already present 80 | if not base_dir.name.endswith(f"_{dimension}"): 81 | values.vectordb_index_dir = base_dir.parent / f"{base_dir.name}_{dimension}" 82 | logger.info(f"Adjusted vectordb_index_dir for dimension {dimension}: {values.vectordb_index_dir}") 83 | return values -------------------------------------------------------------------------------- /src/wandbot/database/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/database/__init__.py -------------------------------------------------------------------------------- /src/wandbot/database/client.py: -------------------------------------------------------------------------------- 1 | """This module provides a Database and DatabaseClient class for managing database operations. 2 | 3 | The Database class provides a connection to the database and manages the session. It also provides methods for 4 | getting and setting the current session object and the name of the database. 5 | 6 | The DatabaseClient class uses an instance of the Database class to perform operations such as getting and creating 7 | chat threads, question answers, and feedback from the database. 8 | 9 | Typical usage example: 10 | 11 | db_client = DatabaseClient() 12 | chat_thread = db_client.get_chat_thread(application='app1', thread_id='123') 13 | question_answer = db_client.create_question_answer(question_answer=QuestionAnswerCreateSchema()) 14 | """ 15 | 16 | import json 17 | from typing import Any, List 18 | 19 | from sqlalchemy.future import create_engine 20 | from sqlalchemy.orm import sessionmaker 21 | 22 | from wandbot.configs.database_config import DataBaseConfig 23 | from wandbot.database.models import ChatThread as ChatThreadModel 24 | from wandbot.database.models import FeedBack as FeedBackModel 25 | from wandbot.database.models import QuestionAnswer as QuestionAnswerModel 26 | from wandbot.database.schemas import ChatThreadCreate as ChatThreadCreateSchema 27 | from wandbot.database.schemas import Feedback as FeedbackSchema 28 | from wandbot.database.schemas import ( 29 | QuestionAnswerCreate as QuestionAnswerCreateSchema, 30 | ) 31 | from wandbot.utils import get_logger 32 | 33 | logger = get_logger(__name__) 34 | 35 | 36 | class Database: 37 | """A class representing a database connection. 38 | 39 | This class provides a connection to the database and manages the session. 40 | 41 | Attributes: 42 | db_config: An instance of the DataBaseConfig class. 43 | SessionLocal: A sessionmaker object for creating sessions. 44 | db: The current session object. 45 | name: The name of the database. 46 | """ 47 | 48 | db_config: DataBaseConfig = DataBaseConfig() 49 | 50 | def __init__(self, database: str | None = None): 51 | """Initializes the Database instance. 52 | 53 | Args: 54 | database: The URL of the database. If None, the default URL is used. 55 | """ 56 | if database is not None: 57 | engine: Any = create_engine( 58 | url=database, connect_args=self.db_config.connect_args 59 | ) 60 | else: 61 | engine: Any = create_engine( 62 | url=self.db_config.SQLALCHEMY_DATABASE_URL, 63 | connect_args=self.db_config.connect_args, 64 | ) 65 | self.SessionLocal: Any = sessionmaker( 66 | autocommit=False, autoflush=False, bind=engine 67 | ) 68 | 69 | def __get__(self, instance, owner) -> Any: 70 | """Gets the current session object. 71 | 72 | Args: 73 | instance: The instance of the owner class. 74 | owner: The owner class. 75 | 76 | Returns: 77 | The current session object. 78 | """ 79 | if not hasattr(self, "db"): 80 | self.db: Any = self.SessionLocal() 81 | return self.db 82 | 83 | def __set__(self, instance, value) -> None: 84 | """Sets the current session object. 85 | 86 | Args: 87 | instance: The instance of the owner class. 88 | value: The new session object. 89 | """ 90 | self.db = value 91 | 92 | def __set_name__(self, owner, name) -> None: 93 | """Sets the name of the database. 94 | 95 | Args: 96 | owner: The owner class. 97 | name: The name of the database. 98 | """ 99 | self.name: str = name 100 | 101 | 102 | class DatabaseClient: 103 | database: Database = Database() 104 | 105 | def __init__(self, database: str | None = None): 106 | """Initializes the DatabaseClient instance. 107 | 108 | Args: 109 | database: The URL of the database. If None, the default URL is used. 110 | """ 111 | if database is not None: 112 | self.database = Database(database=database) 113 | 114 | def get_chat_thread( 115 | self, application: str, thread_id: str 116 | ) -> ChatThreadModel | None: 117 | """Gets a chat thread from the database. 118 | 119 | Args: 120 | application: The application name. 121 | thread_id: The ID of the chat thread. 122 | 123 | Returns: 124 | The chat thread model if found, None otherwise. 125 | """ 126 | chat_thread: ChatThreadModel | None = ( 127 | self.database.query(ChatThreadModel) 128 | .filter( 129 | ChatThreadModel.thread_id == thread_id, 130 | ChatThreadModel.application == application, 131 | ) 132 | .first() 133 | ) 134 | return chat_thread 135 | 136 | def create_chat_thread( 137 | self, chat_thread: ChatThreadCreateSchema 138 | ) -> ChatThreadModel: 139 | """Creates a chat thread in the database. 140 | 141 | Args: 142 | chat_thread: The chat thread to create. 143 | 144 | Returns: 145 | The created chat thread model. 146 | """ 147 | try: 148 | chat_thread: ChatThreadModel = ChatThreadModel( 149 | thread_id=chat_thread.thread_id, 150 | application=chat_thread.application, 151 | ) 152 | self.database.add(chat_thread) 153 | self.database.flush() 154 | self.database.commit() 155 | self.database.refresh(chat_thread) 156 | 157 | except Exception as e: 158 | logger.error(f"Create chat thread failed with error: {e}") 159 | self.database.rollback() 160 | 161 | return chat_thread 162 | 163 | def get_question_answer( 164 | self, question_answer_id: str, thread_id: str 165 | ) -> QuestionAnswerModel | None: 166 | """Gets a question answer from the database. 167 | 168 | Args: 169 | question_answer_id: The ID of the question answer. 170 | thread_id: The ID of the chat thread. 171 | 172 | Returns: 173 | The question answer model if found, None otherwise. 174 | """ 175 | question_answer: QuestionAnswerModel | None = ( 176 | self.database.query(QuestionAnswerModel) 177 | .filter( 178 | QuestionAnswerModel.thread_id == thread_id, 179 | QuestionAnswerModel.question_answer_id == question_answer_id, 180 | ) 181 | .first() 182 | ) 183 | return question_answer 184 | 185 | def create_question_answer( 186 | self, question_answer: QuestionAnswerCreateSchema 187 | ) -> QuestionAnswerModel: 188 | """Creates a question answer in the database. 189 | 190 | Args: 191 | question_answer: The question answer to create. 192 | 193 | Returns: 194 | The created question answer model. 195 | """ 196 | try: 197 | question_answer: QuestionAnswerModel = QuestionAnswerModel( 198 | **question_answer.model_dump() 199 | ) 200 | self.database.add(question_answer) 201 | self.database.flush() 202 | self.database.commit() 203 | self.database.refresh(question_answer) 204 | except Exception as e: 205 | logger.error(f"Create question answer failed with error: {e}") 206 | self.database.rollback() 207 | return question_answer 208 | 209 | def get_feedback(self, question_answer_id: str) -> FeedBackModel | None: 210 | """Gets feedback from the database. 211 | 212 | Args: 213 | question_answer_id: The ID of the question answer. 214 | 215 | Returns: 216 | The feedback model if found, None otherwise. 217 | """ 218 | feedback: FeedBackModel | None = ( 219 | self.database.query(FeedBackModel) 220 | .filter(FeedBackModel.question_answer_id == question_answer_id) 221 | .first() 222 | ) 223 | return feedback 224 | 225 | def create_feedback(self, feedback: FeedbackSchema) -> FeedBackModel: 226 | """Creates feedback in the database. 227 | 228 | Args: 229 | feedback: The feedback to create. 230 | 231 | Returns: 232 | The created feedback model. 233 | """ 234 | if feedback.rating: 235 | try: 236 | feedback: FeedBackModel = FeedBackModel(**feedback.model_dump()) 237 | self.database.add(feedback) 238 | self.database.flush() 239 | self.database.commit() 240 | self.database.refresh(feedback) 241 | except Exception as e: 242 | logger.error(f"Create feedback failed with error: {e}") 243 | self.database.rollback() 244 | 245 | return feedback 246 | 247 | def get_all_question_answers( 248 | self, time: Any = None 249 | ) -> List[dict[str, Any]] | None: 250 | """Gets all question answers from the database. 251 | 252 | Args: 253 | time: The time to filter the question answers by. 254 | 255 | Returns: 256 | A list of question answer dictionaries if found, None otherwise. 257 | """ 258 | question_answers = self.database.query(QuestionAnswerModel) 259 | if time is not None: 260 | question_answers = question_answers.filter( 261 | QuestionAnswerModel.end_time >= time 262 | ) 263 | question_answers = question_answers.all() 264 | if question_answers is not None: 265 | question_answers = [ 266 | json.loads( 267 | QuestionAnswerCreateSchema.model_validate( 268 | question_answer 269 | ).model_dump_json() 270 | ) 271 | for question_answer in question_answers 272 | ] 273 | return question_answers 274 | -------------------------------------------------------------------------------- /src/wandbot/database/config.py: -------------------------------------------------------------------------------- 1 | # """This module provides a DataBaseConfig class for managing database configuration. 2 | 3 | # The DataBaseConfig class uses the BaseSettings class from pydantic_settings to define and manage the database 4 | # configuration settings. It includes the SQLAlchemy database URL and connection arguments. 5 | 6 | # Typical usage example: 7 | 8 | # db_config = DataBaseConfig() 9 | # database_url = db_config.SQLALCHEMY_DATABASE_URL 10 | # connect_args = db_config.connect_args 11 | # """ 12 | 13 | # from typing import Any 14 | 15 | # from pydantic import Field 16 | # from pydantic_settings import BaseSettings 17 | 18 | 19 | # class DataBaseConfig(BaseSettings): 20 | # SQLALCHEMY_DATABASE_URL: str = Field( 21 | # "sqlite:///./data/cache/app.db", env="SQLALCHEMY_DATABASE_URL" 22 | # ) 23 | # connect_args: dict[str, Any] = Field({"check_same_thread": False}) 24 | -------------------------------------------------------------------------------- /src/wandbot/database/database.py: -------------------------------------------------------------------------------- 1 | """This module provides the setup for the SQLAlchemy database engine and session. 2 | 3 | It imports the create_engine and sessionmaker modules from SQLAlchemy, and the DataBaseConfig class from the config 4 | module. It then creates an instance of DataBaseConfig, sets up the engine with the SQLAlchemy database URL and 5 | connection arguments, and creates a sessionmaker bound to this engine. 6 | 7 | Typical usage example: 8 | 9 | from wandbot.database.database import SessionLocal 10 | session = SessionLocal() 11 | """ 12 | 13 | from pathlib import Path 14 | 15 | from sqlalchemy import create_engine 16 | from sqlalchemy.orm import sessionmaker 17 | 18 | from wandbot.configs.database_config import DataBaseConfig 19 | 20 | db_config = DataBaseConfig() 21 | 22 | # Ensure the directory for the SQLite database exists 23 | db_url = db_config.SQLALCHEMY_DATABASE_URL 24 | if db_url.startswith("sqlite:///"): 25 | db_path_str = db_url.split("sqlite:///", 1)[1] 26 | db_path = Path(db_path_str) 27 | db_path.parent.mkdir(parents=True, exist_ok=True) 28 | 29 | engine = create_engine( 30 | db_url, connect_args=db_config.connect_args 31 | ) 32 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 33 | -------------------------------------------------------------------------------- /src/wandbot/database/models.py: -------------------------------------------------------------------------------- 1 | """This module defines the SQLAlchemy models for the ChatThread, QuestionAnswer, and FeedBack tables. 2 | 3 | Each class represents a table in the database and includes columns and relationships. The Base class is a declarative 4 | base that stores a catalog of classes and mapped tables in the Declarative system. 5 | 6 | Typical usage example: 7 | 8 | from wandbot.database.models import ChatThread, QuestionAnswer, FeedBack 9 | chat_thread = ChatThread(thread_id='123', application='app1') 10 | question_answer = QuestionAnswer(question_answer_id='456', thread_id='123') 11 | feedback = FeedBack(feedback_id='789', question_answer_id='456') 12 | """ 13 | 14 | from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, String 15 | from sqlalchemy.ext.declarative import declarative_base 16 | from sqlalchemy.orm import relationship 17 | 18 | Base = declarative_base() 19 | 20 | 21 | class ChatThread(Base): 22 | __tablename__ = "chat_thread" 23 | 24 | thread_id = Column(String, primary_key=True, index=True) 25 | application = Column(String) 26 | question_answers = relationship( 27 | "QuestionAnswer", back_populates="chat_thread" 28 | ) 29 | 30 | 31 | class QuestionAnswer(Base): 32 | __tablename__ = "question_answers" 33 | thread_id = Column(String, ForeignKey("chat_thread.thread_id")) 34 | question_answer_id = Column(String, primary_key=True, index=True) 35 | system_prompt = Column(String) 36 | question = Column(String) 37 | answer = Column(String) 38 | model = Column(String) 39 | sources = Column(String) 40 | source_documents = Column(String) 41 | start_time = Column(DateTime) 42 | end_time = Column(DateTime) 43 | time_taken = Column(Float) 44 | total_tokens = Column(Integer) 45 | prompt_tokens = Column(Integer) 46 | completion_tokens = Column(Integer) 47 | successful_requests = Column(Integer) 48 | total_cost = Column(Float) 49 | chat_thread = relationship("ChatThread", back_populates="question_answers") 50 | feedback = relationship("FeedBack", back_populates="question_answer") 51 | language = Column(String) 52 | 53 | 54 | class FeedBack(Base): 55 | __tablename__ = "feedback" 56 | 57 | feedback_id = Column(String, primary_key=True, index=True) 58 | question_answer_id = Column( 59 | String, ForeignKey("question_answers.question_answer_id") 60 | ) 61 | rating = Column(Integer) 62 | question_answer = relationship("QuestionAnswer", back_populates="feedback") 63 | -------------------------------------------------------------------------------- /src/wandbot/database/schemas.py: -------------------------------------------------------------------------------- 1 | """This module defines the Pydantic models for the chat system. 2 | 3 | This module contains the Pydantic models that are used to validate the data 4 | for the chat system. It includes models for chat threads, chat requests, and 5 | chat responses. The models are used to ensure that the data sent to and received 6 | from the chat system is in the correct format. 7 | 8 | Typical usage example: 9 | 10 | chat_thread = ChatThread(thread_id="123", application="app1") 11 | chat_request = ChatRequest(question="What is the weather?", chat_history=None) 12 | chat_response = ChatResponse(system_prompt="Weather is sunny", question="What is the weather?", 13 | answer="It's sunny", model="model1", sources="source1", 14 | source_documents="doc1", total_tokens=10, prompt_tokens=2, 15 | completion_tokens=8, time_taken=1.0, 16 | start_time=datetime.now(), end_time=datetime.now()) 17 | """ 18 | 19 | from datetime import datetime 20 | from enum import IntEnum 21 | 22 | from pydantic import BaseModel, ConfigDict 23 | 24 | 25 | class Rating(IntEnum): 26 | positive = 1 27 | negative = -1 28 | neutral = 0 29 | 30 | 31 | class FeedbackBase(BaseModel): 32 | rating: Rating | None = None 33 | 34 | 35 | class Feedback(FeedbackBase): 36 | model_config = ConfigDict(from_attributes=True, use_enum_values=True) 37 | 38 | 39 | class FeedbackCreate(Feedback): 40 | feedback_id: str 41 | question_answer_id: str 42 | 43 | 44 | class QuestionAnswerBase(BaseModel): 45 | system_prompt: str | None = None 46 | question: str 47 | answer: str | None = None 48 | model: str | None = None 49 | sources: str | None = None 50 | source_documents: str | None = None 51 | total_tokens: int | None = None 52 | prompt_tokens: int | None = None 53 | completion_tokens: int | None = None 54 | time_taken: float | None = None 55 | start_time: datetime | None = None 56 | end_time: datetime | None = None 57 | feedback: list[Feedback] | None = [] 58 | language: str | None = None 59 | 60 | 61 | class QuestionAnswer(QuestionAnswerBase): 62 | model_config = ConfigDict(from_attributes=True, use_enum_values=True) 63 | 64 | 65 | class QuestionAnswerCreate(QuestionAnswer): 66 | question_answer_id: str 67 | thread_id: str 68 | 69 | 70 | class ChatThreadBase(BaseModel): 71 | question_answers: list[QuestionAnswer] | None = [] 72 | 73 | 74 | class ChatThread(ChatThreadBase): 75 | application: str 76 | model_config = ConfigDict(from_attributes=True, use_enum_values=True) 77 | 78 | 79 | class ChatThreadCreate(ChatThread): 80 | thread_id: str 81 | -------------------------------------------------------------------------------- /src/wandbot/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | ## Overview 4 | 5 | The following W&B reports provide an overview of the evaluation process of wandbot: 6 | 7 | 8 | - [How to evaluate an LLM Part 1: Building an Evaluation Dataset for our LLM System](https://wandb.ai/wandbot/wandbot-eval/reports/How-to-evaluate-an-LLM-Part-1-Building-an-Evaluation-Dataset-for-our-LLM-System--Vmlldzo1NTAwNTcy) 9 | - [How to evaluate an LLM Part 2: Manual Evaluation of our LLM System](https://wandb.ai/wandbot/wandbot-eval/reports/How-to-evaluate-an-LLM-Part-2-Manual-Evaluation-of-our-LLM-System--Vmlldzo1NzU4NTM3) 10 | - [How to evaluate an LLM Part 3: Auto-Evaluation; LLMs evaluating LLMs](https://wandb.ai/wandbot/wandbot-eval/reports/How-to-evaluate-an-LLM-Part-3-Auto-Evaluation-LLMs-evaluating-LLMs--Vmlldzo1NzEzMDcz) -------------------------------------------------------------------------------- /src/wandbot/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/evaluation/__init__.py -------------------------------------------------------------------------------- /src/wandbot/evaluation/eval_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import simple_parsing as sp 5 | 6 | 7 | @dataclass 8 | class EvalConfig: 9 | # language for eval dataset to use (en or ja) 10 | lang: Literal["en", "ja"] = "en" 11 | eval_judge_provider: Literal["anthropic", "openai"] = "openai" 12 | eval_judge_model: str = "gpt-4o-2024-11-20" 13 | eval_judge_temperature: float = 0.1 14 | experiment_name: str = "wandbot-eval" 15 | evaluation_name: str = "wandbot-eval" 16 | n_trials: int = 3 17 | n_weave_parallelism: int = 10 18 | wandbot_url: str = "http://0.0.0.0:8000" 19 | wandb_entity: str = "wandbot" 20 | wandb_project: str = "wandbot-eval" 21 | debug: bool = False 22 | n_debug_samples: int = 3 23 | max_evaluator_retries: int = 3 24 | evaluator_timeout: int = 60 25 | 26 | @property 27 | def eval_dataset(self) -> str: 28 | if self.lang == "ja": 29 | return "weave:///wandbot/wandbot-eval-jp/object/wandbot_eval_data_jp:oCWifIAtEVCkSjushP0bOEc5GnhsMUYXURwQznBeKLA" 30 | return "weave:///wandbot/wandbot-eval/object/wandbot_eval_data:eCQQ0GjM077wi4ykTWYhLPRpuGIaXbMwUGEB7IyHlFU" 31 | 32 | def get_eval_config() -> EvalConfig: 33 | return sp.parse(EvalConfig) 34 | -------------------------------------------------------------------------------- /src/wandbot/evaluation/eval_metrics/correctness.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | import regex as re 4 | import weave 5 | from pydantic import BaseModel, Field 6 | 7 | from wandbot.models.llm import LLMError, LLMModel 8 | 9 | SYSTEM_TEMPLATE = """You are a Weight & Biases support expert tasked with evaluating the correctness of answers to questions asked by users to a technical support chatbot. 10 | 11 | You are given the following information: 12 | - a user query, 13 | - the documentation used to generate the answer 14 | - a reference answer 15 | - the reason why the reference answer is correct, and 16 | - a generated answer. 17 | 18 | Your job is to judge the relevance and correctness of the generated answer. 19 | - Consider whether the answer addresses all aspects of the question. 20 | - The generated answer must provide only correct information according to the documentation. 21 | - Compare the generated answer to the reference answer for completeness and correctness. 22 | - Output a score and a decision that represents a holistic evaluation of the generated answer. 23 | - You must return your response only in the below mentioned format. Do not return answers in any other format. 24 | 25 | Follow these guidelines for scoring: 26 | - Your score has to be between 1 and 3, where 1 is the worst and 3 is the best. 27 | - If the generated answer is not correct in comparison to the reference, you should give a score of 1. 28 | - If the generated answer is correct in comparison to the reference but contains mistakes, you should give a score of 2. 29 | - If the generated answer is correct in comparison to the reference and completely answer's the user's query, you should give a score of 3. 30 | 31 | CRITICAL: You must output ONLY a JSON object. No text before or after. No explanations. No notes. Just the JSON object in exactly this format: 32 | { 33 | "reason": <>, 34 | "score": <>, 35 | "decision": <> 36 | } 37 | 38 | Example Response 1: 39 | { 40 | "reason": "The generated answer has the exact details as the reference answer and completely answer's the user's query.", 41 | "score": 3, 42 | "decision": "correct" 43 | } 44 | 45 | Example Response 2: 46 | { 47 | "reason": "The generated answer doesn't match the reference answer, and deviates from the documentation provided", 48 | "score": 1, 49 | "decision": "incorrect" 50 | } 51 | 52 | Example Response 3: 53 | { 54 | "reason": "The generated answer follows the same steps as the reference answer. However, it includes assumptions about methods that are not mentioned in the documentation.", 55 | "score": 2, 56 | "decision": "incorrect" 57 | }""" 58 | 59 | USER_TEMPLATE = """ 60 | ## User Query 61 | {query} 62 | 63 | ## Documentation 64 | {context_str} 65 | 66 | ## Reference Answer 67 | {reference_answer} 68 | 69 | ## Reference Correctness Reason 70 | {reference_notes} 71 | 72 | ## Generated Answer 73 | {generated_answer} 74 | """ 75 | 76 | 77 | class CorrectnessEvaluationModel(BaseModel): 78 | reason: str = Field(..., description="Provide a brief explanation for your decision here") 79 | score: float = Field(..., description="Provide a score as per the above guidelines") 80 | decision: str = Field(..., description="Provide your final decision here, either 'correct', or 'incorrect'") 81 | 82 | 83 | class CorrectnessEvaluationResult(CorrectnessEvaluationModel): 84 | answer_correct: bool 85 | has_error: bool 86 | error_message: str 87 | 88 | 89 | class WandBotCorrectnessEvaluator: 90 | """Evaluates the correctness of a question answering system. 91 | 92 | This evaluator depends on a reference answer being provided, in addition to the 93 | query string and response string. It outputs a score between 1 and 3, where 1 94 | is the worst and 3 is the best, along with a reason for the score. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | model_name: str = "gpt-4-1106-preview", 100 | provider: str = "openai", 101 | temperature: float = 0.1, 102 | system_template: Optional[str] = SYSTEM_TEMPLATE, 103 | user_template: Optional[str] = USER_TEMPLATE, 104 | max_concurrent_requests: int = 20, 105 | **kwargs 106 | ): 107 | """Initialize the evaluator. 108 | 109 | Args: 110 | model_name: Name of the model to use 111 | provider: Provider of the model (e.g., "openai" or "anthropic") 112 | temperature: Temperature for model sampling 113 | system_template: Optional custom system template to use for evaluation 114 | max_concurrent_requests: Maximum number of concurrent requests 115 | **kwargs: Additional keyword arguments for LLMModel 116 | """ 117 | self.llm = LLMModel( 118 | provider=provider, 119 | model_name=model_name, 120 | temperature=temperature, 121 | response_model=CorrectnessEvaluationModel, 122 | n_parallel_api_calls=max_concurrent_requests, 123 | **kwargs 124 | ) 125 | self.system_template = system_template 126 | self.user_template = user_template 127 | 128 | async def _get_completion(self, system_prompt: str, user_prompt: str) -> CorrectnessEvaluationResult: 129 | """Call the LLM, all other parameters are set in the LLMModel init.""" 130 | response, api_status = await self.llm.create( 131 | messages=[ 132 | {"role": "system", "content": system_prompt}, 133 | {"role": "user", "content": user_prompt} 134 | ], 135 | ) 136 | if not api_status.success: 137 | return LLMError(error=True, error_message=api_status.error_info.error_message) 138 | return response 139 | 140 | @weave.op 141 | async def aevaluate( 142 | self, 143 | query: str, 144 | response: str, 145 | contexts: List[str], 146 | reference: str, 147 | **kwargs: Any, 148 | ) -> CorrectnessEvaluationResult: 149 | """Evaluate the correctness of a response. 150 | 151 | Args: 152 | query: The user's question 153 | response: The generated answer to evaluate 154 | contexts: List of context documents used 155 | reference: The reference answer to compare against 156 | **kwargs: Additional arguments (reference_notes etc) 157 | 158 | Returns: 159 | EvaluationResult containing the evaluation details 160 | """ 161 | 162 | try: 163 | if query is None or response is None or reference is None: 164 | raise ValueError("query, response, and reference must be provided") 165 | 166 | judge_prompt = self.user_template.format( 167 | query=query, 168 | generated_answer=response, 169 | reference_answer=reference, 170 | context_str=re.sub( 171 | "\n+", "\n", "\n---\n".join(contexts) if contexts else "" 172 | ), 173 | reference_notes=kwargs.get("reference_notes", ""), 174 | ) 175 | 176 | # Run evaluation 177 | eval_response = await self._get_completion(system_prompt=self.system_template, user_prompt=judge_prompt) 178 | 179 | if isinstance(eval_response, LLMError): 180 | return CorrectnessEvaluationResult( 181 | answer_correct=False, 182 | score=1.0, 183 | reason=f"Evaluation failed due to an LLM error: {eval_response.error_message}", 184 | decision="incorrect", 185 | has_error=True, 186 | error_message=eval_response.error_message 187 | ) 188 | 189 | decision = eval_response.decision 190 | answer_correct = decision.lower() == "correct" 191 | score = eval_response.score 192 | reason = eval_response.reason 193 | return CorrectnessEvaluationResult( 194 | answer_correct=answer_correct, 195 | score=score, 196 | reason=reason, 197 | decision=decision, 198 | has_error=False, 199 | error_message="" 200 | ) 201 | except Exception as e: 202 | error_msg = f"Error during evaluation: {str(e)}" 203 | return CorrectnessEvaluationResult( 204 | answer_correct=False, 205 | score=1.0, # Lowest score since evaluation failed 206 | reason=error_msg, 207 | decision="incorrect", 208 | has_error=True, 209 | error_message=error_msg 210 | ) 211 | -------------------------------------------------------------------------------- /src/wandbot/evaluation/utils/jp_evaluation_dataprep.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, List 4 | 5 | import requests 6 | import weave 7 | from tqdm import tqdm 8 | 9 | dataset_ref = weave.ref( 10 | "weave:///wandbot/wandbot-eval/object/wandbot_eval_data:eCQQ0GjM077wi4ykTWYhLPRpuGIaXbMwUGEB7IyHlFU" 11 | ).get() 12 | question_rows = dataset_ref.rows 13 | question_rows = [ 14 | { 15 | "question": row["question"], 16 | "ground_truth": row["answer"], 17 | "notes": row["notes"], 18 | "context": row["context"], 19 | "correctness": row["correctness"], 20 | "is_wandb_query": row["is_wandb_query"], 21 | } 22 | for row in question_rows 23 | ] 24 | 25 | 26 | def translate_with_openai(text: str) -> str: 27 | # Get the OpenAI API key from environment variables 28 | api_key = os.environ.get("OPENAI_API_KEY") 29 | if not api_key: 30 | raise ValueError("OPENAI_API_KEY environment variable is not set") 31 | 32 | # Set headers for the OpenAI API request 33 | headers = { 34 | "Content-Type": "application/json", 35 | "Authorization": f"Bearer {api_key}", 36 | } 37 | 38 | # Data payload for GPT-4-turbo (gpt-4o) API request 39 | 40 | # 質問には答えないでというプロンプトを入れても良いかもしれない 41 | 42 | data = { 43 | "model": "gpt-4o-2024-08-06", # Updated to GPT-4 Turbo 44 | "max_tokens": 4000, 45 | "messages": [ 46 | { 47 | "role": "system", 48 | "content": "You are a professional translator. \n\n\ 49 | Translate the user's text into Japanese according to the specified rules. \n\ 50 | Rule of translation. \n\ 51 | - Maintain the original nuance\n\ 52 | - Use 'run' in English where appropriate, as it's a term used in Wandb.\n\ 53 | - Translate the terms 'reference artifacts' and 'lineage' into Katakana. \n\ 54 | - Include specific terms in English or Katakana where appropriate\n\ 55 | - Keep code unchanged.\n\ 56 | - Keep URL starting from 'Source:\thttps:', but translate texts after 'Source:\thttps:'\n\ 57 | - Only return the Japanese translation without any additional explanation", 58 | }, 59 | {"role": "user", "content": text}, 60 | ], 61 | } 62 | 63 | # Make the API request to OpenAI 64 | response = requests.post( 65 | "https://api.openai.com/v1/chat/completions", headers=headers, json=data 66 | ) 67 | 68 | # Check if the request was successful 69 | if response.status_code == 200: 70 | # Return the translated text 71 | return response.json()["choices"][0]["message"]["content"].strip() 72 | else: 73 | raise Exception( 74 | f"API request failed with status code {response.status_code}: {response.text}" 75 | ) 76 | 77 | 78 | def translate_data(data: List[Dict[str, Any]], output_file: str) -> None: 79 | total_items = len(data) 80 | 81 | # Check if the file exists and get the last processed index 82 | if os.path.exists(output_file): 83 | with open(output_file, "r", encoding="utf-8") as file: 84 | processed_data = json.load(file) 85 | start_index = len(processed_data) 86 | else: 87 | processed_data = [] 88 | start_index = 0 89 | 90 | for i in tqdm( 91 | range(start_index, total_items), initial=start_index, total=total_items 92 | ): 93 | item = data[i] 94 | translated_item = item.copy() 95 | for key in ["question", "ground_truth", "notes", "context"]: 96 | if key in item: 97 | translated_item[key] = translate_with_openai(item[key]) 98 | 99 | processed_data.append(translated_item) 100 | 101 | # Save progress after each item 102 | with open(output_file, "w", encoding="utf-8") as file: 103 | json.dump(processed_data, file, ensure_ascii=False, indent=2) 104 | 105 | print(f"Translation completed. Results saved in '{output_file}'") 106 | 107 | 108 | output_file = "translated_data.json" 109 | translate_data(question_rows, output_file) 110 | -------------------------------------------------------------------------------- /src/wandbot/evaluation/utils/jp_evaluation_dataupload.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import weave 4 | from weave import Dataset 5 | 6 | 7 | def rename_key(item): 8 | if "ground_truth" in item: 9 | item["answer"] = item.pop("ground_truth") 10 | return item 11 | 12 | 13 | def create_test_file(json_file_path, test_file_path, num_lines=5): 14 | with open(json_file_path, "r") as file: 15 | data = json.load(file) 16 | 17 | test_data = data[:num_lines] 18 | 19 | with open(test_file_path, "w") as file: 20 | json.dump(test_data, file, indent=2, ensure_ascii=False) 21 | 22 | print( 23 | f"Test file with {num_lines} lines has been created at {test_file_path}" 24 | ) 25 | 26 | 27 | def publish_json_to_weave(json_file_path, dataset_name, project_name): 28 | # Initialize Weave 29 | weave.init(project_name) 30 | 31 | # Read JSON file 32 | with open(json_file_path, "r") as file: 33 | data = json.load(file) 34 | 35 | # Rename 'ground_truth' to 'answer' in each item 36 | processed_data = [rename_key(item) for item in data] 37 | 38 | # Create a dataset 39 | dataset = Dataset(name=dataset_name, rows=processed_data) 40 | 41 | # Publish the dataset 42 | weave.publish(dataset) 43 | 44 | print( 45 | f"Dataset '{dataset_name}' has been published to project '{project_name}'." 46 | ) 47 | 48 | 49 | # Usage example 50 | json_file_path = "translated_data.json" 51 | test_file_path = "test_translated_data.json" 52 | dataset_name = "wandbot_eval_data_jp" 53 | project_name = "wandbot/wandbot-eval-jp" 54 | 55 | # Create test file 56 | # create_test_file(json_file_path, test_file_path) 57 | 58 | # Publish full dataset to Weave 59 | publish_json_to_weave(json_file_path, dataset_name, project_name) 60 | -------------------------------------------------------------------------------- /src/wandbot/evaluation/utils/log_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["WANDB_ENTITY"] = "wandbot" 4 | 5 | import pandas as pd 6 | import weave 7 | from weave import Dataset 8 | 9 | import wandb 10 | from wandbot.evaluation.eval_config import EvalConfig 11 | 12 | config = EvalConfig() 13 | 14 | wandb_project = config.wandb_project 15 | wandb_entity = config.wandb_entity 16 | 17 | eval_artifact = wandb.Api().artifact(config.eval_artifact) 18 | eval_artifact_dir = eval_artifact.download(root=config.eval_artifact_root) 19 | 20 | df = pd.read_json( 21 | f"{eval_artifact_dir}/{config.eval_annotations_file}", 22 | lines=True, 23 | orient="records", 24 | ) 25 | df.insert(0, "id", df.index) 26 | 27 | correct_df = df[ 28 | (df["is_wandb_query"] == "YES") & (df["correctness"] == "correct") 29 | ] 30 | 31 | data_rows = correct_df.to_dict("records") 32 | 33 | weave.init(wandb_project) 34 | 35 | # Create a dataset 36 | dataset = Dataset( 37 | name="wandbot_eval_data", 38 | rows=data_rows, 39 | ) 40 | 41 | # Publish the dataset 42 | weave.publish(dataset) 43 | -------------------------------------------------------------------------------- /src/wandbot/evaluation/utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Optional, Tuple 4 | 5 | from pydantic import BaseModel 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class EvaluationResult(BaseModel): 10 | """Result of an evaluation.""" 11 | query: str 12 | response: str 13 | reasoning: Optional[str] = None 14 | score: Optional[float] = None 15 | passing: Optional[bool] = None 16 | has_error: bool = False 17 | error_message: Optional[str] = None 18 | 19 | 20 | async def safe_parse_eval_response(eval_response: str, expected_decision: str) -> Tuple[bool, str, float]: 21 | """Safely parse the evaluation response.""" 22 | try: 23 | # Try to find the JSON object in the response 24 | start = eval_response.find("{") 25 | end = eval_response.rfind("}") + 1 26 | if start == -1 or end == 0: 27 | raise ValueError("No JSON object found in response") 28 | 29 | json_str = eval_response[start:end] 30 | result = json.loads(json_str) 31 | 32 | # Extract values 33 | passing = result["decision"].lower() == expected_decision.lower() 34 | reasoning = result["reason"] 35 | score = float(result["score"]) 36 | 37 | return passing, reasoning, score 38 | except (json.JSONDecodeError, KeyError, ValueError) as e: 39 | logger.error(f"Failed to parse evaluation response: {str(e)}\nResponse: {eval_response}") 40 | return False, f"Failed to parse evaluation response: {str(e)}", 1.0 41 | -------------------------------------------------------------------------------- /src/wandbot/ingestion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/ingestion/__init__.py -------------------------------------------------------------------------------- /src/wandbot/ingestion/__main__.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | from wandbot.configs.ingestion_config import IngestionConfig 4 | from wandbot.ingestion.prepare_data import run_prepare_data_pipeline 5 | from wandbot.ingestion.preprocess_data import run_preprocessing_pipeline 6 | from wandbot.ingestion.run_ingestion_config import IngestionRunConfig, get_run_config 7 | from wandbot.ingestion.vectorstore_and_report import run_vectorstore_and_report_pipeline 8 | from wandbot.utils import get_logger 9 | 10 | ingestion_config = IngestionConfig() 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def main(): 16 | load_dotenv() 17 | # Parse command-line arguments 18 | run_config: IngestionRunConfig = get_run_config() 19 | logger.info(f"Running ingestion with config: {run_config}") 20 | 21 | project = ingestion_config.wandb_project 22 | entity = ingestion_config.wandb_entity 23 | 24 | # Adjust artifact names if in debug mode 25 | raw_data_artifact_name = run_config.raw_data_artifact_name 26 | preprocessed_data_artifact_name = run_config.preprocessed_data_artifact_name 27 | vectorstore_artifact_name = run_config.vectorstore_artifact_name 28 | 29 | if run_config.debug: 30 | logger.warning("----- RUNNING IN DEBUG MODE -----") 31 | raw_data_artifact_name += "_debug" 32 | preprocessed_data_artifact_name += "_debug" 33 | vectorstore_artifact_name += "_debug" 34 | logger.info(f"Debug mode: Artifact names adjusted to: {raw_data_artifact_name}, {preprocessed_data_artifact_name}, {vectorstore_artifact_name}") 35 | 36 | # Variables to hold artifact paths/names 37 | raw_artifact_path = None 38 | preprocessed_artifact_path = None 39 | vectorstore_artifact_path = None 40 | 41 | # Execute steps based on config 42 | if "prepare" in run_config.steps: 43 | logger.info("\n\n ------ Starting Prepare Data Step ------\n\n") 44 | raw_artifact_path = run_prepare_data_pipeline( 45 | project=project, 46 | entity=entity, 47 | result_artifact_name=raw_data_artifact_name, 48 | include_sources=run_config.include_sources, 49 | exclude_sources=run_config.exclude_sources, 50 | debug=run_config.debug, 51 | ) 52 | logger.info(f"Prepare Data Step completed. Raw artifact: {raw_artifact_path}") 53 | else: 54 | logger.info("Skipping Prepare Data Step") 55 | 56 | if "preprocess" in run_config.steps: 57 | logger.info("\n\n ------ Starting Preprocess Data Step ------\n\n") 58 | if not raw_artifact_path: 59 | raw_artifact_path = ( 60 | f"{entity}/{project}/{raw_data_artifact_name}:latest" 61 | ) 62 | logger.warning( 63 | f"Prepare step skipped, using latest raw artifact: {raw_artifact_path}" 64 | ) 65 | 66 | preprocessed_artifact_path = run_preprocessing_pipeline( 67 | project=project, 68 | entity=entity, 69 | source_artifact_path=raw_artifact_path, 70 | result_artifact_name=preprocessed_data_artifact_name, 71 | debug=run_config.debug, 72 | ) 73 | logger.info( 74 | f"Preprocess Data Step completed. Preprocessed artifact: {preprocessed_artifact_path}" 75 | ) 76 | else: 77 | logger.info("Skipping Preprocess Data Step") 78 | 79 | # Combine Vectorstore and Report steps if either is requested 80 | if "vectorstore" in run_config.steps or "report" in run_config.steps: 81 | logger.info("\n\n ------ Starting Combined Vector Store and Report Step ------\n\n") 82 | 83 | # Ensure raw_artifact_path is set (needed for report) 84 | if not raw_artifact_path: 85 | raw_artifact_path = f"{entity}/{project}/{raw_data_artifact_name}:latest" 86 | logger.warning( 87 | f"Prepare step skipped, using latest raw artifact for vectorstore/report: {raw_artifact_path}" 88 | ) 89 | 90 | # Ensure preprocessed_artifact_path is set (needed for vectorstore) 91 | if not preprocessed_artifact_path: 92 | preprocessed_artifact_path = f"{entity}/{project}/{preprocessed_data_artifact_name}:latest" 93 | logger.warning( 94 | f"Preprocess step skipped, using latest preprocessed artifact for vectorstore/report: {preprocessed_artifact_path}" 95 | ) 96 | 97 | create_report_flag = "report" in run_config.steps 98 | if not create_report_flag: 99 | logger.info("Report creation will be skipped as 'report' is not in the specified steps.") 100 | 101 | vectorstore_artifact_path = run_vectorstore_and_report_pipeline( 102 | project=project, 103 | entity=entity, 104 | raw_artifact_path=raw_artifact_path, 105 | preprocessed_artifact_path=preprocessed_artifact_path, 106 | vectorstore_artifact_name=vectorstore_artifact_name, 107 | debug=run_config.debug, 108 | create_report=create_report_flag, 109 | upload_to_remote_vector_store=run_config.upload_to_remote_vector_store, 110 | ) 111 | logger.info( 112 | f"Combined Vector Store and Report Step completed. Vectorstore artifact: {vectorstore_artifact_path}" 113 | ) 114 | 115 | else: 116 | logger.info("Skipping Combined Vector Store and Report Step") 117 | 118 | logger.info("Ingestion pipeline finished.") 119 | final_artifact = ( 120 | vectorstore_artifact_path 121 | or preprocessed_artifact_path 122 | or raw_artifact_path 123 | or "No artifact generated." 124 | ) 125 | logger.info(f"Final artifact from run: {final_artifact}") 126 | 127 | 128 | if __name__ == "__main__": 129 | main() -------------------------------------------------------------------------------- /src/wandbot/ingestion/preprocessors/markdown.py: -------------------------------------------------------------------------------- 1 | import json 2 | from hashlib import md5 3 | from typing import Any, Callable, Dict, List, Optional, Sequence, TypedDict 4 | 5 | from langchain.text_splitter import ( 6 | Language, 7 | MarkdownHeaderTextSplitter, 8 | RecursiveCharacterTextSplitter, 9 | ) 10 | from langchain_core.documents import BaseDocumentTransformer 11 | 12 | from wandbot.configs.ingestion_config import DocodileEnglishStoreConfig 13 | from wandbot.schema.document import Document 14 | from wandbot.utils import FastTextLangDetect, FasttextModelConfig 15 | 16 | 17 | class LineType(TypedDict): 18 | """Line type as typed dict.""" 19 | 20 | metadata: Dict[str, str] 21 | content: str 22 | 23 | 24 | class HeaderType(TypedDict): 25 | """Header type as typed dict.""" 26 | 27 | level: int 28 | name: str 29 | data: str 30 | 31 | 32 | def create_id_from_document(document: Document) -> str: 33 | contents = document.page_content + json.dumps(document.metadata) 34 | checksum = md5(contents.encode("utf-8")).hexdigest() 35 | return checksum 36 | 37 | 38 | def create_id_from_page_content(page_content: str) -> str: 39 | return md5(page_content.encode("utf-8")).hexdigest() 40 | 41 | 42 | def prefix_headers_based_on_metadata(chunk): 43 | # Headers ordered by markdown header levels 44 | markdown_header_prefixes = ["#", "##", "###", "####", "#####", "######"] 45 | markdown_header_prefixes_map = {f"header_{i}": prefix for i, prefix in enumerate(markdown_header_prefixes)} 46 | 47 | # Generate headers from metadata 48 | headers_from_metadata = [ 49 | f"{markdown_header_prefixes_map[level]} {title}" for level, title in chunk["metadata"].items() 50 | ] 51 | 52 | # Join the generated headers with new lines 53 | headers_str = "\n".join(headers_from_metadata) + "\n" 54 | 55 | # Check if the page_content starts with a header 56 | if chunk["content"].lstrip().startswith(tuple(markdown_header_prefixes)): 57 | # Find the first newline to locate the end of the existing header 58 | first_newline_index = chunk["content"].find("\n") 59 | if first_newline_index != -1: 60 | # Remove the existing header and prefix with generated headers 61 | modified_content = headers_str + chunk["content"][first_newline_index + 1 :] 62 | else: 63 | # If there's no newline, the entire content is a header, replace it 64 | modified_content = headers_str 65 | else: 66 | # If it doesn't start with a header, just prefix with generated headers 67 | modified_content = headers_str + chunk["content"] 68 | 69 | return {"metadata": chunk["metadata"], "content": modified_content} 70 | 71 | 72 | class CustomMarkdownTextSplitter(MarkdownHeaderTextSplitter): 73 | """Splitting markdown files based on specified headers.""" 74 | 75 | def __init__(self, chunk_size: Optional[int] = None, **kwargs): 76 | headers_to_split_on = [ 77 | ("#", "header_1"), 78 | ("##", "header_2"), 79 | ("###", "header_3"), 80 | ("####", "header_4"), 81 | ("#####", "header_5"), 82 | ("######", "header_6"), 83 | ] 84 | self.max_length = chunk_size 85 | super().__init__( 86 | headers_to_split_on=headers_to_split_on, 87 | return_each_line=False, 88 | strip_headers=False, 89 | ) 90 | 91 | def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[Document]: 92 | aggregated_chunks: List[LineType] = [] 93 | 94 | for line in lines: 95 | should_append = True 96 | 97 | # Attempt to aggregate with an earlier chunk if possible 98 | for i in range(len(aggregated_chunks) - 1, -1, -1): 99 | previous_chunk = aggregated_chunks[i] 100 | # Check if the current line's metadata is a child or same level of the previous chunk's metadata 101 | if all(item in line["metadata"].items() for item in previous_chunk["metadata"].items()): 102 | potential_new_content = previous_chunk["content"] + " \n\n" + line["content"] 103 | if self.max_length is None or len(potential_new_content) <= self.max_length: 104 | # If adding the current line does not exceed chunk_size, merge it into the previous chunk 105 | aggregated_chunks[i]["content"] = potential_new_content 106 | should_append = False 107 | break 108 | else: 109 | # If it exceeds chunk_size, no further checks are needed, break to append as a new chunk 110 | break 111 | 112 | if should_append: 113 | # Append as a new chunk if it wasn't merged into an earlier one 114 | aggregated_chunks.append(line) 115 | 116 | # Prefix headers based on metadata 117 | aggregated_chunks = [prefix_headers_based_on_metadata(chunk) for chunk in aggregated_chunks] 118 | return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] 119 | 120 | def split_documents(self, documents: List[Document]) -> List[Document]: 121 | """Split a list of documents into smaller documents. 122 | 123 | Args: 124 | documents: A list of documents. 125 | 126 | Returns: 127 | A list of documents. 128 | """ 129 | split_documents = [] 130 | for document in documents: 131 | # Use Langchain's splitter to get basic text chunks and their metadata (like headers) 132 | # Note: langchain's split_text may return its own Document type, not ours 133 | chunks = self.split_text(document.page_content) 134 | 135 | original_doc_id = document.metadata.get("id") 136 | if not original_doc_id: 137 | # Fallback if original ID somehow missing (shouldn't happen with prepare_data changes) 138 | original_doc_id = create_id_from_document(document) 139 | 140 | for chunk in chunks: 141 | # Create new metadata, copying from parent 142 | new_metadata = document.metadata.copy() 143 | # Update with any specific metadata from the langchain chunk (e.g., header info) 144 | new_metadata.update(chunk.metadata) 145 | 146 | # Set parent_id to the original document's ID 147 | new_metadata["parent_id"] = original_doc_id 148 | 149 | # Remove the original ID from the new metadata before creating the Document 150 | if "id" in new_metadata: 151 | del new_metadata["id"] 152 | new_metadata["id"] = create_id_from_page_content(chunk.page_content) 153 | 154 | # Create our Document object 155 | new_doc = Document( 156 | page_content=chunk.page_content, 157 | metadata=new_metadata, 158 | ) 159 | split_documents.append(new_doc) 160 | 161 | return split_documents 162 | 163 | 164 | class MarkdownTextTransformer(BaseDocumentTransformer): 165 | def __init__( 166 | self, 167 | lang_detect, 168 | chunk_size: int, 169 | chunk_multiplier: int, 170 | chunk_overlap: int, 171 | length_function: Callable[[str], int] = None, 172 | ): 173 | self.fasttext_model = lang_detect 174 | self.chunk_size: int = chunk_size 175 | self.chunk_multiplier: int = chunk_multiplier 176 | self.chunk_overlap: int = chunk_overlap 177 | self.length_function: Callable[[str], int] = length_function if length_function is not None else len 178 | self.recursive_splitter = RecursiveCharacterTextSplitter.from_language( 179 | language=Language.MARKDOWN, 180 | chunk_size=self.chunk_size, 181 | chunk_overlap=self.chunk_overlap, 182 | keep_separator=True, 183 | length_function=self.length_function, 184 | ) 185 | self.header_splitter = CustomMarkdownTextSplitter( 186 | chunk_size=self.chunk_size * self.chunk_multiplier, 187 | ) 188 | 189 | def identify_document_language(self, document: Document) -> str: 190 | if "language" in document.metadata: 191 | return document.metadata["language"] 192 | else: 193 | return self.fasttext_model.detect_language(document.page_content) 194 | 195 | def split_markdown_documents( 196 | self, 197 | documents: List[Document], 198 | ) -> List[Document]: 199 | final_chunks = [] 200 | chunked_documents = [] 201 | for document in documents: 202 | document_splits = self.header_splitter.split_documents( 203 | documents=[document], 204 | ) 205 | for split in document_splits: 206 | chunk = Document( 207 | page_content=split.page_content, 208 | metadata=split.metadata.copy(), 209 | ) 210 | chunk.metadata["parent_id"] = document.metadata["id"] 211 | chunk.metadata["language"] = self.identify_document_language(chunk) 212 | chunk.metadata["has_code"] = "```" in chunk.page_content 213 | chunked_documents.append(chunk) 214 | 215 | split_chunks = self.recursive_splitter.split_documents(chunked_documents) 216 | 217 | for chunk in split_chunks: 218 | chunk = Document( 219 | page_content=chunk.page_content, 220 | metadata=chunk.metadata.copy(), 221 | ) 222 | chunk.metadata["id"] = create_id_from_document(chunk) 223 | final_chunks.append(chunk) 224 | 225 | return final_chunks 226 | 227 | def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: 228 | split_documents = self.split_markdown_documents(list(documents)) 229 | transformed_documents = [] 230 | for document in split_documents: 231 | transformed_documents.append(document) 232 | return transformed_documents 233 | 234 | 235 | if __name__ == "__main__": 236 | docodile_en_config = DocodileEnglishStoreConfig() 237 | lang_detect = FastTextLangDetect( 238 | FasttextModelConfig(fasttext_file_path="/media/mugan/data/wandb/projects/wandbot/data/cache/models/lid.176.bin") 239 | ) 240 | 241 | data_file = open( 242 | "/media/mugan/data/wandb/projects/wandbot/data/cache/raw_data/docodile_store/docodile_en/documents.jsonl" 243 | ).readlines() 244 | source_document = json.loads(data_file[0]) 245 | 246 | source_document = Document(**source_document) 247 | 248 | markdown_transformer = MarkdownTextTransformer( 249 | lang_detect=lang_detect, 250 | chunk_size=docodile_en_config.chunk_size, 251 | chunk_multiplier=docodile_en_config.chunk_multiplier, 252 | ) 253 | 254 | transformed_documents = markdown_transformer.transform_documents([source_document]) 255 | 256 | for document in transformed_documents: 257 | print(document.page_content) 258 | print(json.dumps(document.metadata, indent=2)) 259 | print("*" * 80) 260 | -------------------------------------------------------------------------------- /src/wandbot/ingestion/run_ingestion_config.py: -------------------------------------------------------------------------------- 1 | # src/wandbot/ingestion/config.py 2 | from dataclasses import dataclass, field 3 | from typing import List 4 | 5 | import simple_parsing as sp 6 | 7 | 8 | @dataclass 9 | class IngestionRunConfig: 10 | """Command-line arguments for controlling the ingestion pipeline.""" 11 | steps: List[str] = field( 12 | default_factory=lambda: ["prepare", "preprocess", "vectorstore", "report"] 13 | ) 14 | """Steps to run: prepare, preprocess, vectorstore, report""" 15 | include_sources: List[str] = field(default_factory=list) 16 | """List of specific source names (from ingestion_config.py) to include. If empty, includes all (respecting excludes).""" 17 | exclude_sources: List[str] = field(default_factory=list) 18 | """List of specific source names to exclude. Applied after includes.""" 19 | raw_data_artifact_name: str = "raw_data" 20 | """Override the default raw data artifact name.""" 21 | preprocessed_data_artifact_name: str = "transformed_data" 22 | """Override the default preprocessed data artifact name.""" 23 | vectorstore_artifact_name: str = "chroma_index" 24 | """Override the default vector store artifact name.""" 25 | debug: bool = False 26 | """Run in debug mode: process only the first source and first 3 documents, append _debug to artifact names.""" 27 | upload_to_remote_vector_store: bool = True 28 | """Whether to upload the final vector store collection to the configured remote ChromaDB instance.""" 29 | 30 | def get_run_config() -> IngestionRunConfig: 31 | """Parses command line arguments for ingestion run configuration.""" 32 | parser = sp.ArgumentParser(add_option_string_dash_variants=True) 33 | parser.add_arguments(IngestionRunConfig, dest="run_config") 34 | args = parser.parse_args() 35 | return args.run_config -------------------------------------------------------------------------------- /src/wandbot/models/embedding.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import os 4 | import sys 5 | import traceback 6 | from typing import Any, Dict, List, Tuple, Union 7 | 8 | import numpy as np 9 | import weave 10 | from tenacity import retry, stop_after_attempt, wait_exponential 11 | 12 | from wandbot.configs.vector_store_config import VectorStoreConfig 13 | from wandbot.schema.api_status import APIStatus 14 | from wandbot.utils import ErrorInfo, get_error_file_path, get_logger 15 | 16 | logger = get_logger(__name__) 17 | vector_store_config = VectorStoreConfig() 18 | 19 | # Valid input types for Cohere embeddings 20 | VALID_COHERE_INPUT_TYPES = ["search_document", "search_query", "classification", "clustering", "image"] 21 | 22 | class BaseEmbeddingModel: 23 | COMPONENT_NAME = "embedding" # Default component name 24 | 25 | def __init__(self, 26 | model_name: str, 27 | n_parallel_api_calls: int = 50, 28 | max_retries: int = 3, 29 | timeout: int = 30, 30 | **kwargs): 31 | self.model_name = model_name 32 | self.n_parallel_api_calls = n_parallel_api_calls 33 | self.max_retries = max_retries 34 | self.timeout = timeout 35 | 36 | @weave.op 37 | def embed(self, input: Union[str, List[str]]) -> Tuple[List[List[float]], APIStatus]: 38 | raise NotImplementedError("Subclasses must implement embed method") 39 | 40 | def __call__(self, input: Union[str, List[str]] = None) -> List[List[float]]: 41 | embeddings, api_status = self.embed(input) 42 | if not api_status.success: 43 | raise RuntimeError(api_status.error_info.error_message) 44 | return embeddings 45 | 46 | class OpenAIEmbeddingModel(BaseEmbeddingModel): 47 | COMPONENT_NAME = "openai_embedding" 48 | 49 | def __init__(self, model_name: str, dimensions: int, encoding_format: str = "float", **kwargs): 50 | super().__init__(model_name, **kwargs) 51 | self.dimensions = dimensions 52 | self.encoding_format = encoding_format 53 | from openai import AsyncOpenAI 54 | self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"), max_retries=self.max_retries, timeout=self.timeout) 55 | 56 | async def _run_openai_embeddings(self, inputs: List[str]) -> Tuple[List[List[float]], APIStatus]: 57 | api_status = APIStatus(component=self.COMPONENT_NAME, success=True) 58 | try: 59 | semaphore = asyncio.Semaphore(self.n_parallel_api_calls) 60 | async def get_single_openai_embedding(text): 61 | async with semaphore: 62 | response = await self.client.embeddings.create( 63 | input=text, model=self.model_name, 64 | encoding_format=self.encoding_format, 65 | dimensions=self.dimensions) 66 | embedding = response.data[0].embedding 67 | if self.encoding_format == "base64" and isinstance(embedding, str): 68 | decoded_embeddings = base64.b64decode(embedding) 69 | return np.frombuffer(decoded_embeddings, dtype=np.float32).tolist() 70 | return embedding 71 | 72 | # Use return_exceptions=True to handle partial failures 73 | results = await asyncio.gather(*[get_single_openai_embedding(text) for text in inputs], return_exceptions=True) 74 | 75 | # Check if any results are exceptions 76 | exceptions = [r for r in results if isinstance(r, Exception)] 77 | if exceptions: 78 | # Get error messages and stack traces from the actual exception objects 79 | error_details = [] 80 | for e in exceptions: 81 | tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__)) 82 | error_details.append(f"{str(e)}\n{tb}") 83 | 84 | error_info = ErrorInfo( 85 | component=self.COMPONENT_NAME, 86 | has_error=True, 87 | error_message=f"Failed to embed some inputs: {'; '.join(str(e) for e in exceptions)}", 88 | error_type="PartialEmbeddingFailure", 89 | stacktrace='\n'.join(error_details), 90 | file_path=get_error_file_path(sys.exc_info()[2]) 91 | ) 92 | return None, APIStatus( 93 | component=self.COMPONENT_NAME, 94 | success=False, 95 | error_info=error_info 96 | ) 97 | 98 | return results, api_status 99 | except Exception as e: 100 | error_info = ErrorInfo( 101 | component=self.COMPONENT_NAME, 102 | has_error=True, 103 | error_message=str(e), 104 | error_type=type(e).__name__, 105 | stacktrace=''.join(traceback.format_exc()), 106 | file_path=get_error_file_path(sys.exc_info()[2]) 107 | ) 108 | return None, APIStatus( 109 | component=self.COMPONENT_NAME, 110 | success=False, 111 | error_info=error_info 112 | ) 113 | 114 | def embed(self, input: Union[str, List[str]]) -> Tuple[List[List[float]], APIStatus]: 115 | inputs = [input] if isinstance(input, str) else input 116 | return asyncio.run(self._run_openai_embeddings(inputs)) 117 | 118 | class CohereEmbeddingModel(BaseEmbeddingModel): 119 | COMPONENT_NAME = "cohere_embedding" 120 | 121 | def __init__(self, model_name: str, input_type: str, dimensions: int = None, **kwargs): 122 | super().__init__(model_name, **kwargs) 123 | if not input_type: 124 | raise ValueError("input_type must be specified for Cohere embeddings") 125 | if input_type not in VALID_COHERE_INPUT_TYPES: 126 | raise ValueError(f"Invalid input_type: {input_type}. Must be one of {VALID_COHERE_INPUT_TYPES}") 127 | self.input_type = input_type 128 | self.dimensions = dimensions 129 | try: 130 | import cohere 131 | from cohere.core.request_options import RequestOptions 132 | self.client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) 133 | self.RequestOptions = RequestOptions # Store the class for later use 134 | except Exception as e: 135 | logger.error(f'Unable to initialise Cohere client:\n{e}') 136 | raise 137 | 138 | def embed(self, input: Union[str, List[str]]) -> Tuple[List[List[float]], APIStatus]: 139 | api_status = APIStatus(component=self.COMPONENT_NAME, success=True) 140 | try: 141 | inputs = [input] if isinstance(input, str) else input 142 | response = self.client.embed( 143 | texts=inputs, 144 | model=self.model_name, 145 | input_type=self.input_type, 146 | embedding_types=["float"], 147 | request_options=self.RequestOptions(max_retries=self.max_retries, timeout_in_seconds=self.timeout) 148 | ) 149 | embeddings = response.embeddings.float 150 | # Convert to numpy arrays 151 | if isinstance(embeddings, list): 152 | if len(embeddings) == 1: 153 | return np.array(embeddings[0]), api_status 154 | return [np.array(emb) for emb in embeddings], api_status 155 | return np.array(embeddings), api_status 156 | except Exception as e: 157 | error_info = ErrorInfo( 158 | component=self.COMPONENT_NAME, 159 | has_error=True, 160 | error_message=str(e), 161 | error_type=type(e).__name__, 162 | stacktrace=''.join(traceback.format_exc()), 163 | file_path=get_error_file_path(sys.exc_info()[2]) 164 | ) 165 | return None, APIStatus( 166 | component=self.COMPONENT_NAME, 167 | success=False, 168 | error_info=error_info 169 | ) 170 | 171 | class EmbeddingModel: 172 | PROVIDER_MAP = { 173 | "openai": OpenAIEmbeddingModel, 174 | "cohere": CohereEmbeddingModel 175 | } 176 | 177 | def __init__(self, provider: str, **kwargs): 178 | provider = provider.lower() 179 | if provider not in self.PROVIDER_MAP: 180 | raise ValueError(f"Unsupported provider: {provider}. Choose from {list(self.PROVIDER_MAP.keys())}") 181 | 182 | if provider == "openai" and "dimensions" not in kwargs: 183 | raise ValueError("`dimensions` needs to be specified when using OpenAI embeddings models") 184 | 185 | if provider == "cohere" and "input_type" not in kwargs: 186 | raise ValueError("input_type must be specified for Cohere embeddings") 187 | 188 | self.model = self.PROVIDER_MAP[provider](**kwargs) 189 | 190 | def embed(self, input: Union[str, List[str]] = None) -> Tuple[List[List[float]], APIStatus]: 191 | """Required interface for Chroma's EmbeddingFunction""" 192 | # No try-except here - let errors propagate up 193 | embeddings, api_status = self.model.embed(input) 194 | return embeddings, api_status 195 | 196 | def __call__(self, input: Union[str, List[str]] = None) -> List[List[float]]: 197 | """Required interface for Chroma's EmbeddingFunction""" 198 | embeddings, api_status = self.embed(input) 199 | if not api_status.success: 200 | raise RuntimeError(api_status.error_info.error_message) # Raise for Chroma to handle 201 | return embeddings -------------------------------------------------------------------------------- /src/wandbot/rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/wandbot/29d248b9f022ee3650501b177d25c2b54c8e6d56/src/wandbot/rag/__init__.py -------------------------------------------------------------------------------- /src/wandbot/rag/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from langchain_core.prompts import PromptTemplate, format_document 4 | 5 | from wandbot.schema.document import Document 6 | 7 | # from langchain_openai import ChatOpenAI 8 | from wandbot.utils import clean_document_content 9 | 10 | # class ChatModel: 11 | # def __init__(self, max_retries: int = 2): 12 | # self.max_retries = max_retries 13 | 14 | # def __set_name__(self, owner, name): 15 | # self.public_name = name 16 | # self.private_name = "_" + name 17 | 18 | # def __get__(self, obj, obj_type=None): 19 | # value = getattr(obj, self.private_name) 20 | # return value 21 | 22 | # def __set__(self, obj, value): 23 | # model = ChatOpenAI( 24 | # model_name=value["model_name"], 25 | # temperature=value["temperature"], 26 | # max_retries=self.max_retries, 27 | # ) 28 | # setattr(obj, self.private_name, model) 29 | 30 | 31 | DEFAULT_QUESTION_PROMPT = PromptTemplate.from_template( 32 | template="""# Query 33 | 34 | {page_content} 35 | 36 | --- 37 | 38 | # Query Metadata 39 | 40 | Language: 41 | {language} 42 | 43 | Intents: 44 | {intents} 45 | 46 | Sub-queries to consider answering: 47 | 48 | {sub_queries} 49 | """ 50 | ) 51 | 52 | 53 | def create_query_str(enhanced_query, document_prompt=DEFAULT_QUESTION_PROMPT): 54 | page_content = enhanced_query["standalone_query"] 55 | metadata = { 56 | "language": enhanced_query["language"], 57 | "intents": enhanced_query["intents"], 58 | "sub_queries": "\t" 59 | + "\n\t".join(enhanced_query["sub_queries"]).strip(), 60 | } 61 | doc = Document(page_content=page_content, metadata=metadata) 62 | doc = clean_document_content(doc) 63 | doc_string = format_document(doc, document_prompt) 64 | return doc_string 65 | 66 | 67 | DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template( 68 | template="source: {source}\nsource_type: {source_type}\nhas_code: {has_code}\n\n{page_content}" 69 | ) 70 | 71 | 72 | def combine_documents( 73 | docs, 74 | document_prompt=DEFAULT_DOCUMENT_PROMPT, 75 | document_separator="\n\n---\n\n", 76 | ): 77 | cleaned_docs = [clean_document_content(doc) for doc in docs] 78 | doc_strings = [ 79 | format_document(doc, document_prompt) for doc in cleaned_docs 80 | ] 81 | return document_separator.join(doc_strings) 82 | 83 | 84 | def process_input_for_retrieval(retrieval_input): 85 | if isinstance(retrieval_input, list): 86 | retrieval_input = "\n".join(retrieval_input) 87 | elif isinstance(retrieval_input, dict): 88 | retrieval_input = json.dumps(retrieval_input) 89 | elif not isinstance(retrieval_input, str): 90 | retrieval_input = str(retrieval_input) 91 | return retrieval_input 92 | -------------------------------------------------------------------------------- /src/wandbot/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import VectorStore 2 | 3 | __all__ = ["VectorStore"] 4 | -------------------------------------------------------------------------------- /src/wandbot/retriever/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Dict, List 3 | 4 | import weave 5 | 6 | import wandb 7 | from wandbot.configs.chat_config import ChatConfig 8 | from wandbot.configs.vector_store_config import VectorStoreConfig 9 | from wandbot.models.embedding import EmbeddingModel 10 | from wandbot.retriever.chroma import ChromaVectorStore 11 | from wandbot.schema.document import Document 12 | from wandbot.utils import get_logger 13 | 14 | logger = get_logger(__name__) 15 | 16 | class VectorStore: 17 | """ 18 | Sets up vector store and embedding model. 19 | """ 20 | 21 | def __init__(self, vector_store_config: VectorStoreConfig, chat_config: ChatConfig): 22 | self.vector_store_config = vector_store_config 23 | self.chat_config = chat_config 24 | try: 25 | self.query_embedding_model = EmbeddingModel( 26 | provider = self.vector_store_config.embeddings_provider, 27 | model_name = self.vector_store_config.embeddings_model_name, 28 | dimensions = self.vector_store_config.embeddings_dimensions, 29 | input_type = self.vector_store_config.embeddings_query_input_type, 30 | encoding_format = self.vector_store_config.embeddings_encoding_format 31 | ) 32 | except Exception as e: 33 | raise RuntimeError(f"Failed to initialize embedding model:\n{str(e)}\n") from e 34 | 35 | try: 36 | self.chroma_vectorstore = ChromaVectorStore( 37 | embedding_model=self.query_embedding_model, 38 | vector_store_config=self.vector_store_config, 39 | chat_config=self.chat_config 40 | ) 41 | except Exception as e: 42 | raise RuntimeError(f"Failed to initialize vector store:\n{str(e)}\n") from e 43 | 44 | @classmethod 45 | def from_config(cls, vector_store_config: VectorStoreConfig, chat_config: ChatConfig): 46 | if vector_store_config.vectordb_index_dir.exists(): 47 | return cls(vector_store_config=vector_store_config, chat_config=chat_config) 48 | else: 49 | api = wandb.Api() 50 | art = api.artifact(vector_store_config.vectordb_index_artifact_url) # Download vectordb index from W&B 51 | _ = art.download(vector_store_config.vectordb_index_dir) 52 | return cls(vector_store_config=vector_store_config, chat_config=chat_config) 53 | 54 | @weave.op 55 | def retrieve( 56 | self, 57 | query_texts: List[str], 58 | filter_params: dict = None, 59 | ) -> Dict[str, List[Document]]: 60 | """Retrieve documents using either MMR or similarity search based on chat_config. 61 | 62 | Args: 63 | query_texts: List of queries to search for 64 | filter_params: Optional filtering parameters 65 | {"filter": dict, "where_document": dict} 66 | """ 67 | filter_params = filter_params or {} 68 | 69 | if self.chat_config.search_type == "mmr": 70 | # Use fixed parameters for MMR as per retrieval_implementation.md 71 | results = self.chroma_vectorstore.max_marginal_relevance_search( 72 | query_texts=query_texts, 73 | top_k=self.chat_config.top_k_per_query, 74 | fetch_k=self.chat_config.fetch_k, 75 | lambda_mult=self.chat_config.mmr_lambda_mult, 76 | filter=filter_params.get("filter"), 77 | where_document=filter_params.get("where_document") 78 | ) 79 | logger.debug(f"RETRIEVER: MMR search completed with {len(results)} results") 80 | else: 81 | results = self.chroma_vectorstore.similarity_search( 82 | query_texts=query_texts, 83 | top_k=self.chat_config.top_k_per_query, 84 | filter=filter_params.get("filter"), 85 | where_document=filter_params.get("where_document") 86 | ) 87 | logger.debug(f"RETRIEVER: Similarity search completed with {len(results)} results") 88 | 89 | return results 90 | 91 | async def _async_retrieve( 92 | self, 93 | query_texts: List[str], 94 | filter_params: dict = None 95 | ) -> Dict[str, List[Document]]: 96 | """Async version of retrieve that returns the same dictionary structure.""" 97 | return await asyncio.to_thread( 98 | self.retrieve, 99 | query_texts=query_texts, 100 | filter_params=filter_params, 101 | ) -------------------------------------------------------------------------------- /src/wandbot/retriever/mmr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import numpy as np 4 | import weave 5 | 6 | from wandbot.retriever.utils import cosine_similarity 7 | from wandbot.schema.document import Document 8 | 9 | 10 | @weave.op 11 | def maximal_marginal_relevance( 12 | query_embedding: np.ndarray, 13 | embedding_list: list, 14 | top_k: int, 15 | lambda_mult: float, 16 | ) -> List[int]: 17 | """Calculate maximal marginal relevance. 18 | 19 | Args: 20 | query_embedding: Query embedding. 21 | embedding_list: List of embeddings to select from. 22 | lambda_mult: Number between 0 and 1 that determines the degree 23 | of diversity among the results with 0 corresponding 24 | to maximum diversity and 1 to minimum diversity. 25 | k: Number of Documents to return. 26 | 27 | Returns: 28 | List of indices of embeddings selected by maximal marginal relevance. 29 | """ 30 | if min(top_k, len(embedding_list)) <= 0: 31 | return [] 32 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 33 | most_similar = int(np.argmax(similarity_to_query)) 34 | idxs = [most_similar] 35 | selected = np.array([embedding_list[most_similar]]) 36 | 37 | while len(idxs) < min(top_k, len(embedding_list)): 38 | best_score = -np.inf 39 | idx_to_add = -1 40 | similarity_to_selected = cosine_similarity(embedding_list, selected) 41 | for i, query_score in enumerate(similarity_to_query): 42 | if i in idxs: 43 | continue 44 | redundant_score = max(similarity_to_selected[i]) 45 | equation_score = ( 46 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 47 | ) 48 | if equation_score > best_score: 49 | best_score = equation_score 50 | idx_to_add = i 51 | idxs.append(idx_to_add) 52 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 53 | return idxs 54 | 55 | @weave.op 56 | def max_marginal_relevance_search_by_vector( 57 | retrieved_results: Any, 58 | embedding: List[float], 59 | top_k: int, 60 | lambda_mult: float, 61 | ) -> List[Document]: 62 | """Return docs selected using the maximal marginal relevance. 63 | 64 | Maximal marginal relevance optimizes for similarity to query AND diversity 65 | among selected documents. 66 | 67 | Args: 68 | retrieved_results: Retrieved results from vector store. 69 | embedding: Embedding to look up documents similar to. 70 | top_k: Number of Documents to return. 71 | lambda_mult: Number between 0 and 1 that determines the degree 72 | of diversity among the results with 0 corresponding 73 | to maximum diversity and 1 to minimum diversity. 74 | 75 | Returns: 76 | List of Documents selected by maximal marginal relevance. 77 | """ 78 | query_embedding = np.array(embedding, dtype=np.float32) 79 | if query_embedding.ndim == 1: 80 | query_embedding = np.expand_dims(query_embedding, axis=0) 81 | 82 | if np.array(retrieved_results["embeddings"]).shape[0] == 1 and len(np.array(retrieved_results["embeddings"]).shape) == 3: 83 | retrieved_results["embeddings"] = retrieved_results["embeddings"][0] 84 | 85 | mmr_selected = maximal_marginal_relevance( 86 | query_embedding, 87 | retrieved_results["embeddings"], 88 | top_k=top_k, 89 | lambda_mult=lambda_mult, 90 | ) 91 | 92 | candidates = [Document(page_content=doc, metadata=meta, distance=dist) for doc, meta, dist in zip( 93 | retrieved_results["documents"], 94 | retrieved_results["metadatas"], 95 | retrieved_results["distances"])] 96 | selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] 97 | return selected_results -------------------------------------------------------------------------------- /src/wandbot/retriever/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Union 2 | 3 | import numpy as np 4 | import weave 5 | 6 | from wandbot.schema.document import Document 7 | from wandbot.utils import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 12 | 13 | 14 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 15 | """Row-wise cosine similarity between two equal-width matrices. 16 | 17 | Raises: 18 | ValueError: If the number of columns in X and Y are not the same. 19 | """ 20 | if len(X) == 0 or len(Y) == 0: 21 | logger.info("COSINE SIMILARITY: Returning empty array") 22 | return np.array([]) 23 | 24 | X = np.array(X) 25 | Y = np.array(Y) 26 | if X.shape[1] != Y.shape[1]: 27 | raise ValueError( 28 | "Number of columns in X and Y must be the same. X has shape" 29 | f"{X.shape} " 30 | f"and Y has shape {Y.shape}." 31 | ) 32 | 33 | X_norm = np.linalg.norm(X, axis=1) 34 | Y_norm = np.linalg.norm(Y, axis=1) 35 | # Ignore divide by zero errors run time warnings as those are handled below. 36 | with np.errstate(divide="ignore", invalid="ignore"): 37 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 38 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 39 | return similarity 40 | 41 | @weave.op 42 | def reciprocal_rank_fusion(results: list[list[Document]], smoothing_constant=60): 43 | """Combine multiple ranked lists using Reciprocal Rank Fusion. 44 | 45 | Implements the RRF algorithm from Cormack et al. (2009) to fuse multiple 46 | ranked lists into a single ranked list. Documents appearing in multiple 47 | lists have their reciprocal rank scores summed. 48 | 49 | Args: 50 | results: List of ranked document lists to combine 51 | smoothing_constant: Constant that controls scoring impact (default: 60). 52 | It smooths out the differences between ranks by adding a constant to 53 | the denominator in the formula 1/(rank + k). This prevents very high 54 | ranks (especially rank 1) from completely dominating the fusion results. 55 | 56 | Returns: 57 | List[Document]: Combined and reranked list of documents 58 | """ 59 | assert len(results) > 0, "No document lists passed to reciprocal rank fusion" 60 | assert any(len(docs) > 0 for docs in results), "All document lists passed to reciprocal_rank_fusion are empty" 61 | 62 | text_to_doc = {} 63 | fused_scores = {} 64 | for docs in results: 65 | # Assumes the docs are returned in sorted order of relevance 66 | for rank, doc in enumerate(docs): 67 | doc_content = doc.page_content 68 | text_to_doc[doc_content] = doc 69 | if doc_content not in fused_scores: 70 | fused_scores[doc_content] = 0.0 71 | fused_scores[doc_content] += 1 / (rank + smoothing_constant) 72 | logger.debug(f"Final fused scores count: {len(fused_scores)}") 73 | 74 | ranked_results = dict( 75 | sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) 76 | ) 77 | ranked_results = [text_to_doc[text] for text in ranked_results.keys()] 78 | logger.debug(f"Final reciprocal ranked results count: {len(ranked_results)}") 79 | return ranked_results 80 | 81 | 82 | @weave.op 83 | def _filter_similar_embeddings( 84 | embedded_documents: Matrix, 85 | similarity_fn: Callable[[Matrix, Matrix], np.ndarray], 86 | threshold: float 87 | ) -> List[int]: 88 | """Filter redundant documents based on the similarity of their embeddings.""" 89 | 90 | similarity = np.tril(similarity_fn(embedded_documents, embedded_documents), k=-1) 91 | redundant = np.where(similarity > threshold) 92 | redundant_stacked = np.column_stack(redundant) 93 | redundant_sorted = np.argsort(similarity[redundant])[::-1] 94 | included_idxs = set(range(len(embedded_documents))) 95 | 96 | for first_idx, second_idx in redundant_stacked[redundant_sorted]: 97 | if first_idx in included_idxs and second_idx in included_idxs: 98 | # Default to dropping the second document of any highly similar pair 99 | included_idxs.remove(second_idx) 100 | 101 | return sorted(included_idxs) 102 | 103 | 104 | # class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): 105 | # """Filter that drops redundant documents by comparing their embeddings.""" 106 | 107 | # embedding_function: Any 108 | # """Embeddings to use for embedding document contents.""" 109 | 110 | # similarity_fn: Callable[[Matrix, Matrix], np.ndarray] = cosine_similarity 111 | # """Similarity function for comparing documents. Function expected to take as input 112 | # two matrices (List[List[float]] or numpy arrays) and return a matrix of scores 113 | # where higher values indicate greater similarity.""" 114 | 115 | # redundant_similarity_threshold: float = 0.95 116 | # """Threshold for determining when two documents are similar enough to be considered redundant.""" 117 | 118 | # model_config = ConfigDict( 119 | # arbitrary_types_allowed=True, 120 | # ) 121 | 122 | # @weave.op 123 | # def transform_documents( 124 | # self, 125 | # documents: Sequence[Document], 126 | # **kwargs: Any 127 | # ) -> Sequence[Document]: 128 | # """Filter down documents by removing redundant ones based on embedding similarity.""" 129 | # if not documents: 130 | # return [] 131 | 132 | # embedded_documents = self.embedding_function.embed( 133 | # [doc.page_content for doc in documents] 134 | # ) 135 | 136 | # # Filter similar documents 137 | # included_idxs = _filter_similar_embeddings( 138 | # embedded_documents, 139 | # self.similarity_fn, 140 | # self.redundant_similarity_threshold 141 | # ) 142 | 143 | # return [documents[i] for i in sorted(included_idxs)] -------------------------------------------------------------------------------- /src/wandbot/schema/api_status.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from wandbot.utils import ErrorInfo 6 | 7 | 8 | class APIStatus(BaseModel): 9 | """Track status of external API calls.""" 10 | component: str = Field(description="Name of the API component (e.g., 'web_search', 'reranker')") 11 | success: bool = Field(description="Whether the API call was successful") 12 | error_info: Optional[ErrorInfo] = Field(default=None, description="Error information if the call failed") -------------------------------------------------------------------------------- /src/wandbot/schema/document.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Set 2 | 3 | from pydantic import BaseModel, Field, field_validator 4 | 5 | # Define required metadata fields for initial validation 6 | REQUIRED_METADATA_FIELDS: Set[str] = {"source", "source_type", "has_code", "id"} 7 | 8 | 9 | class Document(BaseModel): 10 | """Class for storing a piece of text and associated metadata.""" 11 | page_content: str = Field(description="String text content of the document") 12 | metadata: Dict[str, Any] = Field(default_factory=dict, description="Associated metadata") 13 | 14 | def __init__(self, page_content: str, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None: 15 | """Initialize with page_content as positional or named arg.""" 16 | super().__init__( 17 | page_content=page_content, 18 | metadata=metadata or {}, 19 | **kwargs 20 | ) 21 | 22 | @field_validator("metadata") 23 | @classmethod 24 | def validate_required_metadata(cls, metadata: Dict[str, Any]) -> Dict[str, Any]: 25 | """Validate that required metadata fields are present.""" 26 | # During ingestion/data loading, some metadata might not be set yet 27 | # so only validate if it appears we're validating a fully formed document 28 | # (if at least one required field is present, validate all required fields) 29 | if any(field in metadata for field in REQUIRED_METADATA_FIELDS): 30 | missing_fields = REQUIRED_METADATA_FIELDS - metadata.keys() 31 | if missing_fields: 32 | raise ValueError(f"Required metadata fields missing: {missing_fields}") 33 | return metadata 34 | 35 | def __str__(self) -> str: 36 | """String representation focusing on page_content and metadata.""" 37 | if self.metadata: 38 | return f"page_content='{self.page_content}' metadata={self.metadata}" 39 | return f"page_content='{self.page_content}'" 40 | 41 | def ensure_required_fields(self) -> "Document": 42 | """Ensures all required metadata fields are present with default values if needed. 43 | 44 | Returns: 45 | The document with all required fields populated (self) 46 | """ 47 | from hashlib import md5 48 | 49 | for field in REQUIRED_METADATA_FIELDS: 50 | if field not in self.metadata: 51 | if field == "id": 52 | # Generate an ID based on content and existing metadata 53 | content_str = self.page_content + str(self.metadata) 54 | self.metadata["id"] = md5(content_str.encode("utf-8")).hexdigest() 55 | elif field == "source": 56 | self.metadata["source"] = "unknown" 57 | elif field == "source_type": 58 | self.metadata["source_type"] = "unknown" 59 | elif field == "has_code": 60 | # Simple heuristic - check for code blocks or common code patterns 61 | has_code = "```" in self.page_content or "def " in self.page_content 62 | self.metadata["has_code"] = has_code 63 | 64 | return self 65 | 66 | 67 | def validate_document_metadata(documents: list[Document]) -> list[Document]: 68 | """Validates and ensures all documents have the required metadata fields. 69 | 70 | Args: 71 | documents: List of documents to validate 72 | 73 | Returns: 74 | List of documents with required fields validated and populated 75 | 76 | Raises: 77 | ValueError: If any document is missing required fields and they cannot be auto-populated 78 | """ 79 | valid_documents = [] 80 | for doc in documents: 81 | try: 82 | # Try to ensure all required fields are present 83 | doc.ensure_required_fields() 84 | valid_documents.append(doc) 85 | except Exception as e: 86 | # Re-raise with more context 87 | raise ValueError(f"Failed to validate document: {e}\nDocument: {doc}") 88 | 89 | return valid_documents -------------------------------------------------------------------------------- /src/wandbot/schema/retrieval.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from wandbot.schema.document import Document 6 | from wandbot.utils import ErrorInfo 7 | 8 | 9 | class APIStatus(BaseModel): 10 | """Track status of external API calls during retrieval.""" 11 | component: str = Field(description="Name of the API component (e.g., 'web_search', 'reranker')") 12 | success: bool = Field(description="Whether the API call was successful") 13 | error_info: Optional[ErrorInfo] = Field(default=None, description="Error information if the call failed") 14 | 15 | class RetrievalResult(BaseModel): 16 | """Standardized output format for retrievers.""" 17 | documents: List[Document] = Field(description="Retrieved documents") 18 | retrieval_info: Dict[str, Any] = Field( 19 | default_factory=dict, 20 | description="Statistics and metadata about the retrieved documents (e.g., number of docs, sources, timing)" 21 | ) 22 | 23 | def __len__(self) -> int: 24 | return len(self.documents) 25 | 26 | def __getitem__(self, idx) -> Document: 27 | return self.documents[idx] -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from openai import AsyncOpenAI 5 | 6 | from wandbot.evaluation.eval_metrics.correctness import WandBotCorrectnessEvaluator 7 | 8 | 9 | def pytest_configure(config): 10 | """Configure pytest with custom markers.""" 11 | config.addinivalue_line( 12 | "markers", 13 | "integration: marks tests that make real API calls to external services" 14 | ) 15 | 16 | @pytest.fixture(scope="function") 17 | def evaluator(): 18 | """Create an evaluator instance for testing.""" 19 | client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 20 | return WandBotCorrectnessEvaluator(client=client) 21 | 22 | def pytest_collection_modifyitems(config, items): 23 | """Modify test items in-place to ensure proper async test behavior.""" 24 | for item in items: 25 | # Add asyncio marker to all async tests 26 | if "async" in item.keywords: 27 | item.add_marker(pytest.mark.asyncio) 28 | 29 | # Skip integration tests if SKIP_INTEGRATION_TESTS is set 30 | if "integration" in item.keywords and os.getenv("SKIP_INTEGRATION_TESTS"): 31 | item.add_marker(pytest.mark.skip(reason="Integration tests are disabled")) -------------------------------------------------------------------------------- /tests/evaluation/test_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from unittest.mock import patch 3 | 4 | import httpx 5 | import pytest 6 | 7 | from wandbot.evaluation.eval import WandbotModel, get_answer, get_record, parse_text_to_json 8 | 9 | # Test data 10 | MOCK_API_RESPONSE = { 11 | "system_prompt": "test prompt", 12 | "answer": "test answer", 13 | "source_documents": "source: https://docs.wandb.ai/test\nThis is a test document", 14 | "model": "gpt-4", 15 | "total_tokens": 100, 16 | "prompt_tokens": 50, 17 | "completion_tokens": 50, 18 | "time_taken": 1.5 19 | } 20 | 21 | class MockAsyncClient: 22 | def __init__(self, response=None, error=None): 23 | self.response = response 24 | self.error = error 25 | 26 | async def __aenter__(self): 27 | return self 28 | 29 | async def __aexit__(self, exc_type, exc_val, exc_tb): 30 | pass 31 | 32 | async def post(self, *args, **kwargs): 33 | if self.error: 34 | raise self.error 35 | response = self.response 36 | await response.raise_for_status() 37 | return response 38 | 39 | class MockResponse: 40 | def __init__(self, data): 41 | self._data = data 42 | 43 | def json(self): 44 | if isinstance(self._data, dict): 45 | return self._data 46 | raise self._data 47 | 48 | async def raise_for_status(self): 49 | if isinstance(self._data, Exception): 50 | raise self._data 51 | 52 | @pytest.mark.asyncio(loop_scope="function") 53 | async def test_get_answer_success(): 54 | """Test successful API call in get_answer.""" 55 | mock_response = MockResponse(MOCK_API_RESPONSE) 56 | mock_client = MockAsyncClient(response=mock_response) 57 | 58 | with patch('httpx.AsyncClient', return_value=mock_client): 59 | result = await get_answer( 60 | question="test question", 61 | wandbot_url="http://test.url", 62 | application="test-app", 63 | language="en" 64 | ) 65 | 66 | assert json.loads(result) == MOCK_API_RESPONSE 67 | 68 | @pytest.mark.asyncio(loop_scope="function") 69 | async def test_get_answer_retry(): 70 | """Test retry behavior in get_answer.""" 71 | attempts = [] 72 | 73 | class MockClient: 74 | async def __aenter__(self): 75 | return self 76 | 77 | async def __aexit__(self, exc_type, exc_val, exc_tb): 78 | pass 79 | 80 | async def post(self, *args, **kwargs): 81 | attempts.append(1) 82 | if len(attempts) == 1: 83 | raise httpx.HTTPError("Test error") 84 | response = MockResponse(MOCK_API_RESPONSE) 85 | return response 86 | 87 | with patch('httpx.AsyncClient', return_value=MockClient()): 88 | result = await get_answer( 89 | question="test question", 90 | wandbot_url="http://test.url" 91 | ) 92 | 93 | assert json.loads(result) == MOCK_API_RESPONSE 94 | assert len(attempts) == 2 95 | 96 | @pytest.mark.asyncio(loop_scope="function") 97 | async def test_get_answer_failure(): 98 | """Test complete failure in get_answer.""" 99 | error = httpx.HTTPError("Test error") 100 | mock_client = MockAsyncClient(error=error) 101 | 102 | with patch('httpx.AsyncClient', return_value=mock_client): 103 | result = await get_answer( 104 | question="test question", 105 | wandbot_url="http://test.url" 106 | ) 107 | 108 | result_dict = json.loads(result) 109 | # The error message will contain RetryError due to the retry decorator 110 | assert "RetryError" in result_dict["error"] 111 | # Check all other fields match 112 | assert result_dict["answer"] == "" 113 | assert result_dict["system_prompt"] == "" 114 | assert result_dict["source_documents"] == "" 115 | assert result_dict["model"] == "" 116 | assert result_dict["total_tokens"] == 0 117 | assert result_dict["prompt_tokens"] == 0 118 | assert result_dict["completion_tokens"] == 0 119 | assert result_dict["time_taken"] == 0 120 | 121 | @pytest.mark.asyncio(loop_scope="function") 122 | async def test_get_record_success(): 123 | """Test successful record retrieval.""" 124 | with patch('wandbot.evaluation.eval.get_answer') as mock_get_answer: 125 | mock_get_answer.return_value = json.dumps(MOCK_API_RESPONSE) 126 | 127 | result = await get_record( 128 | question="test question", 129 | wandbot_url="http://test.url" 130 | ) 131 | 132 | assert result["system_prompt"] == "test prompt" 133 | assert result["generated_answer"] == "test answer" 134 | assert len(result["retrieved_contexts"]) == 1 135 | assert result["retrieved_contexts"][0]["source"] == "https://docs.wandb.ai/test" 136 | assert not result["has_error"] 137 | assert result["error_message"] is None 138 | 139 | @pytest.mark.asyncio(loop_scope="function") 140 | async def test_get_record_empty_response(): 141 | """Test get_record with empty API response.""" 142 | with patch('wandbot.evaluation.eval.get_answer') as mock_get_answer: 143 | mock_get_answer.return_value = json.dumps({}) 144 | 145 | result = await get_record( 146 | question="test question", 147 | wandbot_url="http://test.url" 148 | ) 149 | 150 | assert result["has_error"] 151 | assert result["error_message"] == "Unknown API error" 152 | assert result["generated_answer"] == "" 153 | 154 | @pytest.mark.asyncio(loop_scope="function") 155 | async def test_get_record_api_error(): 156 | """Test get_record with API error.""" 157 | with patch('wandbot.evaluation.eval.get_answer') as mock_get_answer: 158 | mock_get_answer.side_effect = Exception("API Error") 159 | 160 | result = await get_record( 161 | question="test question", 162 | wandbot_url="http://test.url" 163 | ) 164 | 165 | assert result["has_error"] 166 | assert "Error getting response from wandbotAPI" in result["error_message"] 167 | assert result["generated_answer"] == "" 168 | 169 | def test_parse_text_to_json(): 170 | """Test parsing of source documents text.""" 171 | text = """source: https://docs.wandb.ai/test1 172 | This is document 1 173 | source: https://docs.wandb.ai/test2 174 | This is document 2""" 175 | 176 | result = parse_text_to_json(text) 177 | 178 | assert len(result) == 2 179 | assert result[0]["source"] == "https://docs.wandb.ai/test1" 180 | assert result[0]["content"] == "This is document 1" 181 | assert result[1]["source"] == "https://docs.wandb.ai/test2" 182 | assert result[1]["content"] == "This is document 2" 183 | 184 | @pytest.mark.asyncio(loop_scope="function") 185 | async def test_wandbot_model(): 186 | """Test WandbotModel prediction.""" 187 | with patch('wandbot.evaluation.eval.get_record') as mock_get_record: 188 | mock_get_record.return_value = { 189 | "system_prompt": "test prompt", 190 | "generated_answer": "test answer", 191 | "retrieved_contexts": [], 192 | "model": "gpt-4", 193 | "total_tokens": 100, 194 | "has_error": False, 195 | "error_message": None 196 | } 197 | 198 | model = WandbotModel( 199 | language="en", 200 | application="test-app", 201 | wandbot_url="http://test.url" 202 | ) 203 | 204 | result = await model.predict("test question") 205 | 206 | assert result["generated_answer"] == "test answer" 207 | assert not result["has_error"] 208 | mock_get_record.assert_called_once_with( 209 | "test question", 210 | wandbot_url="http://test.url", 211 | application="test-app", 212 | language="en" 213 | ) -------------------------------------------------------------------------------- /tests/evaluation/test_eval_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from wandbot.evaluation.eval_config import EvalConfig, get_eval_config 7 | 8 | 9 | def test_eval_config_defaults(): 10 | """Test default values of EvalConfig.""" 11 | config = EvalConfig() 12 | assert config.lang == "en" 13 | assert config.eval_judge_model == "gpt-4-1106-preview" 14 | assert config.eval_judge_temperature == 0.1 15 | assert config.experiment_name == "wandbot-eval" 16 | assert config.evaluation_name == "wandbot-eval" 17 | assert config.n_trials == 3 18 | assert config.n_weave_parallelism == 20 19 | assert config.wandbot_url == "http://0.0.0.0:8000" 20 | assert config.wandb_entity == "wandbot" 21 | assert config.wandb_project == "wandbot-eval" 22 | assert config.debug is False 23 | assert config.n_debug_samples == 3 24 | 25 | def test_eval_dataset_property(): 26 | """Test eval_dataset property returns correct dataset based on language.""" 27 | config = EvalConfig() 28 | 29 | # Test English dataset 30 | config.lang = "en" 31 | assert "wandbot-eval/object/wandbot_eval_data:" in config.eval_dataset 32 | 33 | # Test Japanese dataset 34 | config.lang = "ja" 35 | assert "wandbot-eval-jp/object/wandbot_eval_data_jp:" in config.eval_dataset 36 | 37 | def test_get_eval_config_with_args(): 38 | """Test get_eval_config with command line arguments.""" 39 | test_args = [ 40 | "--lang", "ja", 41 | "--eval_judge_model", "gpt-4", 42 | "--eval_judge_temperature", "0.2", 43 | "--experiment_name", "test-exp", 44 | "--debug", "true" 45 | ] 46 | 47 | with patch.object(sys, 'argv', ['script.py'] + test_args): 48 | config = get_eval_config() 49 | assert config.lang == "ja" 50 | assert config.eval_judge_model == "gpt-4" 51 | assert config.eval_judge_temperature == 0.2 52 | assert config.experiment_name == "test-exp" 53 | assert config.debug is True 54 | 55 | def test_get_eval_config_invalid_args(): 56 | """Test get_eval_config with invalid arguments.""" 57 | test_args = [ 58 | "--lang", "invalid", # Invalid language 59 | "--eval_judge_temperature", "invalid" # Invalid float 60 | ] 61 | 62 | with patch.object(sys, 'argv', ['script.py'] + test_args): 63 | with pytest.raises(SystemExit): 64 | get_eval_config() 65 | 66 | def test_get_eval_config_type_validation(): 67 | """Test type validation in get_eval_config.""" 68 | test_cases = [ 69 | (["--n_trials", "abc"], "n_trials should be an integer"), 70 | (["--debug", "not_bool"], "debug should be a boolean"), 71 | (["--eval_judge_temperature", "abc"], "eval_judge_temperature should be a float"), 72 | ] 73 | 74 | for args, error_msg in test_cases: 75 | with patch.object(sys, 'argv', ['script.py'] + args): 76 | with pytest.raises(SystemExit) as exc_info: 77 | get_eval_config() 78 | assert exc_info.value.code == 2 # Standard argparse error exit code -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from pydantic import ConfigDict 2 | 3 | from wandbot.configs.chat_config import ChatConfig 4 | 5 | 6 | class TestConfig(ChatConfig): 7 | """Test configuration with minimal retry settings for faster tests""" 8 | 9 | model_config = ConfigDict( 10 | arbitrary_types_allowed=True, 11 | extra="allow" 12 | ) 13 | 14 | # Override LLM retry settings 15 | llm_max_retries: int = 1 16 | llm_retry_min_wait: int = 1 17 | llm_retry_max_wait: int = 2 18 | llm_retry_multiplier: int = 1 19 | 20 | # Override Embedding retry settings 21 | embedding_max_retries: int = 1 22 | embedding_retry_min_wait: int = 1 23 | embedding_retry_max_wait: int = 2 24 | embedding_retry_multiplier: int = 1 25 | 26 | # Override Reranker retry settings 27 | reranker_max_retries: int = 1 28 | reranker_retry_min_wait: int = 1 29 | reranker_retry_max_wait: int = 2 30 | reranker_retry_multiplier: int = 1 31 | 32 | # Override retry settings for faster tests 33 | max_retries: int = 1 # Only try once 34 | retry_min_wait: int = 1 # Wait 1 second minimum 35 | retry_max_wait: int = 2 # Wait 2 seconds maximum 36 | retry_multiplier: int = 1 # No exponential increase -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pytest 6 | from dotenv import load_dotenv 7 | 8 | from wandbot.models.embedding import VALID_COHERE_INPUT_TYPES, EmbeddingModel 9 | from wandbot.schema.api_status import APIStatus 10 | 11 | # Load environment variables from .env in project root 12 | ENV_PATH = Path(__file__).parent.parent / '.env' 13 | load_dotenv(ENV_PATH, override=True) 14 | 15 | cohere_models = [ 16 | "embed-english-v3.0", 17 | ] 18 | 19 | openai_models = [ 20 | "text-embedding-3-small", 21 | ] 22 | 23 | # Basic model creation tests 24 | @pytest.mark.parametrize("model_name", cohere_models) 25 | def test_cohere_embedding_creation(model_name): 26 | model = EmbeddingModel( 27 | provider="cohere", 28 | model_name=model_name, 29 | dimensions=1024, 30 | encoding_format="float", 31 | input_type=VALID_COHERE_INPUT_TYPES[0] # Use first valid type 32 | ) 33 | assert model.model.model_name == model_name 34 | assert model.model.dimensions == 1024 35 | assert model.model.client is not None # Just verify client exists 36 | 37 | @pytest.mark.parametrize("model_name", openai_models) 38 | def test_openai_embedding_creation(model_name): 39 | model = EmbeddingModel( 40 | provider="openai", 41 | model_name=model_name, 42 | dimensions=1536, 43 | encoding_format="float" 44 | ) 45 | assert model.model.model_name == model_name 46 | assert model.model.dimensions == 1536 47 | assert model.model.client.api_key == os.getenv("OPENAI_API_KEY") 48 | 49 | # Embedding generation tests 50 | @pytest.mark.parametrize("model_name", cohere_models) 51 | def test_cohere_embed_single(model_name): 52 | model = EmbeddingModel( 53 | provider="cohere", 54 | model_name=model_name, 55 | dimensions=1024, 56 | encoding_format="float", 57 | input_type=VALID_COHERE_INPUT_TYPES[0] 58 | ) 59 | embeddings, api_status = model.embed("This is a test sentence.") 60 | 61 | assert isinstance(embeddings, np.ndarray) 62 | assert embeddings.shape == (1024,) 63 | assert isinstance(api_status, APIStatus) 64 | assert api_status.success 65 | assert api_status.error_info is None 66 | assert api_status.component == "cohere_embedding" 67 | 68 | @pytest.mark.parametrize("model_name", cohere_models) 69 | def test_cohere_embed_batch(model_name): 70 | model = EmbeddingModel( 71 | provider="cohere", 72 | model_name=model_name, 73 | dimensions=1024, 74 | encoding_format="float", 75 | input_type=VALID_COHERE_INPUT_TYPES[0] 76 | ) 77 | texts = ["First test sentence.", "Second test sentence.", "Third test sentence."] 78 | embeddings, api_status = model.embed(texts) 79 | 80 | assert isinstance(embeddings, list) 81 | assert len(embeddings) == len(texts) 82 | assert all(isinstance(emb, np.ndarray) for emb in embeddings) 83 | assert all(emb.shape == (1024,) for emb in embeddings) 84 | assert isinstance(api_status, APIStatus) 85 | assert api_status.success 86 | assert api_status.error_info is None 87 | assert api_status.component == "cohere_embedding" 88 | 89 | @pytest.mark.parametrize("model_name", openai_models) 90 | def test_openai_embed_single(model_name): 91 | model = EmbeddingModel( 92 | provider="openai", 93 | model_name=model_name, 94 | dimensions=1536, 95 | encoding_format="float" 96 | ) 97 | embeddings, api_status = model.embed("This is a test sentence.") 98 | 99 | # Convert to numpy array if it's not already 100 | if not isinstance(embeddings, np.ndarray): 101 | embeddings = np.array(embeddings).squeeze() 102 | 103 | assert isinstance(embeddings, np.ndarray) 104 | assert embeddings.shape == (1536,) 105 | assert isinstance(api_status, APIStatus) 106 | assert api_status.success 107 | assert api_status.error_info is None 108 | assert api_status.component == "openai_embedding" 109 | 110 | @pytest.mark.parametrize("model_name", openai_models) 111 | def test_openai_embed_batch(model_name): 112 | model = EmbeddingModel( 113 | provider="openai", 114 | model_name=model_name, 115 | dimensions=1536, 116 | encoding_format="float" 117 | ) 118 | texts = ["First test sentence.", "Second test sentence.", "Third test sentence."] 119 | embeddings, api_status = model.embed(texts) 120 | 121 | # Convert to numpy arrays if they're not already 122 | if isinstance(embeddings, list): 123 | embeddings = [np.array(emb).squeeze() for emb in embeddings] 124 | 125 | assert isinstance(embeddings, list) 126 | assert len(embeddings) == len(texts) 127 | assert all(isinstance(emb, np.ndarray) for emb in embeddings) 128 | assert all(emb.shape == (1536,) for emb in embeddings) 129 | assert isinstance(api_status, APIStatus) 130 | assert api_status.success 131 | assert api_status.error_info is None 132 | assert api_status.component == "openai_embedding" 133 | 134 | # Error handling tests 135 | def test_invalid_cohere_model(): 136 | model = EmbeddingModel( 137 | provider="cohere", 138 | model_name="invalid-model", 139 | dimensions=1024, 140 | encoding_format="float", 141 | input_type=VALID_COHERE_INPUT_TYPES[0] 142 | ) 143 | embeddings, api_status = model.embed("Test sentence.") 144 | 145 | assert embeddings is None 146 | assert isinstance(api_status, APIStatus) 147 | assert not api_status.success 148 | assert api_status.error_info is not None 149 | # The error message could be about model invalidity or internal server error 150 | assert any(phrase in api_status.error_info.error_message.lower() for phrase in [ 151 | "model", "not found", "invalid", "does not exist", "internal server error" 152 | ]) 153 | assert api_status.component == "cohere_embedding" 154 | assert api_status.error_info.error_type is not None 155 | assert api_status.error_info.stacktrace is not None 156 | assert api_status.error_info.file_path is not None 157 | 158 | def test_invalid_openai_model(): 159 | model = EmbeddingModel( 160 | provider="openai", 161 | model_name="invalid-model", 162 | dimensions=1536, 163 | encoding_format="float" 164 | ) 165 | embeddings, api_status = model.embed("Test sentence.") 166 | 167 | assert embeddings is None 168 | assert isinstance(api_status, APIStatus) 169 | assert not api_status.success 170 | assert api_status.error_info is not None 171 | assert "model" in api_status.error_info.error_message.lower() 172 | assert api_status.component == "openai_embedding" 173 | assert api_status.error_info.error_type is not None 174 | assert api_status.error_info.stacktrace is not None 175 | assert api_status.error_info.file_path is not None 176 | 177 | def test_invalid_provider(): 178 | with pytest.raises(ValueError, match="Unsupported provider"): 179 | EmbeddingModel( 180 | provider="invalid", 181 | model_name="some-model", 182 | dimensions=1024, 183 | encoding_format="float" 184 | ) 185 | 186 | def test_missing_input_type_cohere(): 187 | with pytest.raises(ValueError, match="input_type must be specified for Cohere embeddings"): 188 | EmbeddingModel( 189 | provider="cohere", 190 | model_name="embed-english-v3.0", 191 | dimensions=1024, 192 | encoding_format="float" 193 | ) 194 | 195 | def test_invalid_input_type_cohere(): 196 | with pytest.raises(ValueError, match="Invalid input_type"): 197 | EmbeddingModel( 198 | provider="cohere", 199 | model_name="embed-english-v3.0", 200 | dimensions=1024, 201 | encoding_format="float", 202 | input_type="invalid_type" 203 | ) -------------------------------------------------------------------------------- /tests/test_error_propagation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from dotenv import load_dotenv 5 | 6 | from tests.test_config import TestConfig 7 | from wandbot.chat.rag import RAGPipeline 8 | from wandbot.configs.chat_config import ChatConfig 9 | from wandbot.configs.vector_store_config import VectorStoreConfig 10 | from wandbot.models.embedding import EmbeddingModel 11 | from wandbot.rag.retrieval import FusionRetrievalEngine 12 | from wandbot.retriever.chroma import ChromaVectorStore 13 | from wandbot.schema.api_status import APIStatus 14 | from wandbot.schema.document import Document 15 | 16 | # Load environment variables from .env in project root 17 | ENV_PATH = Path(__file__).parent.parent / '.env' 18 | load_dotenv(ENV_PATH, override=True) 19 | 20 | @pytest.fixture 21 | def chat_config(): 22 | # Create base config with all settings 23 | config = ChatConfig() 24 | 25 | # Override with test retry settings 26 | test_config = TestConfig() 27 | 28 | # Update only retry-related settings that exist in ChatConfig 29 | config.llm_max_retries = test_config.llm_max_retries 30 | config.llm_retry_min_wait = test_config.llm_retry_min_wait 31 | config.llm_retry_max_wait = test_config.llm_retry_max_wait 32 | config.llm_retry_multiplier = test_config.llm_retry_multiplier 33 | 34 | config.embedding_max_retries = test_config.embedding_max_retries 35 | config.embedding_retry_min_wait = test_config.embedding_retry_min_wait 36 | config.embedding_retry_max_wait = test_config.embedding_retry_max_wait 37 | config.embedding_retry_multiplier = test_config.embedding_retry_multiplier 38 | 39 | config.reranker_max_retries = test_config.reranker_max_retries 40 | config.reranker_retry_min_wait = test_config.reranker_retry_min_wait 41 | config.reranker_retry_max_wait = test_config.reranker_retry_max_wait 42 | config.reranker_retry_multiplier = test_config.reranker_retry_multiplier 43 | 44 | return config 45 | 46 | @pytest.fixture 47 | def vector_store_config(): 48 | return VectorStoreConfig() 49 | 50 | @pytest.fixture 51 | def embedding_model(vector_store_config, chat_config): 52 | return EmbeddingModel( 53 | provider=vector_store_config.embeddings_provider, 54 | model_name=vector_store_config.embeddings_model_name, 55 | dimensions=vector_store_config.embeddings_dimensions, 56 | encoding_format=vector_store_config.embeddings_encoding_format, 57 | max_retries=chat_config.embedding_max_retries 58 | ) 59 | 60 | @pytest.fixture 61 | def vector_store(embedding_model, vector_store_config, chat_config): 62 | return ChromaVectorStore( 63 | embedding_function=embedding_model, 64 | vector_store_config=vector_store_config, 65 | chat_config=chat_config 66 | ) 67 | 68 | @pytest.fixture 69 | def retrieval_engine(vector_store, chat_config): 70 | return FusionRetrievalEngine( 71 | vector_store=vector_store, 72 | chat_config=chat_config 73 | ) 74 | 75 | @pytest.fixture 76 | def rag_pipeline(vector_store, chat_config): 77 | return RAGPipeline( 78 | vector_store=vector_store, 79 | chat_config=chat_config 80 | ) 81 | 82 | def test_successful_embedding(vector_store_config, chat_config): 83 | """Test successful embedding with API status""" 84 | model = EmbeddingModel( 85 | provider=vector_store_config.embeddings_provider, 86 | model_name=vector_store_config.embeddings_model_name, 87 | dimensions=vector_store_config.embeddings_dimensions, 88 | encoding_format=vector_store_config.embeddings_encoding_format, 89 | max_retries=chat_config.embedding_max_retries 90 | ) 91 | 92 | embeddings, api_status = model.embed("test query") 93 | 94 | assert embeddings is not None 95 | assert isinstance(api_status, APIStatus) 96 | assert api_status.success 97 | assert api_status.error_info is None 98 | assert api_status.component == f"{vector_store_config.embeddings_provider}_embedding" 99 | 100 | def test_invalid_embedding_model(vector_store_config, chat_config): 101 | """Test error propagation with invalid embedding model""" 102 | model = EmbeddingModel( 103 | provider=vector_store_config.embeddings_provider, 104 | model_name="invalid-model", 105 | dimensions=vector_store_config.embeddings_dimensions, 106 | encoding_format=vector_store_config.embeddings_encoding_format, 107 | max_retries=chat_config.embedding_max_retries 108 | ) 109 | 110 | embeddings, api_status = model.embed("test query") 111 | 112 | assert embeddings is None 113 | assert isinstance(api_status, APIStatus) 114 | assert not api_status.success 115 | assert api_status.error_info is not None 116 | assert "does not exist" in api_status.error_info.error_message.lower() 117 | assert api_status.component == f"{vector_store_config.embeddings_provider}_embedding" 118 | assert api_status.error_info.stacktrace is not None 119 | assert api_status.error_info.file_path is not None 120 | 121 | @pytest.mark.asyncio 122 | async def test_successful_reranking(retrieval_engine, chat_config): 123 | """Test successful reranking with API status""" 124 | docs = [ 125 | Document(page_content="test doc 1", metadata={"id": "1"}), 126 | Document(page_content="test doc 2", metadata={"id": "2"}) 127 | ] 128 | 129 | reranked_docs, api_status = await retrieval_engine._async_rerank_results( 130 | query="test query", 131 | context=docs, 132 | top_k=chat_config.top_k, 133 | language="en" 134 | ) 135 | 136 | assert len(reranked_docs) == 2 137 | assert isinstance(api_status, APIStatus) 138 | assert api_status.success 139 | assert api_status.error_info is None 140 | assert api_status.component == "reranker_api" 141 | 142 | @pytest.mark.asyncio 143 | async def test_invalid_reranker_model(retrieval_engine, chat_config): 144 | """Test error propagation with invalid reranker model""" 145 | # Save original model name 146 | original_model = chat_config.english_reranker_model 147 | chat_config.english_reranker_model = "invalid-model" 148 | 149 | docs = [ 150 | Document(page_content="test doc 1", metadata={"id": "1"}), 151 | Document(page_content="test doc 2", metadata={"id": "2"}) 152 | ] 153 | 154 | reranked_docs, api_status = await retrieval_engine._async_rerank_results( 155 | query="test query", 156 | context=docs, 157 | top_k=chat_config.top_k, 158 | language="en" 159 | ) 160 | 161 | # Restore original model name 162 | chat_config.english_reranker_model = original_model 163 | 164 | assert len(reranked_docs) == 0 # Empty list returned on error 165 | assert isinstance(api_status, APIStatus) 166 | assert not api_status.success 167 | assert api_status.error_info is not None 168 | assert api_status.component == "reranker_api" 169 | assert api_status.error_info.stacktrace is not None 170 | assert api_status.error_info.file_path is not None 171 | 172 | def test_error_propagation_in_retrieval(retrieval_engine, chat_config): 173 | """Test error propagation through the retrieval pipeline using MMR search""" 174 | # Save original model name 175 | original_model = chat_config.english_reranker_model 176 | chat_config.english_reranker_model = "invalid-model" 177 | 178 | inputs = { 179 | "standalone_query": "test query", 180 | "all_queries": ["test query"], 181 | "language": "en" 182 | } 183 | 184 | results = retrieval_engine.vectorstore.max_marginal_relevance_search( 185 | query_texts=inputs["all_queries"], 186 | top_k=chat_config.top_k, 187 | fetch_k=chat_config.fetch_k, 188 | lambda_mult=chat_config.mmr_lambda_mult 189 | ) 190 | 191 | # Restore original model name 192 | chat_config.english_reranker_model = original_model 193 | 194 | # Check API status objects are properly propagated 195 | assert "_embedding_status" in results 196 | assert isinstance(results["_embedding_status"], APIStatus) 197 | assert results["_embedding_status"].component == f"{retrieval_engine.vectorstore.embedding_function.model.COMPONENT_NAME}" 198 | 199 | # Check that we got results 200 | assert len(results) > 0 201 | 202 | def test_error_propagation_in_similarity_search(retrieval_engine, chat_config): 203 | """Test error propagation through the retrieval pipeline using similarity search""" 204 | # Save original model name 205 | original_model = chat_config.english_reranker_model 206 | chat_config.english_reranker_model = "invalid-model" 207 | 208 | inputs = { 209 | "standalone_query": "test query", 210 | "all_queries": ["test query"], 211 | "language": "en" 212 | } 213 | 214 | results = retrieval_engine.vectorstore.similarity_search( 215 | query_texts=inputs["all_queries"], 216 | top_k=chat_config.top_k 217 | ) 218 | 219 | # Restore original model name 220 | chat_config.english_reranker_model = original_model 221 | 222 | # Check that we got results for our query 223 | assert inputs["standalone_query"] in results 224 | assert len(results[inputs["standalone_query"]]) > 0 225 | 226 | # Check API status is properly propagated 227 | assert "_embedding_status" in results 228 | assert isinstance(results["_embedding_status"], APIStatus) 229 | assert results["_embedding_status"].component == f"{retrieval_engine.vectorstore.embedding_function.model.COMPONENT_NAME}" -------------------------------------------------------------------------------- /tests/test_model_config.py: -------------------------------------------------------------------------------- 1 | """Test configurations for language models.""" 2 | 3 | from typing import List 4 | 5 | # Available Models for Testing 6 | ANTHROPIC_MODELS: List[str] = [ 7 | "claude-3-5-sonnet-20241022", 8 | "claude-3-5-haiku-20241022" 9 | ] 10 | 11 | OPENAI_MODELS: List[str] = [ 12 | "gpt-4-1106-preview", 13 | "gpt-4o-mini-2024-07-18", 14 | "gpt-4o-2024-08-06", 15 | "o1-2024-12-17", 16 | # "o1-mini-2024-09-12", 17 | "o3-mini-2025-01-31" 18 | ] 19 | 20 | # Default configurations for each provider in tests 21 | MODEL_CONFIGS = { 22 | "anthropic": { 23 | "provider": "anthropic", 24 | "temperature": 0.7, 25 | }, 26 | "openai": { 27 | "provider": "openai", 28 | "temperature": 0.7, 29 | } 30 | } 31 | 32 | # Test configurations 33 | TEST_CONFIG = { 34 | "primary": { 35 | "provider": "anthropic", 36 | "model_name": "claude-3-5-sonnet-20241022", 37 | "temperature": 0.7, 38 | }, 39 | "fallback": { 40 | "provider": "openai", 41 | "model_name": "gpt-4-1106-preview", 42 | "temperature": 0.7, 43 | } 44 | } 45 | 46 | # Invalid model configurations for testing fallback behavior 47 | TEST_INVALID_MODELS = { 48 | "primary": { 49 | "provider": "anthropic", 50 | "model_name": "invalid-model-1", 51 | "temperature": 0.7, 52 | }, 53 | "fallback": { 54 | "provider": "openai", 55 | "model_name": "invalid-model-2", 56 | "temperature": 0.7, 57 | } 58 | } -------------------------------------------------------------------------------- /tests/test_query_handler.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, patch 2 | 3 | import pytest 4 | from pydantic import ValidationError 5 | 6 | from wandbot.rag.query_handler import EnhancedQuery, Labels, QueryEnhancer, clean_question, format_chat_history 7 | from wandbot.schema.api_status import APIStatus, ErrorInfo 8 | 9 | 10 | @pytest.fixture 11 | def mock_llm_model(): 12 | with patch('wandbot.rag.query_handler.LLMModel') as mock: 13 | instance = mock.return_value 14 | instance.create = AsyncMock() 15 | yield instance 16 | 17 | @pytest.fixture 18 | def query_enhancer(mock_llm_model): 19 | return QueryEnhancer( 20 | model_name="gpt-4", 21 | temperature=0, 22 | fallback_model_name="gpt-3.5-turbo", 23 | fallback_temperature=0 24 | ) 25 | 26 | @pytest.fixture 27 | def sample_enhanced_query(): 28 | return { 29 | "language": "en", 30 | "intents": [ 31 | { 32 | "reasoning": "The query is about using W&B features", 33 | "label": Labels.PRODUCT_FEATURES 34 | } 35 | ], 36 | "keywords": [ 37 | {"keyword": "wandb features"} 38 | ], 39 | "sub_queries": [ 40 | {"query": "How to use wandb?"} 41 | ], 42 | "vector_search_queries": [ 43 | {"query": "wandb basic usage"} 44 | ], 45 | "standalone_query": "How do I use wandb?" 46 | } 47 | 48 | def test_clean_question(): 49 | assert clean_question("<@U123456> hello") == "hello" 50 | assert clean_question("@bot how are you?") == "how are you?" 51 | assert clean_question("regular question") == "regular question" 52 | 53 | def test_format_chat_history(): 54 | history = [ 55 | ("Hello", "Hi there!"), 56 | ("How are you?", "I'm good!") 57 | ] 58 | formatted = format_chat_history(history) 59 | assert "User: Hello" in formatted 60 | assert "Assistant: Hi there!" in formatted 61 | assert "User: How are you?" in formatted 62 | assert "Assistant: I'm good!" in formatted 63 | 64 | assert format_chat_history(None) == "No chat history available." 65 | assert format_chat_history([]) == "No chat history available." 66 | 67 | @pytest.mark.asyncio 68 | async def test_query_enhancer_success(query_enhancer, mock_llm_model, sample_enhanced_query): 69 | # Mock successful response 70 | enhanced_query = EnhancedQuery(**sample_enhanced_query) 71 | api_status = APIStatus(component="query_enhancer_llm_api", success=True) 72 | mock_llm_model.create.return_value = (enhanced_query, api_status) 73 | 74 | result = await query_enhancer({"query": "How do I use wandb?"}) 75 | 76 | assert result is not None 77 | assert result["api_statuses"]["query_enhancer_llm_api"].success 78 | assert result["api_statuses"]["query_enhancer_llm_api"].error_info is None 79 | 80 | @pytest.mark.asyncio 81 | async def test_query_enhancer_validation_error_retry(query_enhancer, mock_llm_model, sample_enhanced_query): 82 | # Mock validation error then success 83 | enhanced_query = EnhancedQuery(**sample_enhanced_query) 84 | error_status = APIStatus(component="query_enhancer_llm_api", success=False, error_info=ErrorInfo(has_error=True, error_message="Validation error")) 85 | success_status = APIStatus(component="query_enhancer_llm_api", success=True) 86 | 87 | mock_llm_model.create.side_effect = [ 88 | (None, error_status), 89 | (enhanced_query, success_status) 90 | ] 91 | 92 | result = await query_enhancer({"query": "How do I use wandb?"}) 93 | 94 | assert result is not None 95 | assert result["api_statuses"]["query_enhancer_llm_api"].success 96 | assert result["api_statuses"]["query_enhancer_llm_api"].error_info is None 97 | 98 | @pytest.mark.asyncio 99 | async def test_query_enhancer_llm_error(query_enhancer, mock_llm_model): 100 | # Mock LLM error that persists through retries 101 | error_status = APIStatus(component="query_enhancer_llm_api", success=False, error_info=ErrorInfo(has_error=True, error_message="API error")) 102 | mock_llm_model.create.return_value = (None, error_status) 103 | 104 | # Mock the fallback model to also fail 105 | query_enhancer.fallback_model.create.return_value = (None, error_status) 106 | 107 | with pytest.raises(Exception) as exc_info: 108 | await query_enhancer({"query": "How do I use wandb?"}) 109 | 110 | # The error will be wrapped in a retry error message 111 | assert "API error" in str(exc_info.value) 112 | 113 | @pytest.mark.asyncio 114 | async def test_query_enhancer_fallback(query_enhancer, mock_llm_model, sample_enhanced_query): 115 | # Mock primary model failure and fallback success 116 | primary_model = mock_llm_model 117 | fallback_model = AsyncMock() 118 | 119 | enhanced_query = EnhancedQuery(**sample_enhanced_query) 120 | error_status = APIStatus(component="query_enhancer_llm_api", success=False, error_info=ErrorInfo(has_error=True, error_message="Primary failed")) 121 | success_status = APIStatus(component="query_enhancer_llm_api", success=True) 122 | 123 | primary_model.create.return_value = (None, error_status) 124 | fallback_model.create.return_value = (enhanced_query, success_status) 125 | 126 | query_enhancer.fallback_model = fallback_model 127 | 128 | result = await query_enhancer({"query": "How do I use wandb?"}) 129 | 130 | assert result is not None 131 | assert result["api_statuses"]["query_enhancer_llm_api"].success 132 | assert result["api_statuses"]["query_enhancer_llm_api"].error_info is None -------------------------------------------------------------------------------- /tests/test_response_synthesis.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Dict, List 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import pytest 6 | 7 | from tests.test_model_config import TEST_CONFIG, TEST_INVALID_MODELS 8 | from wandbot.models.llm import LLMModel 9 | from wandbot.rag.response_synthesis import ResponseSynthesizer 10 | from wandbot.schema.api_status import APIStatus 11 | from wandbot.schema.document import Document 12 | from wandbot.schema.retrieval import RetrievalResult 13 | 14 | 15 | @pytest.fixture 16 | def mock_retrieval_result(): 17 | return RetrievalResult( 18 | documents=[Document( 19 | page_content="Test content", 20 | metadata={ 21 | "source": "test_source.md", 22 | "source_type": "documentation", 23 | "has_code": False 24 | } 25 | )], 26 | retrieval_info={ 27 | "query": "How do I use wandb?", 28 | "language": "en", 29 | "intents": ["test intent"], 30 | "sub_queries": [] 31 | } 32 | ) 33 | 34 | @pytest.fixture 35 | def mock_api_status(): 36 | return APIStatus( 37 | success=True, 38 | error_info=None, 39 | request_id="test_id", 40 | model_info={"model": "test-model"}, 41 | component="response_synthesis" 42 | ) 43 | 44 | @pytest.fixture 45 | def mock_llm_model(): 46 | with patch('wandbot.rag.response_synthesis.LLMModel') as mock: 47 | instance = mock.return_value 48 | instance.model_name = "test_model" 49 | instance.create = AsyncMock() 50 | yield instance 51 | 52 | @pytest.fixture 53 | def synthesizer(): 54 | synth = ResponseSynthesizer( 55 | primary_provider=TEST_CONFIG["primary"]["provider"], 56 | primary_model_name=TEST_CONFIG["primary"]["model_name"], 57 | primary_temperature=TEST_CONFIG["primary"]["temperature"], 58 | fallback_provider=TEST_CONFIG["fallback"]["provider"], 59 | fallback_model_name=TEST_CONFIG["fallback"]["model_name"], 60 | fallback_temperature=TEST_CONFIG["fallback"]["temperature"], 61 | max_retries=1 # Reduce retries for faster tests 62 | ) 63 | 64 | # Mock only the get_messages method to avoid prompt template issues 65 | synth.get_messages = MagicMock(return_value=[ 66 | {"role": "system", "content": "You are a helpful AI assistant."}, 67 | {"role": "user", "content": "What is 2+2?"} 68 | ]) 69 | 70 | return synth 71 | 72 | @pytest.mark.asyncio 73 | async def test_successful_primary_model(synthesizer, mock_retrieval_result): 74 | """Test when primary model succeeds""" 75 | result = await synthesizer(mock_retrieval_result) 76 | 77 | assert result["response"] is not None 78 | assert result["response_model"] == TEST_CONFIG["primary"]["model_name"] 79 | assert "response_synthesis_llm_messages" in result 80 | assert result["api_statuses"]["response_synthesis_llm_api"].success == True 81 | assert isinstance(result["response"], str) 82 | assert len(result["response"]) > 0 83 | 84 | @pytest.mark.asyncio 85 | async def test_fallback_to_secondary_model(synthesizer, mock_retrieval_result): 86 | """Test fallback when primary model fails""" 87 | # Make the primary model fail by using an invalid model name 88 | synthesizer.model = LLMModel( 89 | provider=TEST_INVALID_MODELS["primary"]["provider"], 90 | model_name=TEST_INVALID_MODELS["primary"]["model_name"], 91 | temperature=TEST_INVALID_MODELS["primary"]["temperature"], 92 | max_retries=1 93 | ) 94 | 95 | result = await synthesizer(mock_retrieval_result) 96 | 97 | assert result["response"] is not None 98 | assert result["response_model"] == TEST_CONFIG["fallback"]["model_name"] 99 | assert "response_synthesis_llm_messages" in result 100 | assert result["api_statuses"]["response_synthesis_llm_api"].success == True 101 | assert isinstance(result["response"], str) 102 | assert len(result["response"]) > 0 103 | 104 | @pytest.mark.asyncio 105 | async def test_both_models_fail(synthesizer, mock_retrieval_result): 106 | """Test behavior when both models fail""" 107 | # Make both models fail by using invalid model names 108 | synthesizer.model = LLMModel( 109 | provider=TEST_INVALID_MODELS["primary"]["provider"], 110 | model_name=TEST_INVALID_MODELS["primary"]["model_name"], 111 | temperature=TEST_INVALID_MODELS["primary"]["temperature"], 112 | max_retries=1 113 | ) 114 | synthesizer.fallback_model = LLMModel( 115 | provider=TEST_INVALID_MODELS["fallback"]["provider"], 116 | model_name=TEST_INVALID_MODELS["fallback"]["model_name"], 117 | temperature=TEST_INVALID_MODELS["fallback"]["temperature"], 118 | max_retries=1 119 | ) 120 | 121 | with pytest.raises(Exception) as exc_info: 122 | await synthesizer(mock_retrieval_result) 123 | 124 | assert "Response synthesis failed" in str(exc_info.value) 125 | 126 | @pytest.mark.asyncio 127 | async def test_input_formatting(synthesizer, mock_retrieval_result): 128 | """Test that inputs are formatted correctly""" 129 | result = await synthesizer(mock_retrieval_result) 130 | 131 | # Check that query and context are formatted correctly 132 | assert "query_str" in result 133 | assert "context_str" in result 134 | assert mock_retrieval_result.retrieval_info["query"] in result["query_str"] 135 | assert "Test content" in result["context_str"] 136 | 137 | # Verify the response is meaningful 138 | assert result["response"] is not None 139 | assert isinstance(result["response"], str) 140 | assert len(result["response"]) > 0 141 | 142 | if __name__ == "__main__": 143 | asyncio.run(test_successful_primary_model(synthesizer, mock_retrieval_result)) -------------------------------------------------------------------------------- /tests/test_retrieval.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | 5 | from wandbot.configs.chat_config import ChatConfig 6 | from wandbot.rag.retrieval import FusionRetrievalEngine 7 | from wandbot.retriever.base import VectorStore 8 | 9 | 10 | @pytest.fixture 11 | def retrieval_engine(): 12 | chat_config = ChatConfig() 13 | return FusionRetrievalEngine(chat_config=chat_config) 14 | 15 | @pytest.mark.asyncio 16 | async def test_retrieval_input_validation(retrieval_engine): 17 | # Test missing all_queries 18 | with pytest.raises(ValueError, match="Missing required key 'all_queries'"): 19 | await retrieval_engine._run_retrieval_common({"standalone_query": "test"}, use_async=True) 20 | 21 | # Test missing standalone_query 22 | with pytest.raises(ValueError, match="Missing required key 'standalone_query'"): 23 | await retrieval_engine._run_retrieval_common({"all_queries": ["test"]}, use_async=True) 24 | 25 | # Test wrong type for all_queries 26 | with pytest.raises(TypeError, match="Expected 'all_queries' to be a list or tuple"): 27 | await retrieval_engine._run_retrieval_common( 28 | {"all_queries": "not a list", "standalone_query": "test"}, 29 | use_async=True 30 | ) 31 | 32 | # Test wrong type for standalone_query 33 | with pytest.raises(TypeError, match="Expected 'standalone_query' to be a string"): 34 | await retrieval_engine._run_retrieval_common( 35 | {"all_queries": ["test"], "standalone_query": ["not a string"]}, 36 | use_async=True 37 | ) 38 | 39 | # Test non-dict input 40 | with pytest.raises(TypeError, match="Expected inputs to be a dictionary"): 41 | await retrieval_engine._run_retrieval_common("not a dict", use_async=True) --------------------------------------------------------------------------------