├── mcpb ├── requirements.txt ├── .gitignore ├── .dxtignore ├── icon.png ├── README.md ├── manifest.json └── server.py ├── kumo_rfm_mcp ├── _version.py ├── tools │ ├── __init__.py │ ├── auth.py │ ├── io.py │ ├── docs.py │ ├── model.py │ └── graph.py ├── __init__.py ├── session.py ├── resources │ ├── overview.md │ ├── explainability.md │ ├── graph-setup.md │ └── predictive-query.md ├── server.py └── config.py ├── .gitignore ├── test ├── tools │ ├── test_docs.py │ ├── test_io.py │ ├── test_model.py │ └── test_graph.py └── conftest.py ├── .github └── workflows │ ├── mcpb.yml │ ├── test.yml │ ├── lint.yml │ ├── _release.yml │ └── release.yml ├── LICENSE ├── .pre-commit-config.yaml ├── pyproject.toml └── README.md /mcpb/requirements.txt: -------------------------------------------------------------------------------- 1 | kumo-rfm-mcp==0.2.0 2 | -------------------------------------------------------------------------------- /mcpb/.gitignore: -------------------------------------------------------------------------------- 1 | *.dxt 2 | sys_venv/ 3 | uv_venv/ 4 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.3.0.dev0' 2 | -------------------------------------------------------------------------------- /mcpb/.dxtignore: -------------------------------------------------------------------------------- 1 | README.md 2 | sys_venv/ 3 | uv_venv/ 4 | -------------------------------------------------------------------------------- /mcpb/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumo-ai/kumo-rfm-mcp/HEAD/mcpb/icon.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[oc] 3 | build/ 4 | dist/ 5 | wheels/ 6 | *.egg-info 7 | .venv 8 | .env 9 | .mypy_cache/ 10 | tmp/ 11 | mcp.json 12 | server-info.json 13 | .DS_Store 14 | -------------------------------------------------------------------------------- /mcpb/README.md: -------------------------------------------------------------------------------- 1 | # MCPB Packaging 2 | 3 | 1. Install: 4 | ```bash 5 | npm install -g @anthropic-ai/dxt 6 | ``` 7 | 1. Create `dxt` file: 8 | ```bash 9 | dxt pack 10 | ``` 11 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .docs import register_docs_tools 2 | from .auth import register_auth_tools 3 | from .io import register_io_tools 4 | from .graph import register_graph_tools 5 | from .model import register_model_tools 6 | 7 | __all__ = [ 8 | 'register_docs_tools', 9 | 'register_auth_tools', 10 | 'register_io_tools', 11 | 'register_graph_tools', 12 | 'register_model_tools', 13 | ] 14 | -------------------------------------------------------------------------------- /test/tools/test_docs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastmcp import Client 3 | 4 | from kumo_rfm_mcp.server import mcp 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_get_docs() -> None: 9 | async with Client(mcp) as client: 10 | result = await client.call_tool('get_docs', { 11 | 'resource_uri': 'kumo://docs/overview', 12 | }) 13 | assert result.content[0].text.startswith('# Overview of KumoRFM\n') 14 | 15 | result = await client.call_tool( 16 | 'get_docs', { 17 | 'resource_uri': 'kumo://docs/graph-setup', 18 | }) 19 | assert result.content[0].text.startswith('# Graph Setup\n') 20 | 21 | result = await client.call_tool( 22 | 'get_docs', { 23 | 'resource_uri': 'kumo://docs/predictive-query', 24 | }) 25 | assert result.content[0].text.startswith('# Predictive Query\n') 26 | -------------------------------------------------------------------------------- /.github/workflows/mcpb.yml: -------------------------------------------------------------------------------- 1 | name: MCPBundle 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} 10 | # Only cancel intermediate builds if on a PR: 11 | cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} 12 | 13 | jobs: 14 | mcpb: 15 | runs-on: ${{ matrix.os }} 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ubuntu-latest, windows-latest, macos-latest] 21 | 22 | steps: 23 | - name: Checkout repository 24 | uses: actions/checkout@v5 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | 31 | - name: Run MCPB 32 | run: python mcpb/server.py 33 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | from .config import ( 3 | TableSource, 4 | TableSourcePreview, 5 | TableMetadata, 6 | AddTableMetadata, 7 | UpdateTableMetadata, 8 | LinkMetadata, 9 | GraphMetadata, 10 | UpdateGraphMetadata, 11 | UpdatedGraphMetadata, 12 | MaterializedGraphInfo, 13 | PredictResponse, 14 | EvaluateResponse, 15 | ExplanationResponse, 16 | ) 17 | from .session import Session, SessionManager 18 | 19 | __all__ = [ 20 | '__version__', 21 | 'TableSource', 22 | 'TableSourcePreview', 23 | 'TableMetadata', 24 | 'AddTableMetadata', 25 | 'UpdateTableMetadata', 26 | 'LinkMetadata', 27 | 'GraphMetadata', 28 | 'UpdateGraphMetadata', 29 | 'UpdatedGraphMetadata', 30 | 'MaterializedGraphInfo', 31 | 'PredictResponse', 32 | 'EvaluateResponse', 33 | 'ExplanationResponse', 34 | 'Session', 35 | 'SessionManager', 36 | ] 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} 10 | # Only cancel intermediate builds if on a PR: 11 | cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} 12 | 13 | jobs: 14 | pytest: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v5 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: '3.10' 25 | 26 | - name: Install uv 27 | uses: astral-sh/setup-uv@v6 28 | with: 29 | python-version: '3.10' 30 | activate-environment: true 31 | 32 | - name: Install Kumo RFM MCP 33 | run: uv pip install '.[dev]' 34 | 35 | - name: Run test suite 36 | run: uv run --no-project pytest 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Kumo.ai, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the " Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice (including the next paragraph) shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} 10 | # Only cancel intermediate builds if on a PR: 11 | cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} 12 | 13 | jobs: 14 | mypy: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v5 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: '3.10' 25 | 26 | - name: Install uv 27 | uses: astral-sh/setup-uv@v6 28 | with: 29 | python-version: '3.10' 30 | activate-environment: true 31 | 32 | - name: Install Kumo RFM MCP 33 | run: uv pip install . 34 | 35 | - name: Install dependencies 36 | run: uv pip install mypy 37 | 38 | - name: Check type hints 39 | run: uv run --no-project mypy 40 | -------------------------------------------------------------------------------- /test/tools/test_io.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from kumo_rfm_mcp.tools.io import find_table_files, inspect_table_files 6 | 7 | 8 | @pytest.mark.asyncio 9 | @pytest.mark.parametrize('recursive', [False, True]) 10 | async def test_find_table_files(root_dir: Path, recursive: bool) -> None: 11 | sources = await find_table_files(root_dir, recursive) 12 | assert len(sources) == 3 13 | filenames = {source.path.name for source in sources} 14 | assert filenames == {'USERS.csv', 'ORDERS.parquet', 'STORES.csv'} 15 | 16 | 17 | @pytest.mark.asyncio 18 | async def test_inspect_table_files(root_dir: Path) -> None: 19 | previews = await inspect_table_files( 20 | paths=[(root_dir / 'USERS.csv').as_posix()], 21 | num_rows=4, 22 | ) 23 | assert len(previews) == 1 24 | preview = previews[(root_dir / 'USERS.csv').as_posix()] 25 | assert preview.rows == [ 26 | { 27 | 'USER_ID': 0, 28 | 'AGE': 20.0, 29 | 'GENDER': 'male', 30 | }, 31 | { 32 | 'USER_ID': 1, 33 | 'AGE': 30.0, 34 | 'GENDER': 'female', 35 | }, 36 | { 37 | 'USER_ID': 2, 38 | 'AGE': 40.0, 39 | 'GENDER': 'female', 40 | }, 41 | { 42 | 'USER_ID': 3, 43 | 'AGE': None, 44 | 'GENDER': None, 45 | }, 46 | ] 47 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/tools/auth.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import Literal 4 | 5 | from fastmcp import FastMCP 6 | from fastmcp.exceptions import ToolError 7 | from kumoai.experimental import rfm 8 | 9 | from kumo_rfm_mcp import SessionManager 10 | 11 | 12 | async def authenticate( 13 | ) -> Literal["KumoRFM session successfully authenticated"]: 14 | """Authenticate the current KumoRFM session. 15 | 16 | Authentication is needed once before predicting or evaluating with the 17 | KumoRFM model. If the 'KUMO_API_KEY' environment variable is not set, 18 | initiates an OAuth2 authentication flow by opening a browser window for 19 | user login. Sets the 'KUMO_API_KEY' environment variable upon successful 20 | authentication. 21 | """ 22 | session = SessionManager.get_default_session() 23 | 24 | if session.is_initialized: 25 | raise ToolError("KumoRFM session is already authenticated") 26 | 27 | if os.getenv('KUMO_API_KEY') in {None, '', '${user_config.KUMO_API_KEY}'}: 28 | try: 29 | await asyncio.to_thread(rfm.authenticate) 30 | except Exception as e: 31 | raise ToolError( 32 | f"Failed to authenticate KumoRFM session: {e}") from e 33 | 34 | session.initialize() 35 | return "KumoRFM session successfully authenticated" 36 | 37 | 38 | def register_auth_tools(mcp: FastMCP) -> None: 39 | """Register all authentication tools to the MCP server.""" 40 | mcp.tool(annotations=dict( 41 | title="🔑 Signing in to KumoRFM…", 42 | readOnlyHint=False, 43 | destructiveHint=False, 44 | idempotentHint=False, 45 | openWorldHint=False, 46 | ))(authenticate) 47 | -------------------------------------------------------------------------------- /mcpb/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "dxt_version": "0.1", 3 | "name": "kumo-rfm-mcp", 4 | "display_name": "KumoRFM", 5 | "version": "0.1.0", 6 | "description": "KumoRFM MCP Server", 7 | "long_description": "KumoRFM is a pre-trained Relational Foundation Model (RFM) that generates training-free predictions on any relational multi-table data by interpreting the data as a (temporal) heterogeneous graph. It can be queried via the Predictive Query Language (PQL).", 8 | "author": { 9 | "name": "Kumo.AI", 10 | "email": "hello@kumo.ai" 11 | }, 12 | "repository": { 13 | "type": "git", 14 | "url": "https://github.com/kumo-ai/kumo-rfm-mcp" 15 | }, 16 | "homepage": "https://kumorfm.ai", 17 | "documentation": "https://kumo-ai.github.io/kumo-sdk/docs", 18 | "support": "https://github.com/kumo-ai/kumo-rfm-mcp/issues", 19 | "icon": "icon.png", 20 | "server": { 21 | "type": "python", 22 | "entry_point": "server.py", 23 | "mcp_config": { 24 | "command": "python", 25 | "args": ["${__dirname}/server.py"], 26 | "env": { 27 | "KUMO_API_KEY": "${user_config.KUMO_API_KEY}" 28 | }, 29 | "platform_overrides": { 30 | "darwin": { 31 | "command": "python3" 32 | } 33 | } 34 | } 35 | }, 36 | "keywords": [ 37 | "rfm", 38 | "mcp", 39 | "foundation model", 40 | "relational data", 41 | "kumo" 42 | ], 43 | "license": "MIT", 44 | "user_config": { 45 | "KUMO_API_KEY": { 46 | "type": "string", 47 | "title": "KumoRFM API key", 48 | "description": "You can generate your KumoRFM API key for free at https://kumorfm.ai. If not provided, you can authenticate on-the-fly in individual sessions.", 49 | "required": false, 50 | "sensitive": true 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /mcpb/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import subprocess 4 | import sys 5 | from pathlib import Path 6 | 7 | ROOT = Path(__file__).parent 8 | SYS_VENV = ROOT / 'sys_venv' 9 | UV_VENV = ROOT / 'uv_venv' 10 | REQUIREMENTS = ROOT / 'requirements.txt' 11 | 12 | 13 | def get_venv() -> Path: 14 | """Create two virtual environments for KumoRFM MCP. 15 | 1. A venv that inherits system Python version and installs `uv`. 16 | 2. A `uv` venv with Python 3.11. 17 | """ 18 | if os.name == 'nt': 19 | sys_python = SYS_VENV / 'Scripts' / 'python.exe' 20 | uv_python = UV_VENV / 'Scripts' / 'python.exe' 21 | else: 22 | sys_python = SYS_VENV / 'bin' / 'python' 23 | uv_python = UV_VENV / 'bin' / 'python' 24 | 25 | if not SYS_VENV.exists(): 26 | subprocess.check_call( 27 | [sys.executable, '-m', 'venv', SYS_VENV], 28 | stdout=subprocess.DEVNULL, 29 | ) 30 | 31 | if not UV_VENV.exists(): 32 | subprocess.check_call( 33 | [str(sys_python), '-m', 'pip', 'install', 'uv'], 34 | stdout=subprocess.DEVNULL, 35 | ) 36 | subprocess.check_call( 37 | [ 38 | str(sys_python), '-m', 'uv', 'venv', '--python', '3.11', 39 | str(UV_VENV) 40 | ], 41 | stdout=subprocess.DEVNULL, 42 | ) 43 | 44 | subprocess.check_call( 45 | [ 46 | str(sys_python), '-m', 'uv', 'pip', 'install', '-r', 47 | str(REQUIREMENTS), '--python', 48 | str(uv_python) 49 | ], 50 | stdout=subprocess.DEVNULL, 51 | ) 52 | 53 | return uv_python 54 | 55 | 56 | if __name__ == '__main__': 57 | python = get_venv() 58 | sys.exit(subprocess.call([str(python), '-m', 'kumo_rfm_mcp.server'])) 59 | -------------------------------------------------------------------------------- /test/tools/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from kumoapi.rfm import Explanation 3 | 4 | from kumo_rfm_mcp import UpdateGraphMetadata 5 | from kumo_rfm_mcp.tools.graph import materialize_graph, update_graph_metadata 6 | from kumo_rfm_mcp.tools.model import evaluate, explain, predict 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_predict(graph: UpdateGraphMetadata) -> None: 11 | update_graph_metadata(graph) 12 | await materialize_graph() 13 | 14 | out = await predict( 15 | 'PREDICT USERS.AGE>20 FOR EACH USERS.USER_ID', 16 | indices=[0], 17 | anchor_time=None, 18 | run_mode='fast', 19 | num_neighbors=[16, 16], 20 | max_pq_iterations=20, 21 | ) 22 | assert len(out.predictions) == 1 23 | assert len(out.logs) == 0 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_evaluate(graph: UpdateGraphMetadata) -> None: 28 | update_graph_metadata(graph) 29 | await materialize_graph() 30 | 31 | out = await evaluate( 32 | 'PREDICT USERS.AGE>20 FOR EACH USERS.USER_ID', 33 | metrics=None, 34 | anchor_time=None, 35 | run_mode='fast', 36 | num_neighbors=[16, 16], 37 | max_pq_iterations=20, 38 | ) 39 | assert set(out.metrics.keys()) == {'ap', 'auprc', 'auroc'} 40 | assert len(out.logs) == 0 41 | 42 | 43 | @pytest.mark.asyncio 44 | async def test_explain(graph: UpdateGraphMetadata) -> None: 45 | update_graph_metadata(graph) 46 | await materialize_graph() 47 | 48 | out = await explain( 49 | 'PREDICT USERS.AGE>20 FOR EACH USERS.USER_ID', 50 | index=0, 51 | anchor_time=None, 52 | num_neighbors=[16, 16], 53 | max_pq_iterations=20, 54 | ) 55 | assert isinstance(out.prediction, dict) 56 | assert isinstance(out.explanation, Explanation) 57 | assert len(out.logs) == 0 58 | -------------------------------------------------------------------------------- /.github/workflows/_release.yml: -------------------------------------------------------------------------------- 1 | name: Release on tag 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | tags: 6 | - 'v[0-9]+.[0-9]+.[0-9]+' # vX.Y.Z 7 | - 'v[0-9]+.[0-9]+.[0-9]+rc[0-9]+' # vX.Y.ZrcN 8 | 9 | jobs: 10 | release: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v5 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | 20 | - name: Build package 21 | run: | 22 | python -m pip install --upgrade pip build twine 23 | python -m build 24 | twine check dist/* 25 | 26 | - name: Create a GitHub release (pre-release) 27 | if: ${{ contains(github.ref_name, 'rc') }} 28 | run: | 29 | gh release create ${{ github.ref_name }} --verify-tag --generate-notes --title ${{ github.ref_name }} --prerelease 30 | env: 31 | GH_TOKEN: ${{ github.token }} 32 | 33 | - name: Create a GitHub release 34 | if: ${{ !contains(github.ref_name, 'rc') }} 35 | run: | 36 | gh release create ${{ github.ref_name }} --verify-tag --generate-notes --title ${{ github.ref_name }} 37 | env: 38 | GH_TOKEN: ${{ github.token }} 39 | 40 | - name: Upload to TestPyPI 41 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 42 | with: 43 | repository-url: https://test.pypi.org/legacy/ 44 | user: __token__ 45 | password: ${{ secrets.TEST_PYPI_TOKEN }} 46 | verbose: true 47 | 48 | - name: Upload to PyPI 49 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 50 | with: 51 | repository-url: https://upload.pypi.org/legacy/ 52 | user: __token__ 53 | password: ${{ secrets.PYPI_TOKEN }} 54 | verbose: true 55 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/session.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | from fastmcp.exceptions import ToolError 5 | from kumoai.experimental import rfm 6 | from typing_extensions import Self 7 | 8 | 9 | @dataclass(init=False, repr=False) 10 | class Session: 11 | name: str 12 | _graph: rfm.LocalGraph = field(default_factory=lambda: rfm.LocalGraph([])) 13 | _model: rfm.KumoRFM | None = None 14 | 15 | def __init__(self, name: str) -> None: 16 | self.name = name 17 | self._graph = rfm.LocalGraph([]) 18 | self._model = None 19 | 20 | @property 21 | def is_initialized(self) -> bool: 22 | from kumoai import global_state 23 | return global_state.initialized 24 | 25 | @property 26 | def graph(self) -> rfm.LocalGraph: 27 | return self._graph 28 | 29 | @property 30 | def model(self) -> rfm.KumoRFM: 31 | if self._model is None: 32 | raise ToolError("Graph is not yet materialized") 33 | self.initialize() 34 | return self._model 35 | 36 | def clear(self) -> Self: 37 | """Clear the current session.""" 38 | self._graph = rfm.LocalGraph([]) 39 | self._model = None 40 | return self 41 | 42 | def initialize(self) -> Self: 43 | """Initialize a session from environment variables.""" 44 | if not self.is_initialized: 45 | if os.getenv('KUMO_API_KEY') is None: 46 | raise ToolError("Missing required environment variable " 47 | "'KUMO_API_KEY'. Please set your API key via " 48 | "`export KUMO_API_KEY='your-api-key'` or " 49 | "call the 'authenticate' tool to " 50 | "automatically generate an API key.") 51 | 52 | rfm.init() 53 | 54 | return self 55 | 56 | def __repr__(self) -> str: 57 | return f'{self.__class__.__name__}(name={self.name})' 58 | 59 | 60 | class SessionManager: 61 | _default: Session = Session(name='default') 62 | 63 | @classmethod 64 | def get_default_session(cls) -> Session: 65 | r"""Returns the default session.""" 66 | return cls._default 67 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: true 3 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 4 | autoupdate_schedule: monthly 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v6.0.0 9 | hooks: 10 | - id: no-commit-to-branch 11 | name: No commits to master 12 | - id: end-of-file-fixer 13 | name: End-of-file fixer 14 | - id: mixed-line-ending 15 | name: Fix mixed line endings 16 | args: [--fix, lf] 17 | - id: trailing-whitespace 18 | name: Remove trailing whitespaces 19 | - id: check-toml 20 | name: Check toml 21 | 22 | - repo: https://github.com/adrienverge/yamllint 23 | rev: v1.37.1 24 | hooks: 25 | - id: yamllint 26 | name: Lint yaml 27 | args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}'] 28 | 29 | - repo: https://github.com/asottile/pyupgrade 30 | rev: v3.21.2 31 | hooks: 32 | - id: pyupgrade 33 | name: Upgrade Python syntax 34 | args: [--py310-plus] 35 | 36 | - repo: https://github.com/PyCQA/autoflake 37 | rev: v2.3.1 38 | hooks: 39 | - id: autoflake 40 | name: Remove unused imports and variables 41 | args: [ 42 | --remove-all-unused-imports, 43 | --remove-unused-variables, 44 | --remove-duplicate-keys, 45 | --ignore-init-module-imports, 46 | --in-place, 47 | ] 48 | 49 | - repo: https://github.com/google/yapf 50 | rev: v0.43.0 51 | hooks: 52 | - id: yapf 53 | name: Format code 54 | additional_dependencies: [toml] 55 | 56 | - repo: https://github.com/pycqa/isort 57 | rev: 7.0.0 58 | hooks: 59 | - id: isort 60 | name: Sort imports 61 | 62 | - repo: https://github.com/PyCQA/flake8 63 | rev: 7.3.0 64 | hooks: 65 | - id: flake8 66 | name: Check PEP8 67 | 68 | - repo: https://github.com/astral-sh/ruff-pre-commit 69 | rev: v0.14.7 70 | hooks: 71 | - id: ruff 72 | name: Ruff formatting 73 | args: [--fix, --exit-non-zero-on-fix] 74 | 75 | - repo: https://github.com/executablebooks/mdformat 76 | rev: 0.7.22 77 | hooks: 78 | - id: mdformat 79 | name: Format Markdown 80 | additional_dependencies: 81 | - mdformat-gfm 82 | - mdformat_frontmatter 83 | - mdformat_footnote 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "kumo-rfm-mcp" 7 | dynamic = ["version"] 8 | description = "Model Context Protocol server for KumoRFM" 9 | readme = "README.md" 10 | license = { text = "MIT" } 11 | authors = [ 12 | { name = "Kumo.AI", email = "hello@kumo.ai" } 13 | ] 14 | keywords = [ 15 | "rfm", 16 | "mcp", 17 | "foundation model", 18 | "relational data", 19 | "kumo", 20 | ] 21 | requires-python = ">=3.10" 22 | dependencies = [ 23 | "kumoai==2.10.1", 24 | "fastmcp>=2.2.7,<3", 25 | ] 26 | classifiers = [ 27 | "License :: OSI Approved :: MIT License", 28 | "Development Status :: 3 - Alpha", 29 | "Environment :: Console", 30 | "Operating System :: OS Independent", 31 | "Intended Audience :: Developers", 32 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 33 | "Topic :: Scientific/Engineering :: Information Analysis", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | "Programming Language :: Python :: 3", 36 | "Programming Language :: Python :: 3.10", 37 | "Programming Language :: Python :: 3.11", 38 | "Programming Language :: Python :: 3.12", 39 | "Programming Language :: Python :: 3.13", 40 | ] 41 | 42 | [project.optional-dependencies] 43 | s3 = ["s3fs"] 44 | dev = [ 45 | "pytest", 46 | "pytest-asyncio", 47 | ] 48 | 49 | [project.scripts] 50 | kumo-rfm-mcp = "kumo_rfm_mcp.server:main" 51 | 52 | [tool.hatch.build] 53 | exclude = [ 54 | ".github", 55 | "mcpb", 56 | ".pre-commit-config.yaml", 57 | ] 58 | 59 | [tool.hatch.build.targets.wheel] 60 | packages = ["kumo_rfm_mcp"] 61 | 62 | [tool.hatch.version] 63 | path = "kumo_rfm_mcp/_version.py" 64 | 65 | [project.urls] 66 | Homepage = "https://kumorfm.ai" 67 | Repository = "https://github.com/kumo-ai/kumo-rfm-mcp.git" 68 | Issues = "https://github.com/kumo-ai/kumo-rfm-mcp/issues" 69 | 70 | [tool.yapf] 71 | based_on_style="pep8" 72 | split_before_named_assigns=false 73 | blank_line_before_nested_class_or_def=false 74 | 75 | [tool.isort] 76 | multi_line_output=3 77 | include_trailing_comma=true 78 | skip=[".gitignore", "__init__.py"] 79 | 80 | [tool.pytest.ini_options] 81 | addopts=[ 82 | "--capture=no", 83 | "--color=yes", 84 | "-vv", 85 | ] 86 | testpaths=["test"] 87 | 88 | [tool.mypy] 89 | files=["kumo_rfm_mcp", "test", "mcpb"] 90 | python_version="3.10" 91 | install_types=true 92 | non_interactive=true 93 | ignore_missing_imports=true 94 | show_error_codes=true 95 | warn_redundant_casts=true 96 | warn_unused_configs=true 97 | warn_unused_ignores=true 98 | disallow_untyped_defs=true 99 | disallow_incomplete_defs=true 100 | check_untyped_defs=true 101 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/resources/overview.md: -------------------------------------------------------------------------------- 1 | # Overview of KumoRFM 2 | 3 | KumoRFM is a **pre-trained Relational Foundation Model (RFM)** that generates training-free predictions on any relational multi-table data by interpreting the data as a (temporal) heterogeneous graph. 4 | It can be queried via the **Predictive Query Language (PQL)**. 5 | 6 | **Internal note:** It is NOT related to "Recency, Frequency, Monetary" analysis. 7 | 8 | Internally, KumoRFM utilizes in-context learning to transfer patterns from historical examples to new unseen examples. 9 | Specifically, it constructs training/in-context subgraphs with known ground-truth labels and relates them to unseen subgraphs. 10 | 11 | KumoRFM can discover table-like files (e.g., CSV, Parquet), inspect them, and structure them into a graph via foreign key-primary key relationships. 12 | A time column in a table dictates the create time of a row, which is used downstream to receive and order historical interactions and prevent temporal leakage. 13 | Each column within a table is assigned a semantic type (`"numerical"`, `"categorical"`, `"multicategorical"`, `"ID"`, `"text"`, `"timestamp"`, `"sequence"`, etc) that denotes the semantic meaning of the column and how it should be processed within the model. 14 | 15 | See the `kumo://docs/graph-setup` resource for more information. 16 | 17 | After a graph is set up and materialized, KumoRFM can generate predictions (e.g., missing value imputation, temporal forecasts) and evaluations by querying the graph via the Predictive Query Language (PQL), a declarative language to formulate machine learning tasks. 18 | Understanding PQL and how it maps to a machine learning task is critical to achieve good model predictions. 19 | Besides PQL, various other options exist to tune model output, e.g., optimizing the `run_mode` of the model, specifying how subgraphs are formed via `num_neighbors`, or adjusting the `anchor_time` to denote the point in time for when a prediction should be made. 20 | 21 | See the `kumo://docs/predictive-query` resource for more information. 22 | 23 | ## Getting Started 24 | 25 | 1. Finding, inspecting and understanding table-like files via `find_table_files` and `inspect_table_files` tools 26 | 1. Constructing and updating the graph schema via `update_graph_metadata` by adding/updating and removing tables and their schema, and inter-connecting them via foreign key-primary key relationshsips 27 | 1. Visualizing the graph schema as a Mermaid entity relationship diagram via `get_mermaid` 28 | 1. Materializing the graph via `materialize_graph` to make it available for inference operations; This step is necessary to efficiently form subgraphs around entities at any given point in time 29 | 1. Predicting and evaluating predictive queries on top of the materialized graph via `predict` and `evaluate` to obtain valuable insights for the future 30 | 31 | ## Quick Access 32 | 33 | Use the `get_docs` tool to access any resource: 34 | 35 | ``` 36 | get_docs('kumo://docs/graph-setup') 37 | get_docs('kumo://docs/predictive-query") 38 | ``` 39 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import logging 3 | import sys 4 | from pathlib import Path 5 | 6 | from fastmcp import FastMCP 7 | from fastmcp.resources import FileResource 8 | from pydantic import AnyUrl 9 | 10 | import kumo_rfm_mcp 11 | from kumo_rfm_mcp import tools 12 | 13 | logging.basicConfig( 14 | level=logging.INFO, 15 | format='[%(levelname)s] - %(asctime)s - %(name)s - %(message)s', 16 | stream=sys.stderr) 17 | logger = logging.getLogger('kumo-rfm-mcp') 18 | 19 | mcp = FastMCP( 20 | name='KumoRFM (Relational Foundation Model)', 21 | instructions=("KumoRFM is a pre-trained Relational Foundation Model (RFM) " 22 | "that generates training-free predictions on any relational " 23 | "multi-table data by interpreting the data as a (temporal) " 24 | "heterogeneous graph. It can be queried via the Predictive " 25 | "Query Language (PQL)."), 26 | version=kumo_rfm_mcp.__version__, 27 | ) 28 | 29 | # Tools ###################################################################### 30 | tools.register_docs_tools(mcp) 31 | tools.register_auth_tools(mcp) 32 | tools.register_io_tools(mcp) 33 | tools.register_graph_tools(mcp) 34 | tools.register_model_tools(mcp) 35 | 36 | # Resources ################################################################## 37 | mcp.add_resource( 38 | FileResource( 39 | uri=AnyUrl('kumo://docs/overview'), 40 | path=Path(__file__).parent / 'resources' / 'overview.md', 41 | name="Overview of KumoRFM", 42 | description="Overview of KumoRFM (Relational Foundation Model)", 43 | mime_type='text/markdown', 44 | tags={'documentation'}, 45 | )) 46 | mcp.add_resource( 47 | FileResource( 48 | uri=AnyUrl('kumo://docs/graph-setup'), 49 | path=Path(__file__).parent / 'resources' / 'graph-setup.md', 50 | name="Graph Setup", 51 | description="How to set up graphs in KumoRFM", 52 | mime_type='text/markdown', 53 | tags={'documentation'}, 54 | )) 55 | mcp.add_resource( 56 | FileResource( 57 | uri=AnyUrl('kumo://docs/predictive-query'), 58 | path=Path(__file__).parent / 'resources' / 'predictive-query.md', 59 | name="Predictive Query", 60 | description="How to query and generate predictions in KumoRFM", 61 | mime_type='text/markdown', 62 | tags={'documentation'}, 63 | )) 64 | mcp.add_resource( 65 | FileResource( 66 | uri=AnyUrl('kumo://docs/explainability'), 67 | path=Path(__file__).parent / 'resources' / 'explainability.md', 68 | name="Explainability", 69 | description="How to interpret and summarize explanations of KumoRFM", 70 | mime_type='text/markdown', 71 | tags={'documentation'}, 72 | )) 73 | 74 | 75 | def main() -> None: 76 | """Main entry point for the CLI command.""" 77 | try: 78 | mcp.run(transport='stdio') 79 | except KeyboardInterrupt: 80 | logger.info("Server shutdown requested by user") 81 | sys.exit(0) 82 | except Exception as e: 83 | logger.error(f"Failed to start KumoRFM MCP server: {e}") 84 | sys.exit(1) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/tools/io.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os.path as osp 3 | from pathlib import Path 4 | from typing import Annotated 5 | 6 | import pandas as pd 7 | from fastmcp import FastMCP 8 | from fastmcp.exceptions import ToolError 9 | from pydantic import Field 10 | 11 | from kumo_rfm_mcp import TableSource, TableSourcePreview 12 | 13 | 14 | async def find_table_files( 15 | path: Annotated[Path, "Local root directory to scan"], 16 | recursive: Annotated[ 17 | bool, 18 | Field( 19 | default=False, 20 | description=("Whether to scan subdirectories recursively. Use " 21 | "with caution in large directories such as home " 22 | "folders or system directories."), 23 | ), 24 | ], 25 | ) -> list[TableSource]: 26 | """Finds all table-like files (e.g., CSV, Parquet) in a directory. 27 | 28 | This tool is for local directories only. It cannot search, e.g., in S3 29 | buckets. 30 | """ 31 | path = path.expanduser() 32 | 33 | if not path.exists() or not path.is_dir(): 34 | raise ToolError(f"Directory '{path}' does not exist") 35 | 36 | def _find_table_files() -> list[TableSource]: 37 | pattern = "**/*" if recursive else "*" 38 | suffixes = {'.csv', '.parquet'} 39 | files = [f for f in path.glob(pattern) if f.suffix.lower() in suffixes] 40 | return [ 41 | TableSource(path=f, bytes=f.stat().st_size) for f in sorted(files) 42 | ] 43 | 44 | return await asyncio.to_thread(_find_table_files) 45 | 46 | 47 | async def inspect_table_files( 48 | paths: Annotated[ 49 | list[str], 50 | ("File paths to inspect. Can be a mix of local file paths, S3 URIs " 51 | "(s3://...), or HTTP/HTTPS URLs."), 52 | ], 53 | num_rows: Annotated[ 54 | int, 55 | Field( 56 | default=20, 57 | ge=1, 58 | le=1000, 59 | description="Number of rows to read per file", 60 | ), 61 | ], 62 | ) -> dict[str, TableSourcePreview]: 63 | """Inspect the first rows of table-like files. 64 | 65 | Each row in a file is represented as a dictionary mapping column 66 | names to their corresponding values. 67 | """ 68 | def read_file(path: str) -> TableSourcePreview: 69 | path = osp.expanduser(path) 70 | suffix = path.rsplit('.', maxsplit=1)[-1].lower() 71 | 72 | if suffix not in {'csv', 'parquet'}: 73 | raise ToolError(f"'{path}' is not a valid CSV or Parquet file") 74 | 75 | try: 76 | if suffix == 'csv': 77 | df = pd.read_csv(path, nrows=num_rows) 78 | else: 79 | assert suffix == 'parquet' 80 | # TODO Read first row groups via `pyarrow` instead. 81 | df = pd.read_parquet(path).head(num_rows) 82 | except Exception as e: 83 | raise ToolError(f"Could not read file '{path}': {e}") from e 84 | 85 | df = df.astype(object).where(df.notna(), None) 86 | return TableSourcePreview(rows=df.to_dict(orient='records')) 87 | 88 | tasks = [asyncio.to_thread(read_file, path) for path in paths] 89 | previews = await asyncio.gather(*tasks) 90 | return {path: preview for path, preview in zip(paths, previews)} 91 | 92 | 93 | def register_io_tools(mcp: FastMCP) -> None: 94 | """Register all I/O tools to the MCP server.""" 95 | mcp.tool(annotations=dict( 96 | title="🔍 Searching for tabular files…", 97 | readOnlyHint=True, 98 | destructiveHint=False, 99 | idempotentHint=True, 100 | openWorldHint=False, 101 | ))(find_table_files) 102 | 103 | mcp.tool(annotations=dict( 104 | title="🧐 Analyzing table structure…", 105 | readOnlyHint=True, 106 | destructiveHint=False, 107 | idempotentHint=True, 108 | openWorldHint=False, 109 | ))(inspect_table_files) 110 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import pandas as pd 5 | import pytest 6 | from kumoai.experimental import rfm 7 | from kumoai.experimental.rfm.rfm import Explanation 8 | from kumoapi.rfm import Explanation as ExplanationConfig 9 | from kumoapi.task import TaskType 10 | from pytest import TempPathFactory 11 | 12 | from kumo_rfm_mcp import ( 13 | AddTableMetadata, 14 | LinkMetadata, 15 | SessionManager, 16 | UpdateGraphMetadata, 17 | ) 18 | 19 | 20 | @pytest.fixture(autouse=True) 21 | def clear_session() -> None: 22 | SessionManager.get_default_session().clear() 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def mock_model(monkeypatch: pytest.MonkeyPatch) -> None: 27 | monkeypatch.setenv('KUMO_API_KEY', 'DUMMY') 28 | monkeypatch.setattr(rfm, 'init', lambda *args, **kwargs: None) 29 | 30 | def predict( 31 | *args: Any, 32 | explain: bool = False, 33 | **kwargs: Any, 34 | ) -> pd.DataFrame | Explanation: 35 | 36 | df = pd.DataFrame({ 37 | 'ENTITY': [0], 38 | 'ANCHOR_TIMESTAMP': ['2025-01-1'], 39 | 'TARGET_PRED': [True], 40 | 'False_PROB': [0.4], 41 | 'True_PROB': [0.6], 42 | }) 43 | 44 | if not explain: 45 | return df 46 | 47 | return Explanation( 48 | prediction=df, 49 | summary='', 50 | details=ExplanationConfig( 51 | task_type=TaskType.BINARY_CLASSIFICATION, 52 | cohorts=[], 53 | subgraphs=[], 54 | ), 55 | ) 56 | 57 | def evaluate(*args: Any, **kwargs: Any) -> pd.DataFrame: 58 | return pd.DataFrame({ 59 | 'metric': ['ap', 'auprc', 'auroc'], 60 | 'value': [0.8, 0.8, 0.9], 61 | }) 62 | 63 | monkeypatch.setattr(rfm.KumoRFM, 'predict', predict) 64 | monkeypatch.setattr(rfm.KumoRFM, 'evaluate', evaluate) 65 | 66 | 67 | @pytest.fixture(scope='session') 68 | def root_dir(tmp_path_factory: TempPathFactory) -> Path: 69 | path = tmp_path_factory.mktemp('table_files') 70 | 71 | df = pd.DataFrame({ 72 | 'USER_ID': [0, 1, 2, 3], 73 | 'AGE': [20, 30, 40, float('NaN')], 74 | 'GENDER': ['male', 'female', 'female', None], 75 | }) 76 | df.to_csv(path / 'USERS.csv', index=False) 77 | 78 | df = pd.DataFrame({ 79 | 'USER_ID': [0, 1, 2, 3], 80 | 'STORE_ID': [0, 1, 0, 1], 81 | 'AMOUNT': [10, 15, float('NaN'), 20], 82 | 'TIME': ['2025-01-01', '2025-01-02', '2025-01-03', '2025-01-04'], 83 | }) 84 | df.to_parquet(path / 'ORDERS.parquet') 85 | 86 | df = pd.DataFrame({ 87 | 'STORE_ID': [0, 1], 88 | 'CAT': ['burger', 'pizza'], 89 | }) 90 | df.to_csv(path / 'STORES.csv', index=False) 91 | 92 | return path 93 | 94 | 95 | @pytest.fixture 96 | def graph(root_dir: Path) -> UpdateGraphMetadata: 97 | return UpdateGraphMetadata( # type: ignore[call-arg] 98 | tables_to_add=[ 99 | AddTableMetadata( 100 | path=(root_dir / 'USERS.csv').as_posix(), 101 | name='USERS', 102 | primary_key='USER_ID', 103 | time_column=None, 104 | end_time_column=None, 105 | ), 106 | AddTableMetadata( 107 | path=(root_dir / 'ORDERS.parquet').as_posix(), 108 | name='ORDERS', 109 | primary_key=None, 110 | time_column='TIME', 111 | end_time_column=None, 112 | ), 113 | AddTableMetadata( 114 | path=(root_dir / 'STORES.csv').as_posix(), 115 | name='STORES', 116 | primary_key='STORE_ID', 117 | time_column=None, 118 | end_time_column=None, 119 | ), 120 | ], 121 | links_to_add=[ 122 | LinkMetadata( 123 | source_table='ORDERS', 124 | foreign_key='USER_ID', 125 | destination_table='USERS', 126 | ), 127 | LinkMetadata( 128 | source_table='ORDERS', 129 | foreign_key='STORE_ID', 130 | destination_table='STORES', 131 | ), 132 | ], 133 | ) 134 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/tools/docs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal 3 | 4 | from fastmcp import FastMCP 5 | from fastmcp.exceptions import ToolError 6 | from fastmcp.resources import FileResource 7 | 8 | 9 | def register_docs_tools(mcp: FastMCP) -> None: 10 | """Register all documentation tools to the MCP server.""" 11 | @mcp.tool(annotations=dict( 12 | title="📖 Reading…", 13 | readOnlyHint=True, 14 | destructiveHint=False, 15 | idempotentHint=False, 16 | openWorldHint=False, 17 | )) 18 | async def get_docs( 19 | resource_uri: Literal[ 20 | 'kumo://docs/overview', 21 | 'kumo://docs/graph-setup', 22 | 'kumo://docs/predictive-query', 23 | 'kumo://docs/explainability', 24 | ] 25 | ) -> str: 26 | """Get documentation on how to use KumoRFM. 27 | 28 | KumoRFM is a pre-trained Relational Foundation Model (RFM) that 29 | generates training-free predictions on any relational multi-table 30 | data by interpreting the data as a (temporal) heterogeneous graph. It 31 | can be queried via the Predictive Query Language (PQL). 32 | 33 | Internal note: It is NOT related to "Recency, Frequency, Monetary" 34 | analysis. 35 | 36 | Internally, KumoRFM utilizes in-context learning to transfer patterns 37 | from historical examples to new unseen examples. Specifically, it 38 | constructs training/in-context subgraphs with known ground-truth 39 | labels and relates them to unseen subgraphs. 40 | 41 | See the 'kumo://docs/overview' resource for more information. 42 | 43 | KumoRFM can discover table-like files (e.g., CSV, Parquet), inspect 44 | them, and structure them into a graph via foreign key-primary key 45 | relationships. A time column in a table dictates the create time of a 46 | row, which is used downstream to receive and order historical 47 | interactions and prevent temporal leakage. Each column within a table 48 | is assigned a semantic type (numerical, categorical, multi-categorical, 49 | ID, text, timestamp, sequence, etc) that denotes the semantic meaning 50 | of the column and how it should be processed within the model. 51 | 52 | Important: Before creating and updating graphs, read the 53 | documentation first at 'kumo://docs/graph-setup'. 54 | 55 | After a graph is set up and materialized, KumoRFM can generate 56 | predictions (e.g., missing value imputation, temporal forecasts) and 57 | evaluations by querying the graph via the Predictive Query Language 58 | (PQL), a declarative language to formulate machine learning tasks. 59 | Understanding PQL and how it maps to a machine learning task is 60 | critical to achieve good model predictions. Besides PQL, various other 61 | options exist to tune model output, e.g., optimizing the `run_mode` of 62 | the model, specifying how subgraphs are formed via `num_neighbors`, or 63 | adjusting the `anchor_time` to denote the point in time for when a 64 | prediction should be made. 65 | 66 | Important: Before executing or suggesting any predictive queries, 67 | read the documentation first at 'kumo://docs/predictive-query'. 68 | 69 | KumoRFM can additionally generate explanations for predictions, 70 | providing both a global column-level analysis and a local, cell-level 71 | attribution view. 72 | Together, these views enable comprehensive interpretation. 73 | 74 | Important: Before analyzing the explanation output, read the 75 | documentation first at 'kumo://docs/explainability'. 76 | """ 77 | resources = await mcp.get_resources() 78 | if resource_uri not in resources: 79 | raise ToolError(f"Resource '{resource_uri}' not found. Available " 80 | f"resources: {list(resources.keys())}") 81 | 82 | resource = resources[resource_uri] 83 | 84 | if isinstance(resource, FileResource): 85 | if getattr(resource, 'path', None): 86 | path = Path(resource.path) 87 | else: # Construct path from URI: 88 | name = f"{str(resource.uri).rsplit('/', 1)[-1]}.md" 89 | path = Path(__file__).parent.parent / 'resources' / name 90 | 91 | if not path.exists(): 92 | raise ToolError(f"File resource '{resource_uri}' not found at " 93 | "'{path}'") 94 | 95 | return path.read_text(encoding='utf-8') 96 | 97 | raise ToolError(f"Resource '{resource_uri}' is not accessible") 98 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/resources/explainability.md: -------------------------------------------------------------------------------- 1 | # Explainability 2 | 3 | KumoRFM explanations provide two complementary views of model predictions: 4 | 5 | 1. **Global View (Cohorts):** Column-level patterns across in-context examples that reveal what data characteristics drive predictions 6 | 1. **Local View (Subgraph):** Cell-level attribution scores showing which specific values in this entity's subgraph influenced the prediction 7 | 8 | Together, these views answer: "What patterns does the model see globally?" and "Which specific data points matter for this prediction?" 9 | 10 | ## Understanding the Global View: Cohorts 11 | 12 | Cohorts reveal how different value ranges or categories in columns correlate with prediction outcomes across all in-context examples. 13 | 14 | - `table_name`: Which table this analysis covers 15 | - `column_name`: Which column or statistic (e.g., `COUNT(*)`) this analysis covers 16 | - `hop`: Distance from the entity table (0 = entity attributes, 1 = direct neighbors, 2 = second-degree neighbors, ...) 17 | - `stype`: Semantic type (numerical, categorical, timestamp, etc) 18 | - `cohorts`: List of value ranges/categories (e.g., `["[0-5]", "(5-10]", "(10-20+]"]`) 19 | - `populations`: Proportion of in-context examples in each cohort 20 | - `targets`: Average prediction score within each cohort 21 | 22 | High-impact columns usually have large variance in `targets` across different cohorts. 23 | 24 | **Example for a churn predictive query:** 25 | 26 | ``` 27 | table_name: "orders" 28 | column_name: "COUNT(*)" 29 | hop: 1 30 | cohorts: ["[0-0]", "(0-1]", "(1-2]", "(2-4]", "(4-6+]"] 31 | populations: [0.20, 0.08, 0.07, 0.11, 0.54] 32 | targets: [0.0, 0.78, 0.74, 0.64, 0.35] 33 | ``` 34 | 35 | **What this means:** 36 | 37 | - Users with 0 orders have 0% churn risk (they already churned) 38 | - Users with 1-2 orders have ~75% churn risk (early stage, not sticky) 39 | - Users with 6+ orders have 35% churn risk (established, but not immune) 40 | - Key insight: Order count is strongly predictive; more orders = lower churn 41 | 42 | ## Understanding the Local View: Subgraph 43 | 44 | Subgraphs show the actual data neighborhood around the specific entity being predicted, with attribution scores indicating importance. 45 | Node indices are different from primary keys and are mapped to a contiguous range from 0 to N. 46 | The entity being predicted is guaranteed to have ID 0. 47 | Some cells may have a `null` value with non-zero scores, indicating missingness itself is informative. 48 | 49 | Each node represents a row from a table, containing: 50 | 51 | - `cells`: Dictionary of column values with attribution scores 52 | - `value`: Actual data value 53 | - `score`: Gradient-based importance between 0 and 1 (higher = more influential) 54 | - `links`: Connections to other nodes via foreign keys 55 | 56 | Scores reflect how much changing this value would change the prediction. 57 | High scores on specific cells explain "why this prediction, not another". 58 | 59 | **Score Magnitude Interpretation:** 60 | 61 | - 0.00 - 0.05: Negligible influence 62 | - 0.05 - 0.15: Moderate influence 63 | - 0.15 - 0.30: Strong influence 64 | - 0.30+: Critical influence 65 | 66 | **Example:** 67 | 68 | ``` 69 | cells: { 70 | "club_member_status": {value: "ACTIVE", score: 1.0}, 71 | "age": {value: 49, score: 0.089}, 72 | "fashion_news_frequency": {value: "Regularly", score: 0.411} 73 | } 74 | links: { 75 | "user_id->orders": [1,2,3,...,32] 76 | } 77 | ``` 78 | 79 | **What this means:** 80 | 81 | Club membership status is the most important attribute (score=1.0) 82 | Fashion news subscription is moderately important (score=0.411). 83 | Age contributes but is less critical (score=0.089). 84 | User has 32 orders linked (indicates high activity). 85 | 86 | You can follow paths in the subgraph to understand data connectivity and how tables/cells far away may contribute to the prediction. 87 | 88 | ## Connecting Global and Local Views 89 | 90 | Often times, you can understand high subgraph attribution scores by relating their cell values to the average prediction of the cohort. 91 | 92 | 1. **Find influential cells for the prediction in the local view:** 93 | Which cells have scores > 0.15? 94 | 1. **Locate entity in global context:** 95 | Find which cohorts the specific entity falls into and compare entity's values to high/low risk cohorts. 96 | Focus on highest-scoring cells and most divergent cohorts. 97 | 1. **Relate attribution score and cohort prediction:** 98 | Check if entity exhibits typical or atypical patterns. 99 | 1. **Find general global trends** in the data that might explain the prediction. 100 | Additionally, look for missing expected signals (why ISN'T something important?) 101 | 102 | Tell a coherent story connecting global patterns to local evidence. 103 | Use concrete numbers from the subgraph. 104 | Avoid jargon; explain in business terms. 105 | 106 | ## Common Interpretation Pitfalls 107 | 108 | - **Don't assume correlation = causation:** 109 | High scores show model importance, not real-world causality. 110 | For example, "black clothing" might correlate with churn, but color isn't the cause. 111 | - **Consider data distribution:** 112 | Rare cohorts may show extreme `targets` with small `populations`. 113 | Focus on cohorts with both significant population AND divergent targets. 114 | - **Missing cohort analysis:** 115 | Not all columns have a cohort analysis since some semantic types are unsupported. 116 | For example, text and ID columns typically only appear in local view. 117 | -------------------------------------------------------------------------------- /test/tools/test_graph.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from kumoapi.typing import Stype 5 | 6 | from kumo_rfm_mcp import ( 7 | AddTableMetadata, 8 | LinkMetadata, 9 | UpdateGraphMetadata, 10 | UpdateTableMetadata, 11 | ) 12 | from kumo_rfm_mcp.tools.graph import ( 13 | get_mermaid, 14 | inspect_graph_metadata, 15 | lookup_table_rows, 16 | materialize_graph, 17 | update_graph_metadata, 18 | ) 19 | 20 | 21 | def test_graph_metadata(root_dir: Path) -> None: 22 | graph = inspect_graph_metadata() 23 | assert len(graph.tables) == 0 24 | assert len(graph.links) == 0 25 | 26 | update = UpdateGraphMetadata( # type: ignore[call-arg] 27 | tables_to_add=[ 28 | AddTableMetadata( 29 | path=(root_dir / 'USERS.csv').as_posix(), 30 | name='USERS', 31 | primary_key='USER_ID', 32 | time_column=None, 33 | end_time_column=None, 34 | ), 35 | AddTableMetadata( 36 | path=(root_dir / 'ORDERS.parquet').as_posix(), 37 | name='ORDERS', 38 | primary_key=None, 39 | time_column='TIME', 40 | end_time_column=None, 41 | ), 42 | AddTableMetadata( 43 | path=(root_dir / 'STORES.csv').as_posix(), 44 | name='STORES', 45 | primary_key='STORE_ID', 46 | time_column=None, 47 | end_time_column=None, 48 | ), 49 | ]) 50 | update_graph_metadata(update) 51 | out = update_graph_metadata(update) # idempotent 52 | assert len(out.graph.tables) == 3 53 | assert len(out.graph.links) == 0 54 | assert len(out.errors) == 0 55 | 56 | update = UpdateGraphMetadata( # type: ignore[call-arg] 57 | tables_to_update={ 58 | 'USERS': UpdateTableMetadata( # type: ignore[call-arg] 59 | stypes={ 60 | 'AGE': Stype.categorical, 61 | 'GENDER': None, 62 | }, 63 | ) 64 | } 65 | ) 66 | update_graph_metadata(update) 67 | out = update_graph_metadata(update) # idempotent 68 | assert len(out.graph.tables) == 3 69 | assert len(out.graph.links) == 0 70 | assert len(out.errors) == 0 71 | assert out.graph.tables[0].stypes['AGE'] == Stype.categorical 72 | assert out.graph.tables[0].stypes['GENDER'] is None 73 | 74 | update = UpdateGraphMetadata( # type: ignore[call-arg] 75 | links_to_add=[ 76 | LinkMetadata( 77 | source_table='ORDERS', 78 | foreign_key='USER_ID', 79 | destination_table='USERS', 80 | ), 81 | LinkMetadata( 82 | source_table='ORDERS', 83 | foreign_key='STORE_ID', 84 | destination_table='STORES', 85 | ), 86 | ]) 87 | update_graph_metadata(update) 88 | out = update_graph_metadata(update) # idempotent 89 | assert len(out.graph.tables) == 3 90 | assert len(out.graph.links) == 2 91 | assert len(out.errors) == 0 92 | 93 | update = UpdateGraphMetadata( # type: ignore[call-arg] 94 | links_to_remove=[ 95 | LinkMetadata( 96 | source_table='ORDERS', 97 | foreign_key='USER_ID', 98 | destination_table='USERS', 99 | ), 100 | ], 101 | tables_to_remove=['STORES'], 102 | ) 103 | update_graph_metadata(update) 104 | out = update_graph_metadata(update) # idempotent 105 | assert len(out.graph.tables) == 2 106 | assert len(out.graph.links) == 0 107 | assert len(out.errors) == 0 108 | 109 | 110 | def test_get_mermaid(graph: UpdateGraphMetadata) -> None: 111 | update_graph_metadata(graph) 112 | mermaid = get_mermaid(show_columns=False) 113 | assert mermaid == ('erDiagram\n' 114 | ' USERS {\n' 115 | ' ID USER_ID PK\n' 116 | ' }\n' 117 | ' ORDERS {\n' 118 | ' ID USER_ID FK\n' 119 | ' ID STORE_ID FK\n' 120 | ' timestamp TIME\n' 121 | ' }\n' 122 | ' STORES {\n' 123 | ' ID STORE_ID PK\n' 124 | ' }\n' 125 | '\n' 126 | ' USERS o|--o{ ORDERS : USER_ID\n' 127 | ' STORES o|--o{ ORDERS : STORE_ID') 128 | 129 | 130 | @pytest.mark.asyncio 131 | async def test_materialize_graph(graph: UpdateGraphMetadata) -> None: 132 | update_graph_metadata(graph) 133 | await materialize_graph() 134 | out = await materialize_graph() # idempotent 135 | assert out.num_nodes == 10 136 | assert out.num_edges == 16 137 | assert out.time_ranges == { 138 | 'ORDERS': '2025-01-01 00:00:00 - 2025-01-04 00:00:00' 139 | } 140 | 141 | 142 | @pytest.mark.asyncio 143 | async def test_lookup_table_rows(graph: UpdateGraphMetadata) -> None: 144 | update_graph_metadata(graph) 145 | await materialize_graph() 146 | preview = await lookup_table_rows('USERS', ids=[1, 0]) 147 | assert preview.rows == [ 148 | { 149 | 'USER_ID': 1, 150 | 'AGE': 30.0, 151 | 'GENDER': 'female', 152 | }, 153 | { 154 | 'USER_ID': 0, 155 | 'AGE': 20.0, 156 | 'GENDER': 'male', 157 | }, 158 | ] 159 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/resources/graph-setup.md: -------------------------------------------------------------------------------- 1 | # Graph Setup 2 | 3 | This guide outlines the data requirements and best practices for setting up graphs from relational data in KumoRFM. 4 | 5 | KumoRFM operates on relational data organized as inter-connected tables forming a graph structure. The foundation of this process starts with a set of CSV or Parquet files, which are registered as table schemas and assembled into a graph schema. 6 | 7 | ## Table Schema 8 | 9 | A table schema is defined by three concepts: 10 | 11 | - **Semantic types (`stypes`):** Semantic types denote the semantic meaning of columns in a table and how they should be processed within the model 12 | - **Primary key (`primary_key`):** A unique identifier for the table 13 | - **Time column (`time_column`):** The column that denotes the create time of rows (marking when the row became active) 14 | - **End time column (`end_time_column`):** The column that denotes the end time of rows (marking when the row stopped being active) 15 | 16 | ### Semantic Types 17 | 18 | The semantic type of a column will determine how it will be encoded downstream. 19 | Correctly setting each column's semantic type is critical for model performance. 20 | For instance, for missing value imputation queries, the semantic type determines whether the task is treated as regression (`stype="numerical"`) or as classification (`stype="categorical"`). 21 | 22 | The following semantic types are available: 23 | 24 | | `stype` | Explanation | Supported data types | Example | 25 | | -------------------- | ------------------------------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------- | 26 | | `"numerical"` | Numerical values (e.g., `price`, `age`) | `int`, `float` | `25`, `3.14`, `-10` | 27 | | `"categorical"` | Discrete categories with limited cardinality | `int`, `float`, `string` | Color: `"red"`, `"blue"`, `"green"` (one cell may only have one category) | 28 | | `"multicategorical"` | Multiple categories in a single cell | `string`, `stringlist`, `intlist`, `floatlist` | `"Action\|Drama\|Comedy"`, `["Action", "Drama", "Comedy"]` | 29 | | `"ID"` | An identifier, e.g., primary keys or foreign keys | `int`, `float`, `string` | `123`, `PRD-8729453` | 30 | | `"text"` | Natural language text | `string` | Descriptions of products | 31 | | `"timestamp"` | Specific point in time | `date`, `string` | `"2025-07-11"`, `"2023-02-12 09:47:58"` | 32 | | `"sequence"` | Custom embeddings or sequential data | `floatlist`, `intlist` | `[0.25, -0.75, 0.50, ...]` | 33 | 34 | Upon table registration, semantic types of columns are estimated based on simple heuristics (e.g., data types, cardinality), but may not be ideal. 35 | For example, low cardinality columns may be mistakenly treated as `"categorical"` rather than `"numerical"`. 36 | You can use your world knowledge and common sense to analyze and correct such mistakes. 37 | 38 | If certain columns should be discarded, e.g., in case they have such high cardinality to make proper model generalization infeasible, a semantic type of `None` can be used to discard the column from being encoded. 39 | 40 | ### Primary Key 41 | 42 | The primary key is a unique identifier of each row in a table. 43 | Each table can have at most one primary key. 44 | If there are duplicated primary keys, the system will only keep the first one. 45 | A primary key can be used later to link tables through foreign key-primary key relationships. 46 | However, a primary key does not need to necessarily link to other tables. 47 | Setting a primary key will automatically assing the semantic type `"ID"` to this column. 48 | A primary key may not exist for all tables, but will be required whenever tables need to be linked together or whenever the table is used as the entity in a predictive query. 49 | 50 | ### Time Column 51 | 52 | A time column specifies the timestamp at which an event occured or when this row became active. 53 | It is used to prevent temporal leakage during subgraph sampling, i.e. for a given anchor time only events are preserved with timestamp less than or equal to the given anchor time. 54 | Time column data must obey to datetime format to be correctly parsed by `pandas.to_datetime`. 55 | Each table can have at most one time column. 56 | A time column may not exist for all tables, but will be required when predicting future aggregates over fact tables, e.g., the count of all orders in the next seven days. 57 | The system will only keep rows with non-N/A timestamps. 58 | In case there exists multiple time columns in the table, pick the column as time column that most likely refers to the create time of the event. 59 | For example, `create_time` should be preferred over `update_time`. 60 | 61 | ### End Time Column 62 | 63 | An end time column specifies the timestamp at which an event or row stopped being active. 64 | It is used to exclude in-context examples that have already expired relative to a given anchor time. 65 | End time column data must obey to datetime format to be correctly parsed by `pandas.to_datetime`. 66 | Each table can have at most one end time column. 67 | If both a time column and an end time column are present, they must refer to different columns in the dataset. 68 | An end time column is optional and typically only appears alongside a time column. 69 | 70 | ## Graph Schema 71 | 72 | Links between tables are defined via foreign key-primary key relationships, describing many-to-one relations. 73 | Links are the crucial bit that transform individual tables into a connected relational structure, enabling KumoRFM to understand and leverage relationships in your data. 74 | However, it is also possible to use KumoRFM in single table settings or within multiple disjoint graph schemas registered within the same graph. 75 | 76 | A link is defined by a source table (`source_table`), the foreign key column in the source table (`foreign_key`), and a destination table (`destination_table`) holding a primary key. 77 | For example, the `orders` source table may hold a foreign key `user_id` to link to the destination table `users`, holding a unique identifier for each user. 78 | Often times, links can be naturally inferred by inspecting the table schemas and relying on name matching to find meaningful connections. 79 | However, this may not always be the case. 80 | You can use your world knowledge and common sense to analyze meaningful connections between tables. 81 | 82 | Note that KumoRFM only supports foreign key-primary key links. 83 | In order to connect primary keys to primary keys, you have to remove the primary key in one of the tables. 84 | 85 | **Important:** Make sure that tables are correctly linked before proceeding. 86 | 87 | ## Graph Updates 88 | 89 | You can use the `update_graph_metadata` tool to perform partial graph schema updates by registering new tables, changing their semantic types, and linking them together. 90 | Note that all operations can be performed in a batch at once, e.g., one can add new tables and directly link them to together. 91 | 92 | ## Graph Visualization 93 | 94 | You can visualize the graph at any given point in time by rendering it as a Mermaid entity relationship diagram via the `get_mermaid` tool. 95 | Based on the number of columns in each table, it is recommended to set `show_columns` to `False` to avoid cluttering the diagram with less relevant details. 96 | 97 | ## Graph Materialization 98 | 99 | Once a graph is set up, you can materialize the graph to make it ready for model inference operations via the `materialize_graph` tool. 100 | This step creates the relational entity graph, which converts each row into a node, and each primary-foreign ky link into an edge. 101 | Most importantly, it converts the relational data into a data structure from which it can efficiently perform graph traversal and sample subgraphs, which are later used as inputs into the model. 102 | 103 | Any updates to the graph schema will require re-materializing the graph before the KumoRFM model can start making predictions again. 104 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Trigger Release 2 | 3 | description: > 4 | To release `vX.Y.0rc1`, run this workflow on `main`. 5 | To release `vX.Y.0rc{N+1}`, run this workflow on `vX.Y`. 6 | To release `vX.Y.0`, run this workflow on `vX.Y` with `switch-from-rc-to-final` set to `true`. 7 | To release `vX.Y.{Z+1}`, run this workflow on `vX.Y`. 8 | 9 | on: # yamllint disable-line rule:truthy 10 | workflow_dispatch: 11 | inputs: 12 | dry-run: 13 | type: boolean 14 | required: false 15 | default: true 16 | switch-from-rc-to-final: 17 | description: > 18 | Whether to switch from `vX.Y.0rcN` to `vX.Y.0`. Only applicable when 19 | triggering this workflow on vX.Y that has any RC release(s). 20 | type: boolean 21 | required: false 22 | default: false 23 | 24 | jobs: 25 | release: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/checkout@v5 29 | with: 30 | # This token is necessary to push a commit and tag. 31 | token: ${{ secrets.KUMO_GITHUB_BOT_TOKEN }} 32 | fetch-depth: 0 # for verbosity 33 | fetch-tags: true # for verbosity 34 | 35 | - name: Configure release 36 | id: config 37 | run: | 38 | if [[ ${{ github.ref_name }} == main ]]; then 39 | echo "create-rc-from-dev=true" >> $GITHUB_OUTPUT 40 | elif [[ ${{ github.ref_name }} =~ ^v[0-9]+\.[0-9]+$ ]]; then 41 | if [[ ${{ inputs.switch-from-rc-to-final }} == true ]]; then 42 | echo "create-final-from-rc=true" >> $GITHUB_OUTPUT 43 | else 44 | is_rc=$(python -c 'exec(input()); print("rc" in __version__)' < kumo_rfm_mcp/_version.py) 45 | if [[ $is_rc == "True" ]]; then 46 | echo "create-rc-from-rc=true" >> $GITHUB_OUTPUT 47 | else 48 | echo "create-final-from-final=true" >> $GITHUB_OUTPUT 49 | fi 50 | fi 51 | else 52 | echo "Error: Unexpected branch name: ${{ github.ref_name }}" 53 | exit 1 54 | fi 55 | 56 | - uses: fregante/setup-git-user@024bc0b8e177d7e77203b48dab6fb45666854b35 # v2.0.2 57 | 58 | - name: Set up Node.js 59 | uses: actions/setup-node@v4 60 | 61 | - name: Create vX.Y and vX.Y.0rc1 from vX.Y.0.dev0 on main 62 | if: ${{ steps.config.outputs.create-rc-from-dev == 'true' }} 63 | run: | 64 | # Get the release candidate version: 65 | # - version_dev: X.Y.0.dev0 66 | # - version_rc: X.Y.0rc1 67 | # - version_final: X.Y.0 68 | version_dev=$(python -c 'exec(input()); print(__version__)' < kumo_rfm_mcp/_version.py) # X.Y.0.dev0 69 | version_rc=$(npx semver --coerce $version_dev)rc1 # X.Y.0.dev0 -> X.Y.0rc1 70 | version_final=$(npx semver --coerce $version_dev) # X.Y.0.dev0 -> X.Y.0 71 | 72 | # Update the version in kumo_rfm_mcp/_version.py 73 | echo "__version__ = '$version_rc'" > kumo_rfm_mcp/_version.py 74 | git diff --color=always 75 | git add kumo_rfm_mcp/_version.py 76 | git commit -m "Update version from $version_dev to $version_rc (https://github.com/kumo-ai/kumo/actions/runs/${{ github.run_id }})" 77 | 78 | # Create a release tag: 79 | git tag v$version_rc # vX.Y.0rc1 80 | 81 | # Create a new release branch vX.Y on the commit right before it 82 | # starts diverging from main: 83 | git checkout -b v${version_final//.0} 84 | 85 | # Update the version in kumo_rfm_mcp/_version.py again in main (vX.Y.0rc1 -> vX.{Y+1}.0.dev0) 86 | git checkout ${{ github.ref_name }} # This is normally main, but it's set to ref_name for dry-run. 87 | version_dev_new=$(npx semver -i minor $version_final).dev0 # X.Y.0 -> X.{Y+1}.0.dev0 88 | echo "__version__ = '$version_dev_new'" > kumo_rfm_mcp/_version.py 89 | git diff --color=always 90 | git add kumo_rfm_mcp/_version.py 91 | git commit -m "Update version from $version_rc to $version_dev_new (https://github.com/kumo-ai/kumo/actions/runs/${{ github.run_id }})" 92 | 93 | # Push updated main branch, the new vX.Y branch, and vX.Y.0rc1 tag 94 | # Note that ref_name is usually main, but it's set to ref_name for dry-run. 95 | echo UPDATED_REFS="${{ github.ref_name }} v${version_final//.0} v$version_rc" >> $GITHUB_ENV 96 | 97 | - name: Create vX.Y.0rc{N+1} from vX.Y.0rcN on vX.Y 98 | if: ${{ steps.config.outputs.create-rc-from-rc == 'true' }} 99 | run: | 100 | # Get the release candidate version: 101 | # - version_rc: X.Y.0rcN 102 | # - version_rc_new: X.Y.0rc{N+1} 103 | # - version_final: X.Y.0 104 | version_rc=$(python -c 'exec(input()); print(__version__)' < kumo_rfm_mcp/_version.py) # X.Y.0rcN 105 | version_final=$(npx semver --coerce $version_rc) # X.Y.0rcN -> X.Y.0 106 | n=${version_rc##*rc} # N 107 | version_rc_new=${version_final}rc$((n + 1)) # X.Y.0rc{N+1} 108 | 109 | # Update the version in kumo_rfm_mcp/_version.py 110 | echo "__version__ = '$version_rc_new'" > kumo_rfm_mcp/_version.py 111 | git diff --color=always 112 | git add kumo_rfm_mcp/_version.py 113 | git commit -m "Update version from $version_rc to $version_rc_new (https://github.com/kumo-ai/kumo/actions/runs/${{ github.run_id }})" 114 | 115 | # Create a release tag: 116 | git tag v$version_rc_new # vX.Y.0rc{N+1} 117 | 118 | # Push updated vX.Y branch and vX.Y.0rc{N+1} tag 119 | echo UPDATED_REFS="v${version_final//.0} v$version_rc_new" >> $GITHUB_ENV 120 | 121 | - name: Create vX.Y.0 from vX.Y.0rcN on vX.Y 122 | if: ${{ steps.config.outputs.create-final-from-rc == 'true' }} 123 | run: | 124 | # Get the final release version: 125 | # - version_rc: X.Y.0rcN 126 | # - version_final: X.Y.0 127 | version_rc=$(python -c 'exec(input()); print(__version__)' < kumo_rfm_mcp/_version.py) # X.Y.0rcN 128 | version_final=$(npx semver --coerce $version_rc) # X.Y.0rcN -> X.Y.0 129 | 130 | # Update the version in kumo_rfm_mcp/_version.py 131 | echo "__version__ = '$version_final'" > kumo_rfm_mcp/_version.py 132 | git diff --color=always 133 | git add kumo_rfm_mcp/_version.py 134 | git commit -m "Update version from $version_rc to $version_final (https://github.com/kumo-ai/kumo/actions/runs/${{ github.run_id }})" 135 | 136 | # Create a release tag: 137 | git tag v$version_final # vX.Y.0 138 | 139 | # Push updated vX.Y branch and vX.Y.0 tag 140 | echo UPDATED_REFS="v${version_final//.0} v$version_final" >> $GITHUB_ENV 141 | 142 | - name: Create vX.Y.{Z+1} from vX.Y.Z on vX.Y 143 | if: ${{ steps.config.outputs.create-final-from-final == 'true' }} 144 | run: | 145 | # Get the next patch release version: 146 | # - version_final: X.Y.Z 147 | # - version_final_new: X.Y.{Z+1} 148 | version_final=$(python -c 'exec(input()); print(__version__)' < kumo_rfm_mcp/_version.py) 149 | version_final=$(npx semver --coerce $version_final) # X.Y.Z -> X.Y.Z (just in case) 150 | version_final_new=$(npx semver -i patch $version_final) # X.Y.Z -> X.Y.{Z+1} 151 | 152 | # Update the version in kumo_rfm_mcp/_version.py: 153 | echo "__version__ = '$version_final_new'" > kumo_rfm_mcp/_version.py 154 | git diff --color=always 155 | git add kumo_rfm_mcp/_version.py 156 | git commit -m "Update version from $version_final to $version_final_new (https://github.com/kumo-ai/kumo/actions/runs/${{ github.run_id }})" 157 | 158 | # Create a release tag: 159 | git tag v$version_final_new # vX.Y.{Z+1} 160 | 161 | # Push updated vX.Y branch and vX.Y.{Z+1} tag 162 | echo UPDATED_REFS="${{ github.ref_name }} v$version_final_new" >> $GITHUB_ENV 163 | 164 | - name: Push changes to the remote repository 165 | run: | 166 | # For verbosity, show git tags and git logs: 167 | git tag --list --sort=-v:refname 168 | for ref in ${{ env.UPDATED_REFS }}; do 169 | echo "git log on $ref:" 170 | git log -n 10 --oneline --decorate $ref 171 | done 172 | 173 | if [[ ${{ inputs.dry-run }} == false ]]; then 174 | for ref in ${{ env.UPDATED_REFS }}; do 175 | git push origin $ref 176 | done 177 | fi 178 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Annotated, Any 3 | 4 | from kumoapi.rfm import Explanation 5 | from kumoapi.typing import Dtype, Stype 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class TableSource(BaseModel): 10 | """Source information of a table.""" 11 | path: Annotated[ 12 | Path, 13 | "Path to a local file. Only CSV or Parquet files are supported.", 14 | ] 15 | bytes: Annotated[int, "Size in bytes of the file"] 16 | 17 | 18 | class TableSourcePreview(BaseModel): 19 | """Preview of the first rows of a table-like file.""" 20 | rows: Annotated[ 21 | list[dict[str, Any]], 22 | Field( 23 | default_factory=list, 24 | description=("Each row in the table source is represented as a " 25 | "dictionary mapping column names to their " 26 | "corresponding values."), 27 | ), 28 | ] 29 | 30 | 31 | class TableMetadata(BaseModel): 32 | """Metadata for a table.""" 33 | path: Annotated[ 34 | str, 35 | ("Path to the table. Can be a local file path, an S3 URI " 36 | "(s3://...), or an HTTP/HTTPS URL."), 37 | ] 38 | name: Annotated[str, "Name of the table"] 39 | num_rows: Annotated[int, "Number of rows in the table"] 40 | dtypes: Annotated[ 41 | dict[str, Dtype], 42 | "Column names mapped to their data types", 43 | ] 44 | stypes: Annotated[ 45 | dict[str, Stype | None], 46 | "Column names mapped to their semantic types or `None` if they have " 47 | "been discarded", 48 | ] 49 | primary_key: Annotated[str | None, "Name of the primary key column"] 50 | time_column: Annotated[ 51 | str | None, 52 | "Name of the time column marking when the record becomes active", 53 | ] 54 | end_time_column: Annotated[ 55 | str | None, 56 | ("Name of the end time column marking when the record stops being " 57 | "active"), 58 | ] 59 | 60 | 61 | class AddTableMetadata(BaseModel): 62 | """Metadata to add a new table.""" 63 | path: Annotated[ 64 | str, 65 | ("Path to the table. Can be a local file path, an S3 URI " 66 | "(s3://...), or an HTTP/HTTPS URL."), 67 | ] 68 | name: Annotated[str, "Name of the table"] 69 | primary_key: Annotated[ 70 | str | None, 71 | Field( 72 | default=None, 73 | description="Name of the primary key column", 74 | ), 75 | ] 76 | time_column: Annotated[ 77 | str | None, 78 | Field( 79 | default=None, 80 | description=("Name of the time column marking when the record " 81 | "becomes active"), 82 | ), 83 | ] 84 | end_time_column: Annotated[ 85 | str | None, 86 | Field( 87 | default=None, 88 | description=("Name of the end time column marking when the record " 89 | "stops being active"), 90 | ), 91 | ] 92 | 93 | 94 | class UpdateTableMetadata(BaseModel): 95 | """Metadata updates to perform for a table.""" 96 | stypes: Annotated[ 97 | dict[str, Stype | None], 98 | Field( 99 | default_factory=dict, 100 | description=("Update the semantic type of column names. Set to " 101 | "`None` if the column should be discarded. Omitted " 102 | "columns will be untouched."), 103 | ), 104 | ] 105 | primary_key: Annotated[ 106 | str | None, 107 | Field( 108 | default=None, 109 | description=("Update the primary key column. Set to `None` if the " 110 | "primary key should be discarded. If omitted, the " 111 | "current primary key will be untouched."), 112 | ), 113 | ] 114 | time_column: Annotated[ 115 | str | None, 116 | Field( 117 | default=None, 118 | description=("Update the time column. Set to `None` if the time " 119 | "column should be discarded. If omitted, the current " 120 | "time column will be untouched."), 121 | ), 122 | ] 123 | end_time_column: Annotated[ 124 | str | None, 125 | Field( 126 | default=None, 127 | description=("Update the end time column. Set to `None` if the " 128 | "end time column should be discarded. If omitted, " 129 | "the current end time column will be untouched."), 130 | ), 131 | ] 132 | 133 | 134 | class LinkMetadata(BaseModel): 135 | """Metadata for defining a link between two tables via foreign key-primary 136 | key relationships.""" 137 | source_table: Annotated[ 138 | str, 139 | "Name of the source table containing the foreign key", 140 | ] 141 | foreign_key: Annotated[str, "Name of the foreign key column"] 142 | destination_table: Annotated[ 143 | str, 144 | "Name of the destination table containing the primary key to link to", 145 | ] 146 | 147 | 148 | class GraphMetadata(BaseModel): 149 | """Metadata of a graph holding multiple tables connected via foreign 150 | key-primary key relationships.""" 151 | tables: Annotated[list[TableMetadata], "List of tables"] 152 | links: Annotated[list[LinkMetadata], "List of links"] 153 | 154 | 155 | class UpdateGraphMetadata(BaseModel): 156 | """Metadata updates to perform for a graph holding multiple tables 157 | connected via foreign key-primary key relationships.""" 158 | tables_to_add: Annotated[ 159 | list[AddTableMetadata], 160 | Field(default_factory=list, description="Tables to add"), 161 | ] 162 | tables_to_update: Annotated[ 163 | dict[str, UpdateTableMetadata], 164 | Field( 165 | default_factory=dict, 166 | description="Tables to update. Omitted tables will be untouched.", 167 | ), 168 | ] 169 | links_to_remove: Annotated[ 170 | list[LinkMetadata], 171 | Field(default_factory=list, description="Links to remove"), 172 | ] 173 | links_to_add: Annotated[ 174 | list[LinkMetadata], 175 | Field(default_factory=list, description="Links to add"), 176 | ] 177 | tables_to_remove: Annotated[ 178 | list[str], 179 | Field(default_factory=list, description="Tables to remove"), 180 | ] 181 | 182 | 183 | class UpdatedGraphMetadata(BaseModel): 184 | """Updated metadata of a graph holding multiple tables connected via " 185 | "foreign key-primary key relationships.""" 186 | graph: Annotated[GraphMetadata, "Updated graph metadata"] 187 | errors: Annotated[ 188 | list[str], 189 | Field( 190 | default_factory=list, 191 | description="Any errors encountered during the update process", 192 | ), 193 | ] 194 | 195 | 196 | class MaterializedGraphInfo(BaseModel): 197 | """Information about the materialized graph.""" 198 | num_nodes: Annotated[int, "Number of nodes in the graph"] 199 | num_edges: Annotated[int, "Number of edges in the graph"] 200 | time_ranges: Annotated[ 201 | dict[str, str], 202 | Field( 203 | default_factory=dict, 204 | description=("Earliest to latest timestamp for each table in the " 205 | "graph that contains a time column"), 206 | ), 207 | ] 208 | 209 | 210 | class PredictResponse(BaseModel): 211 | predictions: Annotated[ 212 | list[dict[str, Any]], 213 | Field( 214 | default_factory=list, 215 | description=( 216 | "The predictions, where each row holds information about the " 217 | "entity, the anchor time, and the prediction scores"), 218 | ), 219 | ] 220 | logs: Annotated[ 221 | list[str], 222 | Field( 223 | default_factory=list, 224 | description=("Prediction-specific log messages such as number of " 225 | "context examples, the underlying task type and the " 226 | "label distribution"), 227 | ), 228 | ] 229 | 230 | 231 | class EvaluateResponse(BaseModel): 232 | metrics: Annotated[ 233 | dict[str, float | None], 234 | Field( 235 | default_factory=dict, 236 | description="The metric value for every metric", 237 | ), 238 | ] 239 | logs: Annotated[ 240 | list[str], 241 | Field( 242 | default_factory=list, 243 | description=("Evaluation-specific log messages such as number of " 244 | "context and test examples, the underlying task type " 245 | "and the label distribution"), 246 | ), 247 | ] 248 | 249 | 250 | class ExplanationResponse(BaseModel): 251 | prediction: Annotated[ 252 | dict[str, Any], 253 | ("The prediction, holding information about the entity, the anchor " 254 | "time, and the prediction scores"), 255 | ] 256 | explanation: Annotated[ 257 | Explanation, 258 | ("The explanation of the prediction. Provides both a global, " 259 | "column-level analysis and a local, cell-level attribution view. " 260 | "The global analysis clusters column distributions of in-context " 261 | "examples into cohorts and relates them to their relevance with " 262 | "respect to ground-truth labels. The local view computes " 263 | "gradient-based attribution scores over prediction subgraphs. " 264 | "Together, these views enable comprehensive interpretation."), 265 | ] 266 | logs: Annotated[ 267 | list[str], 268 | Field( 269 | default_factory=list, 270 | description=("Prediction-specific log messages such as number of " 271 | "context examples, the underlying task type and the " 272 | "label distribution"), 273 | ), 274 | ] 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

KumoRFM MCP Server

4 |
5 | 6 |
7 |

8 | KumoRFM • 9 | Notebooks • 10 | Blog • 11 | Get an API key 12 |

13 | 14 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/kumo-rfm-mcp?color=FC1373)](https://pypi.org/project/kumo-rfm-mcp/) 15 | [![PyPI Status](https://img.shields.io/pypi/v/kumo-rfm-mcp.svg?color=FC1373)](https://pypi.org/project/kumo-rfm-mcp/) 16 | [![Slack](https://img.shields.io/badge/slack-join-pink.svg?logo=slack&color=FC1373)](https://join.slack.com/t/kumoaibuilders/shared_invite/zt-2z9uih3lf-fPM1z2ACZg~oS3ObmiQLKQ) 17 | 18 | 🔬 MCP server to query [KumoRFM](https://kumorfm.ai) in your agentic flows 19 | 20 |
21 | 22 | ## 📖 Introduction 23 | 24 | KumoRFM is a pre-trained *Relational Foundation Model (RFM)* that generates training-free predictions on any relational multi-table data by interpreting the data as a (temporal) heterogeneous graph. 25 | It can be queried via the *Predictive Query Language (PQL)*. 26 | 27 | This repository hosts a full-featured *MCP (Model Context Protocol)* server that empowers AI assistants with KumoRFM intelligence. 28 | This server enables: 29 | 30 | - 🕸️ Build, manage, and visualize graphs directly from CSV or Parquet files 31 | - 💬 Convert natural language into PQL queries for seamless interaction 32 | - 🤖 Query, analyze, and evaluate predictions from KumoRFM (missing value imputation, temporal forecasting, *etc*) all without any training required 33 | 34 | ## 🚀 Installation 35 | 36 | ### 🐍 Traditional MCP Server 37 | 38 | The KumoRFM MCP server is available for Python 3.10 and above. To install, simply run: 39 | 40 | ```bash 41 | pip install kumo-rfm-mcp 42 | ``` 43 | 44 | Add to your MCP configuration file (*e.g.*, Claude Desktop's `mcp_config.json`): 45 | 46 | ```json 47 | { 48 | "mcpServers": { 49 | "kumo-rfm": { 50 | "command": "python", 51 | "args": ["-m", "kumo_rfm_mcp.server"], 52 | "env": { 53 | "KUMO_API_KEY": "your_api_key_here" 54 | } 55 | } 56 | } 57 | } 58 | ``` 59 | 60 | ### ⚡ MCP Bundle 61 | 62 | We provide a single-click installation via our [MCP Bundle (MCPB)](https://github.com/anthropics/mcpb) (*e.g.*, for integration into Claude Desktop): 63 | 64 | 1. Download the `dxt` file from [here](https://kumo-sdk-public.s3.us-west-2.amazonaws.com/dxt/kumo-rfm-mcp-0.2.0.dxt) 65 | 1. Double click to install 66 | 67 | 68 | 69 | The MCP Bundle supports Linux, macOS and Windows, but requires a Python executable to be found in order to create a separate new virtual environment. 70 | 71 | ### Claude code 72 | 73 | To include the server in claude code use: 74 | 75 | ``` 76 | claude mcp add --transport stdio kumo-rfm-mcp --env KUMO_API_KEY= -- python -m kumo_rfm_mcp.server --port 8000 77 | ``` 78 | 79 | ## 🎬 Claude Desktop Demo 80 | 81 | See [here](https://claude.ai/share/d2a34e63-b1d2-4255-b3e9-a6cb55004497) for the transcript. 82 | 83 | https://github.com/user-attachments/assets/56192b0b-d9df-425f-9c10-8517c754420f 84 | 85 | ## 🔬 Agentic Workflows 86 | 87 | You can use the KumoRFM MCP directly in your agentic workflows: 88 | 89 | 90 | 91 | 98 | 117 | 118 | 119 | 130 | 147 | 148 | 149 | 160 | 175 | 176 | 177 | 186 | 207 | 208 |
92 | 93 | 94 | 95 |
96 | [Example] 97 |

 99 | from crewai import Agent
100 | from crewai_tools import MCPServerAdapter
101 | from mcp import StdioServerParameters
102 | 
103 | params = StdioServerParameters( 104 | command='python', 105 | args=['-m', 'kumo_rfm_mcp.server'], 106 | env={'KUMO_API_KEY': ...}, 107 | ) 108 |
109 | with MCPServerAdapter(params) as mcp_tools: 110 | agent = Agent( 111 | role=..., 112 | goal=..., 113 | backstory=..., 114 | tools=mcp_tools, 115 | ) 116 |
120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 |
128 | [Example] 129 |

131 | from langchain_mcp_adapter.client MultiServerMCPClient
132 | from langgraph.prebuilt import create_react_agent
133 | 
134 | client = MultiServerMCPClient({ 135 | 'kumo-rfm': { 136 | 'command': 'python', 137 | 'args': ['-m', 'kumo_rfm_mcp.server'], 138 | 'env': {'KUMO_API_KEY': ...}, 139 | } 140 | }) 141 |
142 | agent = create_react_agent( 143 | llm=..., 144 | tools=await client.get_tools(), 145 | ) 146 |
150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 |
158 | [Example] 159 |

161 | from agents import Agent
162 | from agents.mcp import MCPServerStdio
163 | 
164 | async with MCPServerStdio(params={ 165 | 'command': 'python', 166 | 'args': ['-m', 'kumo_rfm_mcp.server'], 167 | 'env': {'KUMO_API_KEY': ...}, 168 | }) as server: 169 | agent = Agent( 170 | name=..., 171 | instructions=..., 172 | mcp_servers=[server], 173 | ) 174 |
178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 |

187 | from claude_code_sdk import query, ClaudeCodeOptions
188 | 
189 | mcp_servers = { 190 | 'kumo-rfm': { 191 | 'command': 'python', 192 | 'args': ['-m', 'kumo_rfm_mcp.server'], 193 | 'env': {'KUMO_API_KEY': ...}, 194 | } 195 | } 196 |
197 | async for message in query( 198 | prompt=..., 199 | options=ClaudeCodeOptions( 200 | system_prompt=..., 201 | mcp_servers=mcp_servers, 202 | permission_mode='default', 203 | ), 204 | ): 205 | ... 206 |
209 | 210 | Browse our [examples](https://github.com/kumo-ai/kumo-rfm/tree/master/notebooks) to get started with agentic workflows powered by KumoRFM. 211 | 212 | ## 📚 Available Tools 213 | 214 | ### I/O Operations 215 | 216 | - **🔍 `find_table_files` - Searching for tabular files:** Find all table-like files (*e.g.*, CSV, Parquet) in a directory. 217 | - **🧐 `inspect_table_files` - Analyzing table structure:** Inspect the first rows of table-like files. 218 | 219 | ### Graph Management 220 | 221 | - **🗂️ `inspect_graph_metadata` - Reviewing graph schema:** Inspect the current graph metadata. 222 | - **🔄 `update_graph_metadata` - Updating graph schema:** Partially update the current graph metadata. 223 | - **🖼️ `get_mermaid` - Creating graph diagram:** Return the graph as a Mermaid entity relationship diagram. 224 | - **🕸️ `materialize_graph` - Assembling graph:** Materialize the graph based on the current state of the graph metadata to make it available for inference operations. 225 | - **📂 `lookup_table_rows` - Retrieving table entries:** Lookup rows in the raw data frame of a table for a list of primary keys. 226 | 227 | ### Model Execution 228 | 229 | - **🤖 `predict` - Running predictive query:** Execute a predictive query and return model predictions. 230 | - **📊 `evaluate` - Evaluating predictive query:** Evaluate a predictive query and return performance metrics which compares predictions against known ground-truth labels from historical examples. 231 | - **🧠 `explain` - Explaining prediction:** Execute a predictive query and explain the model prediction. 232 | 233 | ## 🔧 Configuration 234 | 235 | ### Environment Variables 236 | 237 | - **`KUMO_API_KEY`:** Authentication is needed once before predicting or evaluating with the 238 | KumoRFM model. 239 | You can generate your KumoRFM API key for free [here](https://kumorfm.ai). 240 | If not set, you can also authenticate on-the-fly in individual session via an OAuth2 flow. 241 | 242 | ## We love your feedback! :heart: 243 | 244 | As you work with KumoRFM, if you encounter any problems or things that are confusing or don't work quite right, please open a new :octocat:[issue](https://github.com/kumo-ai/kumo-rfm-mcp/issues/new). 245 | You can also submit general feedback and suggestions [here](https://docs.google.com/forms/d/e/1FAIpQLSfr2HYgJN8ghaKyvU0PSRkqrGd_BijL3oyQTnTxLrf8AEk-EA/viewform). 246 | Join [our Slack](https://join.slack.com/t/kumoaibuilders/shared_invite/zt-2z9uih3lf-fPM1z2ACZg~oS3ObmiQLKQ)! 247 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/tools/model.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | from typing import Annotated, Literal 4 | 5 | import pandas as pd 6 | from fastmcp import FastMCP 7 | from fastmcp.exceptions import ToolError 8 | from kumoai.utils import ProgressLogger 9 | from pydantic import Field 10 | 11 | from kumo_rfm_mcp import ( 12 | EvaluateResponse, 13 | ExplanationResponse, 14 | PredictResponse, 15 | SessionManager, 16 | ) 17 | 18 | query_doc = ("The predictive query string, e.g., " 19 | "'PREDICT COUNT(orders.*, 0, 30, days)>0 FOR EACH users.user_id' " 20 | "or 'PREDICT users.age FOR EACH users.user_id'") 21 | indices_doc = ("The primary keys (entity indices) to generate predictions " 22 | "for. Up to 1000 entities are supported for an individual " 23 | "query. Predictions will be generated for all indices, " 24 | "regardless of whether they match any entity filter " 25 | "constraints.") 26 | anchor_time_doc = ( 27 | "The anchor time for which we are making a prediction for the " 28 | "the future. If `None`, will use the maximum timestamp in the " 29 | "data as anchor time. If 'entity', will use the timestamp of " 30 | "the entity's time column as anchor time (only valid for " 31 | "static predictive queries for which the entity table " 32 | "contains a time column), which is useful to prevent future " 33 | "data leakage when imputing missing values on facts, e.g., " 34 | "predicting whether a transaction is fraudulent should " 35 | "happen at the point in time the transaction was created.") 36 | run_mode_doc = ( 37 | "The run mode for the query. Trades runtime with model performance. The " 38 | "run mode dictates how many training/in-context examples are sampled to " 39 | "make a prediction, i.e. 1000 for 'fast', 5000 for 'normal', and 10000 " 40 | "for 'best'.") 41 | num_neighbors_doc = ( 42 | "The number of neighbors to sample for each hop to create subgraphs. For " 43 | "example, `[24, 12]` samples 24 neighbors in the first hop and 12 " 44 | "neighbors in the second hop. If `None` (recommended), will use two-hop " 45 | "sampling with 32 neighbors in 'fast' mode, and 64 neighbors otherwise in " 46 | "each hop. Up to 6-hop subgraphs are supported. Decreasing the number of " 47 | "neighbors per hop can prevent oversmoothing. Increasing the number of " 48 | "neighbors per hop allows the model to look at a larger historical time " 49 | "window. Increasing the number of hops can improve performance in case " 50 | "important signal is far away from the entity table, but can result in " 51 | "massive subgraphs. We advise to let the number of neighbors gradually " 52 | "shrink down in later hops to prevent recursive neighbor explosion, e.g., " 53 | "`num_neighbors=[32, 32, 4, 4, 2, 2]`, if more hops are required.") 54 | max_pq_iterations_doc = ( 55 | "The maximum number of iterations to perform to collect valid training/" 56 | "in-context examples. It is advised to increase the number of iterations " 57 | "in case the model fails to find the upper bound of supported training " 58 | "examples w.r.t. the run mode, *i.e.* 1000 for 'fast', 5000 for 'normal' " 59 | "and 10000 for 'best'.") 60 | metrics_doc = ( 61 | "The metrics to use for evaluation. If `None`, will use a pre-selection " 62 | "of metrics depending on the given predictive query. The following metrics" 63 | "are supported:\n" 64 | "Binary classification: 'acc', 'precision', 'recall', 'f1', 'auroc', " 65 | "'auprc', 'ap'\n" 66 | "Multi-class classification: 'acc', 'precision', 'recall', 'f1', 'mrr'\n" 67 | "Regression: 'mae', 'mape', 'mse', 'rmse', 'smape', 'r2'\n" 68 | "Temporal link prediction: 'map@k', 'ndcg@k', 'mrr@k', 'precision@k', " 69 | "'recall@k', 'f1@k', 'hit_ratio@k' where 'k' needs to be an integer " 70 | "between 1 and 100") 71 | 72 | 73 | async def predict( 74 | query: Annotated[str, query_doc], 75 | indices: Annotated[list[str] | list[float] | list[int], indices_doc], 76 | anchor_time: Annotated[ 77 | datetime | Literal['entity'] | None, 78 | Field(default=None, description=anchor_time_doc), 79 | ], 80 | run_mode: Annotated[ 81 | Literal['fast', 'normal', 'best'], 82 | Field(default='fast', description=run_mode_doc), 83 | ], 84 | num_neighbors: Annotated[ 85 | list[int] | None, 86 | Field( 87 | default=None, 88 | min_length=0, 89 | max_length=6, 90 | description=num_neighbors_doc, 91 | ), 92 | ], 93 | max_pq_iterations: Annotated[ 94 | int, 95 | Field(default=20, description=max_pq_iterations_doc), 96 | ], 97 | ) -> PredictResponse: 98 | """Execute a predictive query and return model predictions. 99 | 100 | The graph needs to be materialized and the session needs to be 101 | authenticated before the KumoRFM model can start generating predictions. 102 | 103 | The output prediction format depends on the given task type. 104 | 105 | Binary classification: 106 | | ENTITY | ANCHOR_TIMESTAMP | TARGET_PRED | False_PROB | True_PROB | 107 | where 'ENTITY' holds the entity ID, 'ANCHOR_TIMESTAMP' holds the anchor 108 | time of the prediction in unix format, 'TARGET_PRED' holds the final 109 | prediction based on a threshold of 0.5, and 'False_PROB' and 'True_PROB' 110 | hold the probabilities. 111 | 112 | Multi-class classification: 113 | | ENTITY | ANCHOR_TIMESTAMP | CLASS | SCORE | PREDICTED | 114 | where 'ENTITY' holds the entity ID, 'ANCHOR_TIMESTAMP' holds the anchor 115 | time of the prediction in unix format. Each row corresponds to an (ENTITY, 116 | CLASS) pair (up to 10 classes are reported), where 'CLASS' holds the 117 | predicted value, 'SCORE' holds its probability, and 'PREDICTED' denotes 118 | whether the (ENTITY, CLASS) pair has the highest likelihood. 119 | 120 | Regression: 121 | | ENTITY | ANCHOR_TIMESTAMP | TARGET_PRED | 122 | where 'ENTITY' holds the entity ID, 'ANCHOR_TIMESTAMP' holds the anchor 123 | time of the prediction in unix format, and 'TARGET_PRED' holds the 124 | predicted numerical value. 125 | 126 | Temporal link prediction: 127 | | ENTITY | ANCHOR_TIMESTAMP | CLASS | SCORE | 128 | where 'ENTITY' holds the entity ID, 'ANCHOR_TIMESTAMP' holds the anchor 129 | time of the prediction in unix format. Each row corresponds to an (ENTITY, 130 | CLASS) pair, where 'CLASS' holds the recommended item and 'SCORE' holds its 131 | likelihood. 132 | 133 | Important: Before executing or suggesting any predictive queries, 134 | read the documentation first at 'kumo://docs/predictive-query'. 135 | """ 136 | model = SessionManager.get_default_session().model 137 | 138 | if anchor_time is not None and anchor_time != "entity": 139 | anchor_time = pd.Timestamp(anchor_time) 140 | 141 | def _predict() -> PredictResponse: 142 | logger = ProgressLogger(query) 143 | 144 | try: 145 | df = model.predict( 146 | query, 147 | indices=indices, 148 | anchor_time=anchor_time, 149 | run_mode=run_mode, 150 | num_neighbors=num_neighbors, 151 | max_pq_iterations=max_pq_iterations, 152 | verbose=logger, 153 | ) 154 | except Exception as e: 155 | raise ToolError(f"Prediction failed: {e}") from e 156 | 157 | logs = logger.logs 158 | if logger.start_time is not None: 159 | logs = logs + [f'Duration: {logger.duration:2f}s'] 160 | 161 | return PredictResponse( 162 | predictions=df.to_dict(orient='records'), 163 | logs=logs, 164 | ) 165 | 166 | return await asyncio.to_thread(_predict) 167 | 168 | 169 | async def evaluate( 170 | query: Annotated[str, query_doc], 171 | metrics: Annotated[ 172 | list[str] | None, 173 | Field(default=None, description=metrics_doc), 174 | ], 175 | anchor_time: Annotated[ 176 | datetime | Literal['entity'] | None, 177 | Field(default=None, description=anchor_time_doc), 178 | ], 179 | run_mode: Annotated[ 180 | Literal['fast', 'normal', 'best'], 181 | Field(default='fast', description=run_mode_doc), 182 | ], 183 | num_neighbors: Annotated[ 184 | list[int] | None, 185 | Field( 186 | default=None, 187 | min_length=0, 188 | max_length=6, 189 | description=num_neighbors_doc, 190 | ), 191 | ], 192 | max_pq_iterations: Annotated[ 193 | int, 194 | Field(default=20, description=max_pq_iterations_doc), 195 | ], 196 | ) -> EvaluateResponse: 197 | """Evaluate a predictive query and return performance metrics which 198 | compares predictions against known ground-truth labels from historical 199 | examples. 200 | 201 | The graph needs to be materialized and the session needs to be 202 | authenticated before the KumoRFM model can start evaluating. 203 | 204 | Take the label distribution of the predictive query in the output logs into 205 | account when analyzing the returned metrics. 206 | 207 | Important: Before executing or suggesting any predictive queries, 208 | read the documentation first at 'kumo://docs/predictive-query'. 209 | """ 210 | model = SessionManager.get_default_session().model 211 | 212 | if anchor_time is not None and anchor_time != "entity": 213 | anchor_time = pd.Timestamp(anchor_time) 214 | 215 | def _evaluate() -> EvaluateResponse: 216 | logger = ProgressLogger(query) 217 | 218 | try: 219 | df = model.evaluate( 220 | query, 221 | metrics=metrics, 222 | anchor_time=anchor_time, 223 | run_mode=run_mode, 224 | num_neighbors=num_neighbors, 225 | max_pq_iterations=max_pq_iterations, 226 | verbose=logger, 227 | ) 228 | except Exception as e: 229 | raise ToolError(f"Evaluation failed: {e}") from e 230 | 231 | df = df.astype(object).where(df.notna(), None) 232 | 233 | logs = logger.logs 234 | if logger.start_time is not None: 235 | logs = logs + [f'Duration: {logger.duration:2f}s'] 236 | 237 | return EvaluateResponse( 238 | metrics=df.set_index('metric')['value'].to_dict(), 239 | logs=logs, 240 | ) 241 | 242 | return await asyncio.to_thread(_evaluate) 243 | 244 | 245 | async def explain( 246 | query: Annotated[str, query_doc], 247 | index: Annotated[ 248 | str | float | int, 249 | "The primary key (entity index) of the prediction to explain", 250 | ], 251 | anchor_time: Annotated[ 252 | datetime | Literal['entity'] | None, 253 | Field(default=None, description=anchor_time_doc), 254 | ], 255 | num_neighbors: Annotated[ 256 | list[int] | None, 257 | Field( 258 | default=None, 259 | min_length=0, 260 | max_length=6, 261 | description=num_neighbors_doc, 262 | ), 263 | ], 264 | max_pq_iterations: Annotated[ 265 | int, 266 | Field(default=20, description=max_pq_iterations_doc), 267 | ], 268 | ) -> ExplanationResponse: 269 | """Execute a predictive query and explain the model prediction. 270 | 271 | The graph needs to be materialized and the session needs to be 272 | authenticated before the KumoRFM model can start generating an explanation 273 | for a prediction. 274 | 275 | Only a single entity prediction can be explained at a time. 276 | The `run_mode` will be fixed to `'fast'` mode for explainability. 277 | Note that the model prediction returned by the explanation might differ 278 | slightly from the result of the `predict` tool due to floating-point 279 | precision. Ignore such small differences. 280 | 281 | Important: Before executing or suggesting any predictive queries, 282 | read the documentation first at 'kumo://docs/predictive-query'. 283 | 284 | Important: Before analyzing the explanation output, read the documentation 285 | first at 'kumo://docs/explainability'. 286 | """ 287 | model = SessionManager.get_default_session().model 288 | 289 | if anchor_time is not None and anchor_time != "entity": 290 | anchor_time = pd.Timestamp(anchor_time) 291 | 292 | def _explain() -> ExplanationResponse: 293 | logger = ProgressLogger(query) 294 | 295 | try: 296 | out = model.predict( 297 | query, 298 | indices=[index], 299 | explain=dict(skip_summary=True), 300 | anchor_time=anchor_time, 301 | num_neighbors=num_neighbors, 302 | max_pq_iterations=max_pq_iterations, 303 | verbose=logger, 304 | ) 305 | except Exception as e: 306 | raise ToolError(f"Explanation failed: {e}") from e 307 | 308 | logs = logger.logs 309 | if logger.start_time is not None: 310 | logs = logs + [f'Duration: {logger.duration:2f}s'] 311 | 312 | return ExplanationResponse( 313 | prediction=out.prediction.to_dict(orient='records')[0], 314 | explanation=out.details, 315 | logs=logs, 316 | ) 317 | 318 | return await asyncio.to_thread(_explain) 319 | 320 | 321 | def register_model_tools(mcp: FastMCP) -> None: 322 | """Register all model tools to the MCP server.""" 323 | mcp.tool(annotations=dict( 324 | title="🤖 Running predictive query…", 325 | readOnlyHint=True, 326 | destructiveHint=False, 327 | idempotentHint=True, 328 | openWorldHint=False, 329 | ))(predict) 330 | 331 | mcp.tool(annotations=dict( 332 | title="📊 Evaluating predictive query…", 333 | readOnlyHint=True, 334 | destructiveHint=False, 335 | idempotentHint=True, 336 | openWorldHint=False, 337 | ))(evaluate) 338 | 339 | mcp.tool(annotations=dict( 340 | title="🧠 Explaining prediction…", 341 | readOnlyHint=True, 342 | destructiveHint=False, 343 | idempotentHint=True, 344 | openWorldHint=False, 345 | ))(explain) 346 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/tools/graph.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os.path as osp 3 | from collections import defaultdict 4 | from typing import Annotated 5 | 6 | import pandas as pd 7 | from fastmcp import FastMCP 8 | from fastmcp.exceptions import ToolError 9 | from kumoai.experimental import rfm 10 | from kumoai.graph import Edge 11 | from kumoai.utils import ProgressLogger 12 | from kumoapi.typing import Dtype, Stype 13 | from pydantic import Field 14 | 15 | from kumo_rfm_mcp import ( 16 | GraphMetadata, 17 | LinkMetadata, 18 | MaterializedGraphInfo, 19 | SessionManager, 20 | TableMetadata, 21 | TableSourcePreview, 22 | UpdatedGraphMetadata, 23 | UpdateGraphMetadata, 24 | ) 25 | 26 | _materialize_lock = asyncio.Lock() 27 | 28 | 29 | def inspect_graph_metadata() -> GraphMetadata: 30 | """Inspect the current graph metadata. 31 | 32 | Confirming that the metadata is set up correctly is crucial for the RFM 33 | model to work properly. In particular, 34 | 35 | * primary keys and time columns need to be correctly specified for each 36 | table in case they exist; 37 | * columns need to point to a valid semantic type that describe their 38 | semantic meaning, or ``None`` if they have been discarded; 39 | * links need to point to valid foreign key-primary key relationships. 40 | """ 41 | session = SessionManager.get_default_session() 42 | 43 | tables: list[TableMetadata] = [] 44 | for table in session.graph.tables.values(): 45 | dtypes: dict[str, Dtype] = {} 46 | stypes: dict[str, Stype | None] = {} 47 | for column in table._data.columns: 48 | if column in table: 49 | dtypes[column] = table[column].dtype 50 | stypes[column] = table[column].stype 51 | else: 52 | dtypes[column] = rfm.utils.to_dtype(table._data[column]) 53 | stypes[column] = None 54 | tables.append( 55 | TableMetadata( 56 | path=table._path, 57 | name=table.name, 58 | num_rows=len(table._data), 59 | dtypes=dtypes, 60 | stypes=stypes, 61 | primary_key=table._primary_key, 62 | time_column=table._time_column, 63 | end_time_column=table._end_time_column, 64 | )) 65 | 66 | links: list[LinkMetadata] = [] 67 | for edge in session.graph.edges: 68 | links.append( 69 | LinkMetadata( 70 | source_table=edge.src_table, 71 | foreign_key=edge.fkey, 72 | destination_table=edge.dst_table, 73 | )) 74 | 75 | return GraphMetadata(tables=tables, links=links) 76 | 77 | 78 | def update_graph_metadata(update: UpdateGraphMetadata) -> UpdatedGraphMetadata: 79 | """Partially update the current graph metadata. 80 | 81 | Setting up the metadata is crucial for the RFM model to work properly. In 82 | particular, 83 | 84 | * primary keys and time columns need to be correctly specified for each 85 | table in case they exist; 86 | * columns need to point to a valid semantic type that describe their 87 | semantic meaning, or ``None`` if they should be discarded; 88 | * links need to point to valid foreign key-primary key relationships. 89 | 90 | Omitted fields will be untouched. 91 | 92 | For newly added tables, it is advised to double-check semantic types and 93 | modify in a follow-up step if necessary. 94 | 95 | Make sure that tables are correctly linked before proceeding. 96 | 97 | Note that all operations can be performed in a batch at once, *e.g.*, one 98 | can add new tables and directly link them to together. 99 | 100 | Important: Before creating and updating graphs, read the documentation 101 | first at 'kumo://docs/graph-setup'. 102 | """ 103 | session = SessionManager.get_default_session() 104 | session._model = None # Need to reset the model if graph changes. 105 | graph = session.graph 106 | 107 | errors: list[str] = [] 108 | for table in update.tables_to_add: 109 | path = osp.expanduser(table.path) 110 | suffix = path.rsplit('.', maxsplit=1)[-1].lower() 111 | 112 | if table.name in graph and graph[table.name]._path == path: 113 | graph[table.name].primary_key = table.primary_key 114 | graph[table.name].time_column = table.time_column 115 | graph[table.name].end_time_column = table.end_time_column 116 | continue 117 | 118 | if suffix not in {'csv', 'parquet'}: 119 | errors.append(f"'{path}' is not a valid CSV or Parquet file") 120 | continue 121 | 122 | try: 123 | if suffix == 'csv': 124 | df = pd.read_csv(path) 125 | else: 126 | assert suffix == 'parquet' 127 | df = pd.read_parquet(path) 128 | except Exception as e: 129 | errors.append(f"Could not read file '{path}': {e}") 130 | continue 131 | 132 | try: 133 | local_table = rfm.LocalTable( 134 | df=df, 135 | name=table.name, 136 | primary_key=table.primary_key, 137 | time_column=table.time_column, 138 | end_time_column=table.end_time_column, 139 | ) 140 | local_table._path = path 141 | graph.add_table(local_table) 142 | except Exception as e: 143 | errors.append(f"Could not add table '{table.name}': {e}") 144 | continue 145 | 146 | # Only keep specified keys: 147 | update_dict = update.model_dump(exclude_unset=True) 148 | tables_to_update = update_dict.get('tables_to_update', {}) 149 | for table_name, table_update in tables_to_update.items(): 150 | try: 151 | stypes = table_update.get('stypes', {}) 152 | for column_name, stype in stypes.items(): 153 | if column_name not in graph[table_name]: 154 | graph[table_name].add_column(column_name) 155 | if stype is None: 156 | del graph[table_name][column_name] 157 | else: 158 | graph[table_name][column_name].stype = stype 159 | if 'primary_key' in table_update: 160 | graph[table_name].primary_key = table_update['primary_key'] 161 | if 'time_column' in table_update: 162 | graph[table_name].time_column = table_update['time_column'] 163 | if 'end_time_column' in table_update: 164 | graph[table_name].end_time_column = table_update[ 165 | 'end_time_column'] 166 | except Exception as e: 167 | errors.append(f"Could not fully update table '{table_name}': {e}") 168 | continue 169 | 170 | for link in update.links_to_remove: 171 | try: 172 | graph.unlink( 173 | link.source_table, 174 | link.foreign_key, 175 | link.destination_table, 176 | ) 177 | except Exception: 178 | continue 179 | 180 | for link in update.links_to_add: 181 | if Edge( 182 | src_table=link.source_table, 183 | fkey=link.foreign_key, 184 | dst_table=link.destination_table, 185 | ) in graph.edges: 186 | continue 187 | 188 | try: 189 | graph.link( 190 | link.source_table, 191 | link.foreign_key, 192 | link.destination_table, 193 | ) 194 | except Exception as e: 195 | errors.append(f"Could not add link from source table " 196 | f"'{link.source_table}' to destination table " 197 | f"'{link.destination_table}' via the " 198 | f"'{link.foreign_key}' column: {e}") 199 | continue 200 | 201 | for table_name in update.tables_to_remove: 202 | try: 203 | del graph[table_name] 204 | except Exception: 205 | continue 206 | 207 | try: 208 | graph.validate() 209 | except Exception as e: 210 | errors.append(f"Final graph validation failed: {e}") 211 | 212 | return UpdatedGraphMetadata(graph=inspect_graph_metadata(), errors=errors) 213 | 214 | 215 | def get_mermaid( 216 | show_columns: Annotated[ 217 | bool, 218 | Field( 219 | default=True, 220 | description=("Controls whether all columns of a table are shown. " 221 | "If `False`, only the primary key, foreign keys and " 222 | "time column are displayed. Setting this to `False` " 223 | "is recommended for feature-rich tables to avoid " 224 | "cluttering the diagram with less relevant details."), 225 | ), 226 | ], 227 | ) -> str: 228 | """Return the graph as a Mermaid entity relationship diagram. 229 | 230 | Important: The returned Mermaid markup can be used to input into an 231 | artifact to render it visually on the client side. 232 | """ 233 | session = SessionManager.get_default_session() 234 | 235 | fkey_dict = defaultdict(list) 236 | for edge in session.graph.edges: 237 | fkey_dict[edge.src_table].append(edge.fkey) 238 | 239 | lines = ["erDiagram"] 240 | 241 | for table in session.graph.tables.values(): 242 | feat_columns = [] 243 | for column in table.columns: 244 | if (column.name != table._primary_key 245 | and column.name not in fkey_dict[table.name] 246 | and column.name != table._time_column 247 | and column.name != table._end_time_column): 248 | feat_columns.append(column) 249 | 250 | lines.append(f"{' ' * 4}{table.name} {{") 251 | if pkey := table.primary_key: 252 | lines.append(f"{' ' * 8}{pkey.stype} {pkey.name} PK") 253 | for fkey_name in fkey_dict[table.name]: 254 | fkey = table[fkey_name] 255 | lines.append(f"{' ' * 8}{fkey.stype} {fkey.name} FK") 256 | if time_col := table.time_column: 257 | lines.append(f"{' ' * 8}{time_col.stype} {time_col.name}") 258 | if end_time_col := table.end_time_column: 259 | lines.append(f"{' ' * 8}{end_time_col.stype} {end_time_col.name}") 260 | if show_columns: 261 | for col in feat_columns: 262 | lines.append(f"{' ' * 8}{col.stype} {col.name}") 263 | lines.append(f"{' ' * 4}}}") 264 | 265 | if len(session.graph.edges) > 0: 266 | lines.append("") 267 | 268 | for edge in session.graph.edges: 269 | lines.append(f"{' ' * 4}{edge.dst_table} o|--o{{ {edge.src_table} " 270 | f": {edge.fkey}") 271 | 272 | return '\n'.join(lines) 273 | 274 | 275 | async def materialize_graph() -> MaterializedGraphInfo: 276 | """Materialize the graph based on the current state of the graph metadata 277 | to make it available for inference operations (e.g., ``predict`` and 278 | ``evaluate``). 279 | 280 | Any updates to the graph metadata require re-materializing the graph before 281 | the KumoRFM model can start making predictions again. 282 | """ 283 | session = SessionManager.get_default_session() 284 | 285 | def _materialize_graph() -> rfm.KumoRFM: 286 | try: 287 | logger = ProgressLogger("Materializing graph") 288 | return rfm.KumoRFM(session.graph, verbose=logger) 289 | except Exception as e: 290 | raise ToolError(f"Failed to materialize graph: {e}") 291 | 292 | def _get_info(model: rfm.KumoRFM) -> MaterializedGraphInfo: 293 | store = model._graph_store 294 | num_nodes = sum(len(df) for df in store.df_dict.values()) 295 | num_edges = sum(len(row) for row in store.row_dict.values()) 296 | time_ranges = {} 297 | for table in session.graph.tables.values(): 298 | if table._time_column is None: 299 | continue 300 | time = store.df_dict[table.name][table._time_column] 301 | if table.name in store.mask_dict.keys(): 302 | time = time[store.mask_dict[table.name]] 303 | if len(time) == 0: 304 | continue 305 | time_ranges[table.name] = f"{time.min()} - {time.max()}" 306 | 307 | return MaterializedGraphInfo( 308 | num_nodes=num_nodes, 309 | num_edges=num_edges, 310 | time_ranges=time_ranges, 311 | ) 312 | 313 | if session._model is None: 314 | async with _materialize_lock: 315 | session._model = await asyncio.to_thread(_materialize_graph) 316 | 317 | return await asyncio.to_thread(_get_info, session._model) 318 | 319 | 320 | async def lookup_table_rows( 321 | table_name: Annotated[str, "Table name"], 322 | ids: Annotated[ 323 | list[str | int | float], 324 | Field( 325 | min_length=1, 326 | max_length=1000, 327 | description="Primary keys to read", 328 | ), 329 | ], 330 | ) -> TableSourcePreview: 331 | """Lookup rows in the raw data frame of a table for a list of primary 332 | keys. 333 | 334 | In contrast to the 'inspect_table_files' tool, this tool can be used to 335 | query specific rows in a registered table in the graph. 336 | It should not be used to understand and analyze table schema. 337 | 338 | Use this tool to look up detailed information about recommended items to 339 | provide richer, more meaningful recommendations to users. 340 | 341 | The table to read from needs to have a primary key, and the graph has to be 342 | materialized. 343 | """ 344 | model = SessionManager.get_default_session()._model 345 | 346 | if model is None: 347 | raise ToolError("Graph is not yet materialized") 348 | 349 | def _lookup_table_rows() -> TableSourcePreview: 350 | try: 351 | node_ids = model._graph_store.get_node_id( 352 | table_name=table_name, 353 | pkey=pd.Series(ids), 354 | ) 355 | df = model._graph_store.df_dict[table_name].iloc[node_ids] 356 | except Exception as e: 357 | raise ToolError(str(e)) from e 358 | 359 | df = df.astype(object).where(df.notna(), None) 360 | return TableSourcePreview(rows=df.to_dict(orient='records')) 361 | 362 | return await asyncio.to_thread(_lookup_table_rows) 363 | 364 | 365 | def register_graph_tools(mcp: FastMCP) -> None: 366 | """Register all graph tools to the MCP server.""" 367 | mcp.tool(annotations=dict( 368 | title="🗂️ Reviewing graph schema…", 369 | readOnlyHint=True, 370 | destructiveHint=False, 371 | idempotentHint=True, 372 | openWorldHint=False, 373 | ))(inspect_graph_metadata) 374 | 375 | mcp.tool(annotations=dict( 376 | title="🔄 Updating graph schema…", 377 | readOnlyHint=False, 378 | destructiveHint=False, 379 | idempotentHint=True, 380 | openWorldHint=False, 381 | ))(update_graph_metadata) 382 | 383 | mcp.tool(annotations=dict( 384 | title="🖼️ Creating graph diagram…", 385 | readOnlyHint=True, 386 | destructiveHint=False, 387 | idempotentHint=True, 388 | openWorldHint=False, 389 | ))(get_mermaid) 390 | 391 | mcp.tool(annotations=dict( 392 | title="🕸️ Assembling graph…", 393 | readOnlyHint=False, 394 | destructiveHint=False, 395 | idempotentHint=True, 396 | openWorldHint=False, 397 | ))(materialize_graph) 398 | 399 | mcp.tool(annotations=dict( 400 | title="📂 Retrieving table entries…", 401 | readOnlyHint=True, 402 | destructiveHint=False, 403 | idempotentHint=True, 404 | openWorldHint=False, 405 | ))(lookup_table_rows) 406 | -------------------------------------------------------------------------------- /kumo_rfm_mcp/resources/predictive-query.md: -------------------------------------------------------------------------------- 1 | # Predictive Query 2 | 3 | The Predictive Query Language (PQL) is a querying language that allows to define relational machine learning tasks. 4 | PQL lets you define predictive problems by specifying: 5 | 6 | 1. **The target expression:** Declares the value or aggregate the model should predict 7 | 1. **The entity specification:** Specifies the entities to predict for 8 | 1. **Optional entity filters:** Filters which historical entities are used as in-context learning examples 9 | 10 | The basic structure of a predictive query is: 11 | 12 | ``` 13 | PREDICT FOR EACH WHERE 14 | ``` 15 | 16 | Every predictive query needs to contain the `PREDICT` and `FOR EACH` keywords. 17 | All references to columns within predictive queries must be fully qualified by table name and column name as `.`. 18 | 19 | In general, follow these given steps to author a predictive query: 20 | 21 | 1. **Choose your entity** - a table and its primary key you predict for. 22 | 1. **Define the target** - a raw column or an aggregation over a future window. 23 | 1. **Refine the context** - if necessary, restrict which historical rows are used as in-context learning examples. 24 | 1. **Run & fetch** - run `predict` or `evaluate` on top. 25 | 26 | A predictive query uniquely defines a predictive machine learning task. 27 | As such, it also defines the procedure on how to obtain ground-truth labels from historical snapshots of the data, which are used to generate context labels to perform in-context learning within KumoRFM. 28 | 29 | **Important:** PQL is not SQL. 30 | Standard SQL operations such as `JOIN`, `SELECT`, `UNION`, `GROUP BY`, and subqueries are not supported in PQL. 31 | PQL uses a simpler, more constrained syntax designed specifically for defining predictive machine learning tasks. 32 | PQL also doesn't support arithmetic operations like `+` or `-`. 33 | Do not make syntax up that is not listed in this document. 34 | 35 | ## Entity Specification 36 | 37 | Entities for each query can be specified via: 38 | 39 | ``` 40 | PREDICT ... FOR EACH users.user_id 41 | ``` 42 | 43 | Note that the entity table needs a primary key to uniquely determine the set of IDs to predict for. 44 | 45 | The actual entities to generate predictions for can be fully customized as part of the `predict` tool via the `indices` argument. 46 | Up to 1000 entities are supported for an individual query. 47 | Note that predictions will be generated for all indices, regardless of whether they match any entity filter constraints defined in the `WHERE` clause. 48 | 49 | ## Target Expression 50 | 51 | The target expression is the value or aggregate the model should predict. 52 | It can be a single value, an aggregate, a condition, or a set of logical operations. 53 | We differentiate between two types of queries: static and temporal queries. 54 | 55 | ### Static Predictive Queries 56 | 57 | Static predictive queries are used to impute missing values from an entity table. 58 | That is, the target column has to appear in the same table as the entity you are making a prediction for. 59 | KumoRFM will then mask out the target column and predict the value from related in-context examples. 60 | 61 | For example, you can predict the age of users via 62 | 63 | ``` 64 | PREDICT users.age FOR EACH users.user_id 65 | ``` 66 | 67 | You can impute missing values for all `"numerical"` and `"categorical"` columns. 68 | Currently, you cannot impute missing values for other semantic types such as `"timestamp"` or `"text"`. 69 | For `"numerical"` columns, the predictive query is interpreted as a regression task. 70 | For `"categorical"` columns, the predictive query is interpreted as a multi-class classification task. 71 | For binary classification tasks, you can add **conditions** to the target expression: 72 | 73 | ``` 74 | PREDICT users.age > 40 FOR EACH users.user_id 75 | ``` 76 | 77 | The following boolean operators are supported: 78 | 79 | - `=`: ` = ` - can be applied to any column type 80 | - `!=`: ` != `, can be applied to any column type 81 | - `<`: ` < ` - can be applied to numerical and temporal columns only 82 | - `<=`: ` <= ` - can be applied to numerical and temporal columns only 83 | - `>`: ` > ` - can be applied to numerical and temporal columns only 84 | - `>=`: ` >= ` - can be applied to numerical and temporal columns only 85 | - `IN`: ` IN (, , )` - can be applied to any column type 86 | 87 | The `` needs to be a constant, pre-defined value. 88 | It cannot be modeled as a target expression. 89 | When using boolean conditions, the value format must match the column's data type: 90 | 91 | ``` 92 | PREDICT users.location='US' FOR EACH users.user_id 93 | PREDICT users.birthday>1990-01-01 FOR EACH users.user_id 94 | ``` 95 | 96 | Multiple conditions can be logically combined via `AND`, `OR` and `NOT` to form complex predictive queries, e.g.: 97 | 98 | ``` 99 | PREDICT (users.age>40 OR users.location='US') AND (NOT users.gender='male') FOR EACH users.user_id 100 | ``` 101 | 102 | The following logical operations are supported: 103 | 104 | - `AND`: ` AND ` 105 | - `OR`: ` OR ` 106 | - `NOT`: `NOT ` 107 | 108 | Use parentheses to group logical operations and control their order. 109 | 110 | ### Temporal Predictive Queries 111 | 112 | Temporal predictive queries predict some aggregation of values over time (e.g., purchases each customer will make over the next 7 days). 113 | The target table needs to be directly connected to the entity table via a foreign key-primary key relationship. 114 | 115 | An aggregation is defined by an aggregation operator over a **relative** period of time. 116 | You can specify an aggregation operator and the column in the target table representing the value you want to aggregate. 117 | The syntax is as follows: 118 | 119 | ``` 120 | (., , , ) 121 | ``` 122 | 123 | For example: 124 | 125 | ``` 126 | PREDICT SUM(orders.price, 0, 30, days) FOR EACH users.user_id 127 | ``` 128 | 129 | Here, `orders` is a table that is connected to `users` via a foreign key-primary key relationship (`orders.user_id <> users.user_id`). 130 | Within the aggregation function inputs, the `` (`0` in the example) and `` (`30` in the example) parameters refer to the time period you want to aggregate across, relative to a given anchor time. 131 | Both `` and `` should be non-negative, and `` values should be strictly greater than ``. 132 | As such, the example query can be understood as: "Predict the sum of prices of all the orders a user will do in the next 30 days". 133 | 134 | Note that by default, the anchor time is set to the maximum timestamp present in your relational data, but can be fully customized in `predict` and `evaluate` tools. 135 | The `` value is not limited to be always `0`. 136 | For example, a `` value of `10` and an `` value of `30` implies that you want to aggregate from 10 days later (excluding the 10th day) to 30 days later (including the 30th day). 137 | 138 | The following values for `` are supported: `seconds`, `minutes`, `hours`, `days`, `weeks`, `months` 139 | The time unit of the aggregation defaults to `days` if none is specified. 140 | 141 | Similar to static predictive queries, you can add conditions and logical operations to temporal predictive queries to create binary classification tasks: 142 | 143 | ``` 144 | PREDICT SUM(transactions.price, 0, 30, days)=0 FOR EACH users.user_id 145 | ``` 146 | 147 | When using logical operations, it is allowed to aggregate from multiple different target tables: 148 | 149 | ``` 150 | PREDICT COUNT(session.*, 0, 7)>10 OR SUM(transaction.value, 0, 5)>100 FOR EACH user.user_id 151 | ``` 152 | 153 | #### Aggregation Operators 154 | 155 | The following aggregation operators are supported: 156 | 157 | - `SUM`: Calculates the total of values in a numerical column 158 | - `AVG`: Calculates the average of values in a numerical column 159 | - `MIN`: Finds the minimum value in a numerical column 160 | - `MAX`: Finds the maximum value in a numerical column 161 | - `COUNT`: Counts the number of rows/events. 162 | Use `COUNT(.*, ...)` to count all events, or `COUNT(., ...)` to count non-null values in any column type. 163 | The `COUNT` operator is the only operator where the special `*` syntax is allowed. 164 | - `LIST_DISTINCT`: Returns a distinct list of unique values from a foreign key column (used for recommendations) 165 | 166 | ##### Recommendation Tasks 167 | 168 | The `LIST_DISTINCT` operator is specifically designed for recommendation tasks. 169 | It predicts which foreign key values an entity will interact with in the future. 170 | The basic syntax is: 171 | 172 | ``` 173 | LIST_DISTINCT(., , , ) RANK TOP k FOR EACH ... 174 | ``` 175 | 176 | `LIST_DISTINCT` aggregations must be applied to foreign key columns (not regular columns). 177 | They cannot be combined with conditions or logical operations. 178 | They also must include `RANK TOP k` to specify how many recommendations to return, where `k` can range from 1 to 20 (maximum 20 recommendations per query). 179 | For example: 180 | 181 | ``` 182 | PREDICT LIST_DISTINCT(orders.item_id, 0, 7, days) RANK TOP 10 FOR EACH users.user_id 183 | ``` 184 | 185 | ##### Handling Inactive Entities in Temporal Aggregations 186 | 187 | In case there is no event for a given entity within the requested time window, predictive query behaves differently depending on the aggregation operator and whether it has a neutral element. 188 | 189 | **Zero-Valued Aggregations**: For `SUM` and `COUNT` operations, entities with no activity will return zero values and will be included as in-context learning examples. 190 | 191 | **Undefined Aggregations**: For `AVG`, `MIN`, `MAX`, and `LIST_DISTINCT` operations, inactive entities produce undefined results and are excluded from in-context learning. 192 | 193 | **Important:** Make sure that treating inactive entities as zero is desirable. 194 | Always use temporal entity filters with `SUM` and `COUNT` aggregations to prevent learning from irrelevant and outdated examples (see below on how to define temporal entity filters). 195 | 196 | #### Target Filters 197 | 198 | Target filters allow you to further conextualize your predictive query by dropping certain target rows that do not meet a specific condition. 199 | By using a `WHERE` clause within the target expression (valid for all aggregation types), you can drop rows from being aggregated. 200 | For example: 201 | 202 | ``` 203 | PREDICT COUNT(transactions.* WHERE transactions.price > 10, 0, 7, days) FOR EACH users.user_id 204 | ``` 205 | 206 | Note that the `WHERE` clause of target filters need to be part of the aggregation input. 207 | Target filters must be static and thus can **only** reference columns within the target table being aggregated. 208 | Cross-table references, subqueries, and joins are **not** supported. 209 | Do not make syntax up that is not listed in this document. 210 | 211 | ## Entity Filters 212 | 213 | KumoRFM makes entity-specific predictions based on in-context examples, collected from a historical snapshot of the relational data. 214 | Entity filters can be used to provide more control over how KumoRFM collects in-context examples. 215 | For example, to exclude `users` without recent activity from the context, you can write: 216 | 217 | ``` 218 | PREDICT COUNT(orders.*, 0, 30, days)>0 FOR EACH users.user_id WHERE COUNT(orders.*, -30, 0, days) > 0 219 | ``` 220 | 221 | This limits the in-context examples for predicting churn to active users only. 222 | Note that these filters are **not** applied to the provided entity list `indices` as part of the `predict` tool. 223 | 224 | Both static and temporal filters can be used as entity filters. 225 | If you use temporal entity filters, the `` and `` parameters need to be backward looking, i.e. ` < 0` and ` <= 0`. 226 | Still, `` values need to be strictly greater than `` values. 227 | For temporal entity filters, `` can also be defined as `-INF` to include all historical data from the beginning of the dataset. 228 | 229 | In order to to investigate hypothetical scenarios and to evaluate impact of your actions or decisions, you can use the `ASSUMING` keyword (instead of `WHERE`) to write forward looking entity filters. 230 | For example, you may want to investigate how much a user will spend if you give them a certain coupon or notification. 231 | The `ASSUMING` keyword is followed by a future-looking assumption, which will be assumed to be true for the entity IDs you predict for. 232 | 233 | ``` 234 | PREDICT COUNT(orders.*, 0, 30, days)>0 FOR EACH users.user_id ASSUMING COUNT(notifications.*, 0, 7, days)>0 235 | ``` 236 | 237 | Standard SQL operations such as `JOIN`, `SELECT`, `UNION`, `GROUP BY`, and subqueries are not supported in PQL. 238 | Do not make syntax up that is not listed in this document. 239 | 240 | ## Task Types 241 | 242 | The predictive query uniquely determines the underlying machine learning task type based on your query structure and the underlying graph schema. 243 | The following machine learning tasks are supported: 244 | 245 | - **Binary classification:** When your target expression includes a condition that results in true/false 246 | - **Multi-class classification:** When predicting a categorical column with multiple possible values 247 | - **Regression:** When predicting a numerical value 248 | - **Recommendation/temporal link prediction:** When predicting a ranked list of items using `LIST_DISTINCT` 249 | 250 | Note that you don't need to specify the task type. 251 | PQL automatically detects it based on whether you are predicting a condition (binary), categories (multi-class), numbers (regression), or ranked lists (recommendation). 252 | 253 | ## Best Practices 254 | 255 | - Use target filters to filter which events to aggregate. 256 | - Use entity filters to filter which historical examples to learn from. 257 | - Make sure to include temporal entity filters in zero-valued aggregations such as `SUM` or `COUNT`. 258 | - Ensure value formats match column data types in conditions (e.g., `'US'` for strings, `1990-01-01` for dates). 259 | - It might be non-trivial to pick appropriate `` and `` values. 260 | Choose meaningful time windows that align with domain knowledge and account for event frequency. 261 | For example, in an e-commerce dataset, predicting churn based on the next seven days might be unrealistic. 262 | Play around with different time windows and see how it affects the prediction. 263 | - Analyze the label distribution of in-context learning examples in the `predict` and `evaluate` tool logs to understand if your query needs any adjustments, e.g., more or less strict temporal entity filters. 264 | - Take the label distribution of the predictive query into account when analyzing output metrics of the `evaluate` tool. 265 | - When running a predictive query via `predict` or `evaluate` tools, use `run_mode="fast"` for initial exploration, and reserve `run_mode="best"` for final production queries. 266 | - Choose anchor times that represent realistic prediction scenarios. 267 | Use `anchor_time=None` to make predictions based on the most recent data. 268 | Use `anchor_time='entity'` for static predictions to prevent temporal leakage if entities denote temporal facts. 269 | - Tune the `max_pq_iterations` argument if you see that the model fails to find sufficient number of in-context examples w.r.t. the `run_mode`, i.e. 1000 for `'fast'`, 5000 for `'normal'` and 10000 for `'best'`. 270 | 271 | ## Common Mistakes 272 | 273 | - Ensure that `` is always less than ``. 274 | - Ensure that `` is less than or equal to `0` in temporal entity filters. 275 | - PQL doesn't support arithmetic operations. 276 | - PQL is not SQL - use only supported operators and conditions. 277 | - `SUM` and `COUNT` queries without temporal entity filters include inactive/irrelevant examples. 278 | Always use temporal entity filters with `SUM` and `COUNT` to focus on relevant examples. 279 | - Incorrect semantic types may lead to wrong task formulations. 280 | Carefully review and correct semantic types during graph setup. 281 | - `LIST_DISTINCT` only works on foreign key columns. 282 | - Using `anchor_time='entity'` for temporal queries with aggregations is **not** supported. 283 | 284 | ## Examples 285 | 286 | 1. **Recommend movies to users:** 287 | 288 | ``` 289 | PREDICT LIST_DISTINCT(ratings.movie_id, 0, 14, days) RANK TOP 20 290 | FOR EACH users.user_id 291 | ``` 292 | 293 | 1. **Predict inactive users:** 294 | 295 | ``` 296 | PREDICT COUNT(sessions.*, 0, 14)=0 297 | FOR EACH users.user_id WHERE COUNT(sessions.*,-7,0)>0 298 | ``` 299 | 300 | 1. **Predict 5-star reviews:** 301 | 302 | ``` 303 | PREDICT COUNT(ratings.* WHERE ratings.rating = 5, 0, 30)>0 304 | FOR EACH products.product_id 305 | ``` 306 | 307 | 1. **Predict customer churn:** 308 | 309 | ``` 310 | PREDICT COUNT(transactions.price, 0, 3, months)>0 311 | FOR EACH customers.customer_id 312 | WHERE SUM(transactions.price, -2, 0, months)>0.05 313 | ``` 314 | 315 | 1. **Find next best articles:** 316 | 317 | ``` 318 | PREDICT LIST_DISTINCT(transactions.article_id, 0, 90) RANK TOP 20 319 | FOR EACH customers.customer_id 320 | ``` 321 | --------------------------------------------------------------------------------