├── tests ├── __init__.py ├── test_tool_metadata.py ├── test_additional_features.py ├── test_transcript.py ├── README.md ├── test_server_structured.py └── test_clusters.py ├── databricks_mcp ├── __init__.py ├── api │ ├── __init__.py │ ├── libraries.py │ ├── repos.py │ ├── unity_catalog.py │ ├── clusters.py │ ├── sql.py │ ├── dbfs.py │ ├── jobs.py │ └── notebooks.py ├── cli │ ├── __init__.py │ └── commands.py ├── core │ ├── __init__.py │ ├── logging_utils.py │ ├── auth.py │ ├── models.py │ ├── config.py │ └── utils.py ├── server │ ├── __init__.py │ ├── __main__.py │ ├── tool_helpers.py │ ├── app.py │ └── databricks_mcp_server.py ├── __main__.py └── main.py ├── .env.example ├── .gitignore ├── pyproject.toml ├── AGENTS.md ├── README.md └── ARCHITECTURE.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /databricks_mcp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /databricks_mcp/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /databricks_mcp/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /databricks_mcp/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /databricks_mcp/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /databricks_mcp/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main entry point for running the databricks-mcp-server package. 3 | This allows the package to be run with 'python -m databricks_mcp. or 'uv run databricks_mcp'. 4 | """ 5 | 6 | import asyncio 7 | from databricks_mcp.main import main 8 | 9 | if __name__ == "__main__": 10 | asyncio.run(main()) 11 | -------------------------------------------------------------------------------- /databricks_mcp/server/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main entry point for running the server module directly. 3 | This allows the module to be run with 'python -m databricks_mcp.server' or 'uv run databricks_mcp.server'. 4 | """ 5 | 6 | from databricks_mcp.server.databricks_mcp_server import main 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Databricks Configuration 2 | DATABRICKS_HOST=https://your-workspace.databricks.com 3 | DATABRICKS_TOKEN=dapi_your_token_here 4 | DATABRICKS_WAREHOUSE_ID=sql_warehouse_12345 5 | 6 | # Server Configuration (Optional) 7 | SERVER_HOST=0.0.0.0 8 | SERVER_PORT=8000 9 | DEBUG=false 10 | 11 | # Logging (Optional) 12 | LOG_LEVEL=INFO 13 | -------------------------------------------------------------------------------- /tests/test_tool_metadata.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | import pytest 4 | 5 | if importlib.util.find_spec("mcp") is None: # pragma: no cover - environment guard 6 | pytest.skip("mcp package not available", allow_module_level=True) 7 | 8 | from databricks_mcp.server.databricks_mcp_server import DatabricksMCPServer 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_list_tools_has_schemas(): 13 | server = DatabricksMCPServer() 14 | tools = await server.list_tools() 15 | assert any(tool.name == "list_clusters" for tool in tools) 16 | list_clusters_tool = next(tool for tool in tools if tool.name == "list_clusters") 17 | assert "properties" in list_clusters_tool.inputSchema 18 | if list_clusters_tool.outputSchema is not None: 19 | assert "type" in list_clusters_tool.outputSchema 20 | -------------------------------------------------------------------------------- /tests/test_additional_features.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock 3 | 4 | from databricks_mcp.api import libraries, repos, unity_catalog 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_install_library(): 9 | libraries.install_library = AsyncMock(return_value={}) 10 | resp = await libraries.install_library("cluster", []) 11 | assert resp == {} 12 | libraries.install_library.assert_called_once() 13 | 14 | 15 | @pytest.mark.asyncio 16 | async def test_create_repo(): 17 | repos.create_repo = AsyncMock(return_value={"id": 1}) 18 | resp = await repos.create_repo("https://example.com", "git") 19 | assert resp["id"] == 1 20 | repos.create_repo.assert_called_once() 21 | 22 | 23 | @pytest.mark.asyncio 24 | async def test_list_catalogs(): 25 | unity_catalog.list_catalogs = AsyncMock(return_value={"catalogs": []}) 26 | resp = await unity_catalog.list_catalogs() 27 | assert resp["catalogs"] == [] 28 | unity_catalog.list_catalogs.assert_called_once() 29 | 30 | -------------------------------------------------------------------------------- /databricks_mcp/api/libraries.py: -------------------------------------------------------------------------------- 1 | """API for managing cluster libraries.""" 2 | 3 | import logging 4 | from typing import Any, Dict, List 5 | 6 | from databricks_mcp.core.utils import make_api_request, DatabricksAPIError 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | async def install_library(cluster_id: str, libraries: List[Dict[str, Any]]) -> Dict[str, Any]: 12 | """Install libraries on a cluster.""" 13 | logger.info(f"Installing libraries on cluster {cluster_id}") 14 | payload = {"cluster_id": cluster_id, "libraries": libraries} 15 | return await make_api_request("POST", "/api/2.0/libraries/install", data=payload) 16 | 17 | 18 | async def uninstall_library(cluster_id: str, libraries: List[Dict[str, Any]]) -> Dict[str, Any]: 19 | """Uninstall libraries from a cluster.""" 20 | logger.info(f"Uninstalling libraries on cluster {cluster_id}") 21 | payload = {"cluster_id": cluster_id, "libraries": libraries} 22 | return await make_api_request("POST", "/api/2.0/libraries/uninstall", data=payload) 23 | 24 | 25 | async def list_cluster_libraries(cluster_id: str) -> Dict[str, Any]: 26 | """List library status for a cluster.""" 27 | logger.info(f"Listing libraries for cluster {cluster_id}") 28 | return await make_api_request("GET", "/api/2.0/libraries/cluster-status", params={"cluster_id": cluster_id}) 29 | -------------------------------------------------------------------------------- /tests/test_transcript.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | 4 | import pytest 5 | 6 | if importlib.util.find_spec("mcp") is None: # pragma: no cover 7 | pytest.skip("mcp package not available", allow_module_level=True) 8 | 9 | from databricks_mcp.server.databricks_mcp_server import DatabricksMCPServer 10 | import databricks_mcp.api.clusters as clusters_api 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_list_clusters_transcript(monkeypatch): 15 | async def fake_list_clusters(): 16 | return {"clusters": [{"cluster_id": "transcript", "state": "RUNNING"}]} 17 | 18 | monkeypatch.setattr(clusters_api, "list_clusters", fake_list_clusters) 19 | 20 | server = DatabricksMCPServer() 21 | result = await server.call_tool("list_clusters", {}) 22 | 23 | transcript = { 24 | "request": {"name": "list_clusters", "arguments": {}}, 25 | "response": { 26 | "isError": result.isError, 27 | "structuredContent": result.structuredContent, 28 | }, 29 | } 30 | 31 | expected = { 32 | "request": {"name": "list_clusters", "arguments": {}}, 33 | "response": { 34 | "isError": False, 35 | "structuredContent": {"clusters": [{"cluster_id": "transcript", "state": "RUNNING"}]}, 36 | }, 37 | } 38 | 39 | assert transcript == expected 40 | # Ensure transcript is JSON-serializable for golden comparisons. 41 | json.dumps(transcript) 42 | -------------------------------------------------------------------------------- /databricks_mcp/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main entry point for the Databricks MCP server. 3 | """ 4 | 5 | import argparse 6 | import asyncio 7 | import logging 8 | from typing import Optional 9 | 10 | from databricks_mcp.core.config import settings 11 | from databricks_mcp.core.logging_utils import configure_logging 12 | from databricks_mcp.server.databricks_mcp_server import DatabricksMCPServer 13 | 14 | 15 | async def start_mcp_server() -> None: 16 | """Start the MCP server via the FastMCP stdio helper.""" 17 | server = DatabricksMCPServer() 18 | await server.run_stdio_async() 19 | 20 | 21 | def setup_logging(log_level: Optional[str] = None) -> None: 22 | """Set up centralized logging before any server work begins.""" 23 | level = (log_level or settings.LOG_LEVEL).upper() 24 | configure_logging(level=level, log_file="databricks_mcp.log") 25 | 26 | 27 | async def main(log_level: Optional[str] = None) -> None: 28 | """Main asynchronous entry point.""" 29 | setup_logging(log_level) 30 | logger = logging.getLogger(__name__) 31 | logger.info("Starting Databricks MCP server v%s", settings.VERSION) 32 | logger.info("Databricks host resolved") 33 | await start_mcp_server() 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="Databricks MCP Server") 38 | parser.add_argument( 39 | "--log-level", 40 | choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], 41 | help="Override default log level", 42 | ) 43 | args = parser.parse_args() 44 | asyncio.run(main(args.log_level)) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python virtual environments 2 | venv/ 3 | .venv/ 4 | env/ 5 | ENV/ 6 | 7 | # Python bytecode 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # Distribution / packaging 13 | dist/ 14 | build/ 15 | *.egg-info/ 16 | 17 | # Local development settings 18 | .env 19 | .env.local 20 | .pypirc 21 | 22 | # IDE settings 23 | .idea/ 24 | .vscode/ 25 | *.swp 26 | *.swo 27 | 28 | # OS specific files 29 | .DS_Store 30 | Thumbs.db 31 | 32 | # Logs 33 | *.log 34 | logs/ 35 | IMPLEMENTATION_COMPLETE.md 36 | 37 | # Temporary files 38 | tmp/ 39 | temp/ 40 | 41 | # uv package manager files 42 | .uv/ 43 | uv.lock 44 | 45 | # Databricks-specific 46 | *.dbfs 47 | 48 | # C extensions 49 | *.so 50 | 51 | # Distribution / packaging 52 | .Python 53 | develop-eggs/ 54 | downloads/ 55 | eggs/ 56 | .eggs/ 57 | lib/ 58 | lib64/ 59 | parts/ 60 | sdist/ 61 | var/ 62 | wheels/ 63 | .installed.cfg 64 | *.egg 65 | MANIFEST 66 | 67 | # PyInstaller 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .coverage 79 | .coverage.* 80 | .cache 81 | nosetests.xml 82 | coverage.xml 83 | *.cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | test_example_file.py 87 | test_new_features.py 88 | debug_api.py 89 | 90 | # Environments 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # IDEs and editors 95 | *.swp 96 | *.swo 97 | *~ 98 | 99 | # OS generated files 100 | .DS_Store? 101 | ._* 102 | .Spotlight-V100 103 | .Trashes 104 | ehthumbs.db 105 | Thumbs.db 106 | 107 | # Log files 108 | *.log 109 | -------------------------------------------------------------------------------- /databricks_mcp/api/repos.py: -------------------------------------------------------------------------------- 1 | """API for Databricks Repos.""" 2 | 3 | import logging 4 | from typing import Any, Dict, Optional 5 | 6 | from databricks_mcp.core.utils import DatabricksAPIError, make_api_request 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | async def create_repo(url: str, provider: str, branch: Optional[str] = None, path: Optional[str] = None) -> Dict[str, Any]: 12 | """Create or clone a repo.""" 13 | payload = {"url": url, "provider": provider} 14 | if branch: 15 | payload["branch"] = branch 16 | if path: 17 | payload["path"] = path 18 | return await make_api_request("POST", "/api/2.0/repos", data=payload) 19 | 20 | 21 | async def update_repo(repo_id: int, branch: Optional[str] = None, tag: Optional[str] = None) -> Dict[str, Any]: 22 | """Update repo branch or pull latest.""" 23 | payload: Dict[str, Any] = {} 24 | if branch: 25 | payload["branch"] = branch 26 | if tag: 27 | payload["tag"] = tag 28 | return await make_api_request("PATCH", f"/api/2.0/repos/{repo_id}", data=payload) 29 | 30 | 31 | async def list_repos(path_prefix: Optional[str] = None) -> Dict[str, Any]: 32 | """List repos, optionally filtered by path prefix.""" 33 | params = {"path_prefix": path_prefix} if path_prefix else None 34 | return await make_api_request("GET", "/api/2.0/repos", params=params) 35 | 36 | 37 | async def pull_repo(repo_id: int) -> Dict[str, Any]: 38 | """Pull the latest code for a repository. 39 | 40 | Args: 41 | repo_id: ID of the repository to pull 42 | 43 | Returns: 44 | Response from the Databricks API 45 | 46 | Raises: 47 | DatabricksAPIError: If the API request fails 48 | """ 49 | logger.info(f"Pulling repo {repo_id}") 50 | endpoint = f"/api/2.0/repos/{repo_id}/pull" 51 | return await make_api_request("POST", endpoint) 52 | -------------------------------------------------------------------------------- /databricks_mcp/api/unity_catalog.py: -------------------------------------------------------------------------------- 1 | """API for Unity Catalog.""" 2 | 3 | import logging 4 | from typing import Any, Dict, Optional 5 | 6 | from databricks_mcp.core.utils import make_api_request 7 | from databricks_mcp.api import sql 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | async def list_catalogs() -> Dict[str, Any]: 13 | logger.info("Listing catalogs") 14 | return await make_api_request("GET", "/api/2.1/unity-catalog/catalogs") 15 | 16 | 17 | async def create_catalog(name: str, comment: Optional[str] = None) -> Dict[str, Any]: 18 | payload = {"name": name} 19 | if comment: 20 | payload["comment"] = comment 21 | return await make_api_request("POST", "/api/2.1/unity-catalog/catalogs", data=payload) 22 | 23 | 24 | async def list_schemas(catalog_name: str) -> Dict[str, Any]: 25 | return await make_api_request("GET", "/api/2.1/unity-catalog/schemas", params={"catalog_name": catalog_name}) 26 | 27 | 28 | async def create_schema(catalog_name: str, name: str, comment: Optional[str] = None) -> Dict[str, Any]: 29 | payload = {"catalog_name": catalog_name, "name": name} 30 | if comment: 31 | payload["comment"] = comment 32 | return await make_api_request("POST", "/api/2.1/unity-catalog/schemas", data=payload) 33 | 34 | 35 | async def list_tables(catalog_name: str, schema_name: str) -> Dict[str, Any]: 36 | params = {"catalog_name": catalog_name, "schema_name": schema_name} 37 | return await make_api_request("GET", "/api/2.1/unity-catalog/tables", params=params) 38 | 39 | 40 | async def create_table(warehouse_id: str, statement: str) -> Dict[str, Any]: 41 | """Execute a CREATE TABLE statement using the SQL API.""" 42 | return await sql.execute_statement(statement, warehouse_id=warehouse_id) 43 | 44 | 45 | async def get_table_lineage(full_name: str) -> Dict[str, Any]: 46 | endpoint = f"/api/2.1/unity-catalog/lineage-tracking/table-lineage/{full_name}" 47 | return await make_api_request("GET", endpoint) 48 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests for Databricks MCP Server 2 | 3 | This directory contains automated tests for the Databricks MCP server. The 4 | suite is written in `pytest` and relies on `uv` for dependency and virtualenv 5 | management. 6 | 7 | ## Layout 8 | 9 | - `test_additional_features.py` - smoke tests for auxiliary Databricks features 10 | (repos, workspace listings, etc.). 11 | - `test_clusters.py` - CRUD and lifecycle coverage for cluster-oriented tools. 12 | - `test_tool_metadata.py` - asserts that every registered tool exposes the 13 | expected description, schema metadata, and argument signatures. 14 | - `test_server_structured.py` - validates that tool responses populate 15 | `structuredContent`, include human-readable text summaries, and surface 16 | resource links for large artifacts. 17 | - `test_transcript.py` - golden transcript of a `tools/list` and representative 18 | `tools/call` interaction to guard against protocol regressions. 19 | 20 | All tests are async-friendly and do not require live Databricks credentials; 21 | HTTP calls are mocked. 22 | 23 | ## Running the Test Suite 24 | 25 | From the repository root: 26 | 27 | ```bash 28 | uv run pytest 29 | ``` 30 | 31 | The command above automatically creates an ephemeral virtual environment (if 32 | needed), installs the `dev` extras, and executes every test module in this 33 | directory. The suite completes in under a second on a typical laptop. 34 | 35 | ## Adding New Tests 36 | 37 | 1. Create a new `test_*.py` file in this directory and use `pytest` naming 38 | conventions for functions/classes. 39 | 2. Add any reusable fixtures to a new `conftest.py` (or the specific module 40 | under test) so they are automatically discoverable by `pytest`. 41 | 3. Keep protocol-level assertions (structured content shape, resource links, 42 | progress notifications) close to the server modules they cover. 43 | 4. Run `uv run pytest` locally before opening a pull request and update this 44 | README if you add significant new suites or fixtures. 45 | -------------------------------------------------------------------------------- /databricks_mcp/core/logging_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared logging configuration utilities for the Databricks MCP server. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import json 8 | import logging 9 | import sys 10 | from typing import Optional 11 | 12 | 13 | class JsonFormatter(logging.Formatter): 14 | """Format log records as single-line JSON objects.""" 15 | 16 | def format(self, record: logging.LogRecord) -> str: # noqa: D401 - inherited docstring covers behavior 17 | payload = { 18 | "name": record.name, 19 | "level": record.levelname, 20 | "message": record.getMessage(), 21 | "timestamp": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S%z"), 22 | } 23 | if record.exc_info: 24 | payload["exc_info"] = self.formatException(record.exc_info) 25 | if record.stack_info: 26 | payload["stack_info"] = record.stack_info 27 | return json.dumps(payload, ensure_ascii=False) 28 | 29 | 30 | def configure_logging(level: str = "INFO", log_file: Optional[str] = None) -> None: 31 | """ 32 | Configure application-wide logging, emitting JSON lines to stderr and an optional file. 33 | 34 | Args: 35 | level: Root log level name. 36 | log_file: Optional path for a synchronized file handler. 37 | """ 38 | root = logging.getLogger() 39 | if getattr(root, "_mcp_configured", False): # type: ignore[attr-defined] 40 | root.setLevel(level) 41 | return 42 | 43 | root.setLevel(level) 44 | handler = logging.StreamHandler(sys.stderr) 45 | handler.setFormatter(JsonFormatter()) 46 | root.handlers.clear() 47 | root.addHandler(handler) 48 | 49 | if log_file: 50 | file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8") 51 | file_handler.setFormatter(JsonFormatter()) 52 | root.addHandler(file_handler) 53 | 54 | # Mark configuration to avoid duplicate handlers on repeated calls. 55 | setattr(root, "_mcp_configured", True) # type: ignore[attr-defined] 56 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "databricks-mcp-server" 7 | version = "0.4.4" 8 | description = "A Model Context Protocol (MCP) server for Databricks" 9 | authors = [ 10 | {name = "Olivier Debeuf De Rijcker", email = "olivier@markov.bot"} 11 | ] 12 | requires-python = ">=3.10" 13 | readme = "README.md" 14 | license = {text = "MIT"} 15 | keywords = ["databricks", "mcp", "model-context-protocol", "llm", "ai", "cursor"] 16 | homepage = "https://github.com/markov-kernel/databricks-mcp" 17 | repository = "https://github.com/markov-kernel/databricks-mcp" 18 | classifiers = [ 19 | "Development Status :: 4 - Beta", 20 | "Intended Audience :: Developers", 21 | "Topic :: Software Development :: Libraries :: Python Modules", 22 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "License :: OSI Approved :: MIT License", 28 | "Operating System :: OS Independent", 29 | ] 30 | dependencies = [ 31 | "mcp[cli]>=1.2.0", 32 | "httpx", 33 | "databricks-sdk", 34 | ] 35 | 36 | [project.optional-dependencies] 37 | cli = [ 38 | "click", 39 | ] 40 | dev = [ 41 | "black", 42 | "pylint", 43 | "pytest", 44 | "pytest-asyncio", 45 | "fastapi", 46 | "anyio", 47 | ] 48 | 49 | [project.scripts] 50 | databricks-mcp-server = "databricks_mcp.server.databricks_mcp_server:main" 51 | databricks-mcp = "databricks_mcp.cli.commands:main" 52 | 53 | [tool.hatch.build.targets.wheel] 54 | packages = ["databricks_mcp"] 55 | 56 | [tool.pytest.ini_options] 57 | asyncio_mode = "auto" 58 | asyncio_default_fixture_loop_scope = "function" 59 | testpaths = ["tests"] 60 | python_files = ["test_*.py"] 61 | python_functions = ["test_*"] 62 | addopts = "-v --tb=short" 63 | 64 | [dependency-groups] 65 | dev = [ 66 | "build>=1.2.2.post1", 67 | "twine>=6.1.0", 68 | ] 69 | -------------------------------------------------------------------------------- /databricks_mcp/server/tool_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper utilities for building consistent MCP tool responses. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Any, Dict, Optional, Sequence 8 | 9 | from mcp.types import CallToolResult, TextContent 10 | 11 | 12 | def _coerce_structured(data: Any) -> Dict[str, Any]: 13 | """Ensure structured payload is JSON-serializable as a dict.""" 14 | if isinstance(data, dict): 15 | return data 16 | if hasattr(data, "model_dump"): 17 | return data.model_dump(mode="json") # type: ignore[attr-defined] 18 | return {"result": data} 19 | 20 | 21 | def success_result( 22 | summary: str, 23 | data: Any, 24 | *, 25 | meta: Optional[Dict[str, Any]] = None, 26 | resource_links: Optional[Sequence[Dict[str, Any]]] = None, 27 | ) -> CallToolResult: 28 | """ 29 | Build a standardized success payload with structured content. 30 | 31 | Args: 32 | summary: Short human-readable description. 33 | data: Structured payload (or object convertible to dict via `.model_dump()`). 34 | """ 35 | result = CallToolResult( 36 | content=[TextContent(type="text", text=summary)], 37 | structuredContent=_coerce_structured(data), 38 | isError=False, 39 | ) 40 | if meta: 41 | result.meta = meta 42 | if resource_links: 43 | # Append resource_link content blocks (per MCP spec) 44 | for link in resource_links: 45 | result.content.append( 46 | { # type: ignore[dict-item] 47 | "type": "resource_link", 48 | **link, 49 | } 50 | ) 51 | return result 52 | 53 | 54 | def error_result(message: str, *, details: Optional[Any] = None, status_code: Optional[int] = None) -> CallToolResult: 55 | """ 56 | Build a standardized error payload. 57 | 58 | Args: 59 | message: Human-readable error summary. 60 | details: Optional structured detail block. 61 | status_code: Optional HTTP status code from upstream. 62 | """ 63 | payload: Dict[str, Any] = {"message": message} 64 | if status_code is not None: 65 | payload["status_code"] = status_code 66 | if details is not None: 67 | payload["details"] = details if isinstance(details, dict) else {"raw": details} 68 | 69 | return CallToolResult( 70 | content=[TextContent(type="text", text=message)], 71 | structuredContent=payload, 72 | isError=True, 73 | ) 74 | -------------------------------------------------------------------------------- /databricks_mcp/server/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | FastAPI application for Databricks API. 3 | 4 | This is a stub module that provides compatibility with existing tests. 5 | The actual implementation uses the MCP protocol directly. 6 | """ 7 | 8 | from fastapi import FastAPI 9 | 10 | from databricks_mcp.api import clusters, dbfs, jobs, notebooks, sql 11 | from databricks_mcp.core.config import settings 12 | 13 | 14 | def create_app() -> FastAPI: 15 | """ 16 | Create and configure the FastAPI application. 17 | 18 | Returns: 19 | FastAPI: The configured FastAPI application 20 | """ 21 | app = FastAPI( 22 | title="Databricks API", 23 | description="API for interacting with Databricks services", 24 | version=settings.VERSION, 25 | ) 26 | 27 | # Add routes 28 | @app.get("/api/2.0/clusters/list") 29 | async def list_clusters(): 30 | """List all clusters.""" 31 | result = await clusters.list_clusters() 32 | return result 33 | 34 | @app.get("/api/2.0/clusters/get/{cluster_id}") 35 | async def get_cluster(cluster_id: str): 36 | """Get cluster details.""" 37 | result = await clusters.get_cluster(cluster_id) 38 | return result 39 | 40 | @app.post("/api/2.0/clusters/create") 41 | async def create_cluster(request_data: dict): 42 | """Create a new cluster.""" 43 | result = await clusters.create_cluster(request_data) 44 | return result 45 | 46 | @app.post("/api/2.0/clusters/delete") 47 | async def terminate_cluster(request_data: dict): 48 | """Terminate a cluster.""" 49 | result = await clusters.terminate_cluster(request_data.get("cluster_id")) 50 | return result 51 | 52 | @app.post("/api/2.0/clusters/start") 53 | async def start_cluster(request_data: dict): 54 | """Start a cluster.""" 55 | result = await clusters.start_cluster(request_data.get("cluster_id")) 56 | return result 57 | 58 | @app.post("/api/2.0/clusters/resize") 59 | async def resize_cluster(request_data: dict): 60 | """Resize a cluster.""" 61 | result = await clusters.resize_cluster( 62 | request_data.get("cluster_id"), 63 | request_data.get("num_workers") 64 | ) 65 | return result 66 | 67 | @app.post("/api/2.0/clusters/restart") 68 | async def restart_cluster(request_data: dict): 69 | """Restart a cluster.""" 70 | result = await clusters.restart_cluster(request_data.get("cluster_id")) 71 | return result 72 | 73 | return app -------------------------------------------------------------------------------- /databricks_mcp/core/auth.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authentication functionality for the Databricks MCP server. 3 | """ 4 | 5 | import logging 6 | from typing import Dict, Optional 7 | 8 | from fastapi import Depends, HTTPException, Security, status 9 | from fastapi.security import APIKeyHeader 10 | 11 | from databricks_mcp.core.config import settings 12 | 13 | # Configure logging 14 | logger = logging.getLogger(__name__) 15 | 16 | # API key header scheme 17 | API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False) 18 | 19 | 20 | async def validate_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)) -> Dict[str, str]: 21 | """ 22 | Validate API key for protected endpoints. 23 | 24 | Args: 25 | api_key: The API key from the request header 26 | 27 | Returns: 28 | Dictionary with authentication info 29 | 30 | Raises: 31 | HTTPException: If authentication fails 32 | """ 33 | # For now, we're using a simple token comparison 34 | # In a production environment, you might want to use a database or more secure method 35 | 36 | # Check if API key is required in the current environment 37 | if not settings.DEBUG: 38 | if not api_key: 39 | logger.warning("Authentication failed: Missing API key") 40 | raise HTTPException( 41 | status_code=status.HTTP_401_UNAUTHORIZED, 42 | detail="Missing API key", 43 | headers={"WWW-Authenticate": "ApiKey"}, 44 | ) 45 | 46 | # In a real scenario, you would validate against a secure storage 47 | # For demo purposes, we'll just check against an environment variable 48 | # NEVER do this in production - use a proper authentication system! 49 | valid_keys = ["test-api-key"] # Replace with actual implementation 50 | 51 | if api_key not in valid_keys: 52 | logger.warning("Authentication failed: Invalid API key") 53 | raise HTTPException( 54 | status_code=status.HTTP_401_UNAUTHORIZED, 55 | detail="Invalid API key", 56 | headers={"WWW-Authenticate": "ApiKey"}, 57 | ) 58 | 59 | # Return authentication info 60 | return {"authenticated": True} 61 | 62 | 63 | def get_current_user(): 64 | """ 65 | Dependency to get current user. 66 | 67 | For future implementation of user-specific functionality. 68 | Currently returns a placeholder. 69 | """ 70 | # This would be expanded in a real application with actual user information 71 | return {"username": "admin"} -------------------------------------------------------------------------------- /tests/test_server_structured.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import importlib.util 3 | 4 | import pytest 5 | 6 | if importlib.util.find_spec("mcp") is None: # pragma: no cover - environment guard 7 | pytest.skip("mcp package not available", allow_module_level=True) 8 | 9 | from databricks_mcp.server.databricks_mcp_server import DatabricksMCPServer 10 | import databricks_mcp.api.clusters as clusters_api 11 | import databricks_mcp.api.notebooks as notebooks_api 12 | from databricks_mcp.core.utils import DatabricksAPIError 13 | 14 | 15 | @pytest.mark.asyncio 16 | async def test_list_clusters_structured(monkeypatch): 17 | async def fake_list_clusters(): 18 | return {"clusters": [{"cluster_id": "test", "state": "RUNNING"}]} 19 | 20 | monkeypatch.setattr(clusters_api, "list_clusters", fake_list_clusters) 21 | 22 | server = DatabricksMCPServer() 23 | result = await server.call_tool("list_clusters", {}) 24 | 25 | assert result.isError is False 26 | data = result.structuredContent 27 | assert data == {"clusters": [{"cluster_id": "test", "state": "RUNNING"}]} 28 | assert "_request_id" in (result.meta or {}) 29 | assert result.content and "Found 1 clusters" in result.content[0].text 30 | 31 | 32 | @pytest.mark.asyncio 33 | async def test_export_notebook_returns_resource_link(monkeypatch): 34 | payload = { 35 | "content": base64.b64encode(b"print('hello world')").decode("utf-8"), 36 | "format": "SOURCE", 37 | } 38 | 39 | async def fake_export_notebook(path: str, format: str = "SOURCE"): 40 | return payload 41 | 42 | monkeypatch.setattr(notebooks_api, "export_notebook", fake_export_notebook) 43 | 44 | server = DatabricksMCPServer() 45 | result = await server.call_tool("export_notebook", {"path": "/Users/demo", "format": "SOURCE"}) 46 | 47 | assert result.isError is False 48 | assert any(block.get("type") == "resource_link" for block in result.content if isinstance(block, dict)) 49 | data = result.structuredContent 50 | assert data["content"] == payload["content"] 51 | assert data.get("resource_uri", "").startswith("databricks://exports/") 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_error_wrapped(monkeypatch): 56 | async def fake_list_clusters(): 57 | raise DatabricksAPIError("Boom", status_code=500, response={"error": "boom"}) 58 | 59 | monkeypatch.setattr(clusters_api, "list_clusters", fake_list_clusters) 60 | 61 | server = DatabricksMCPServer() 62 | result = await server.call_tool("list_clusters", {}) 63 | 64 | assert result.isError is True 65 | data = result.structuredContent 66 | assert data["message"].startswith("list_clusters failed") 67 | assert data["status_code"] == 500 68 | -------------------------------------------------------------------------------- /databricks_mcp/core/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, List, Optional 4 | 5 | from pydantic import BaseModel 6 | 7 | 8 | class JobTask(BaseModel): 9 | """Represents a Databricks job task.""" 10 | 11 | task_key: str 12 | notebook_task: Optional[Dict[str, Any]] = None 13 | existing_cluster_id: Optional[str] = None 14 | new_cluster: Optional[Dict[str, Any]] = None 15 | 16 | 17 | class Job(BaseModel): 18 | """Simplified Databricks Job model used for job creation.""" 19 | 20 | name: str 21 | tasks: List[JobTask] 22 | existing_cluster_id: Optional[str] = None 23 | new_cluster: Optional[Dict[str, Any]] = None 24 | 25 | 26 | class Run(BaseModel): 27 | """Represents a Databricks job run.""" 28 | 29 | run_id: int 30 | job_id: int 31 | state: Dict[str, Any] 32 | 33 | 34 | class WorkspaceObject(BaseModel): 35 | """Workspace object such as a notebook or directory.""" 36 | 37 | path: str 38 | object_type: str 39 | language: Optional[str] = None 40 | 41 | 42 | class DbfsItem(BaseModel): 43 | """File or directory within DBFS.""" 44 | 45 | path: str 46 | is_dir: bool 47 | file_size: Optional[int] = None 48 | 49 | 50 | class ClusterConfig(BaseModel): 51 | """Subset of Databricks cluster configuration supported by the MCP tools.""" 52 | 53 | cluster_name: str 54 | spark_version: str 55 | node_type_id: str 56 | num_workers: Optional[int] = None 57 | autotermination_minutes: Optional[int] = None 58 | autoscale: Optional[Dict[str, int]] = None 59 | spark_conf: Optional[Dict[str, Any]] = None 60 | custom_tags: Optional[Dict[str, str]] = None 61 | 62 | 63 | class Library(BaseModel): 64 | """Specification of a library to install on a cluster.""" 65 | 66 | pypi: Optional[Dict[str, str]] = None 67 | maven: Optional[Dict[str, Any]] = None 68 | egg: Optional[str] = None 69 | whl: Optional[str] = None 70 | 71 | 72 | class Repo(BaseModel): 73 | """Represents a Databricks repo.""" 74 | 75 | id: Optional[int] = None 76 | url: str 77 | provider: str 78 | branch: Optional[str] = None 79 | path: Optional[str] = None 80 | 81 | 82 | class Catalog(BaseModel): 83 | """Unity Catalog catalog.""" 84 | 85 | name: str 86 | comment: Optional[str] = None 87 | 88 | 89 | class Schema(BaseModel): 90 | """Unity Catalog schema.""" 91 | 92 | name: str 93 | catalog_name: str 94 | comment: Optional[str] = None 95 | 96 | 97 | class Table(BaseModel): 98 | """Unity Catalog table.""" 99 | 100 | name: str 101 | schema_name: str 102 | catalog_name: str 103 | comment: Optional[str] = None 104 | -------------------------------------------------------------------------------- /databricks_mcp/core/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration settings for the Databricks MCP server. 3 | """ 4 | 5 | import os 6 | from typing import Any, Dict, Optional 7 | 8 | # Import dotenv if available, but don't require it. 9 | # Only load dotenv if not running via Cursor MCP (which provides env vars directly). 10 | if not os.environ.get("RUNNING_VIA_CURSOR_MCP"): 11 | try: 12 | from dotenv import load_dotenv 13 | 14 | load_dotenv() 15 | except ImportError: 16 | pass 17 | 18 | from pydantic import field_validator 19 | from pydantic_settings import BaseSettings, SettingsConfigDict 20 | 21 | # Version 22 | VERSION = "0.4.4" 23 | 24 | 25 | class Settings(BaseSettings): 26 | """Base settings for the application.""" 27 | model_config = SettingsConfigDict(env_file=".env", case_sensitive=True, extra="ignore") 28 | 29 | # Databricks API configuration 30 | DATABRICKS_HOST: str = os.environ.get("DATABRICKS_HOST", "https://example.databricks.net") 31 | DATABRICKS_TOKEN: str = os.environ.get("DATABRICKS_TOKEN", "dapi_token_placeholder") 32 | DATABRICKS_WAREHOUSE_ID: Optional[str] = os.environ.get("DATABRICKS_WAREHOUSE_ID") 33 | 34 | # Server configuration 35 | SERVER_HOST: str = os.environ.get("SERVER_HOST", "0.0.0.0") 36 | SERVER_PORT: int = int(os.environ.get("SERVER_PORT", "8000")) 37 | DEBUG: bool = os.environ.get("DEBUG", "False").lower() == "true" 38 | 39 | # Logging 40 | LOG_LEVEL: str = os.environ.get("LOG_LEVEL", "INFO") 41 | 42 | # Runtime controls 43 | TOOL_TIMEOUT_SECONDS: int = int(os.environ.get("TOOL_TIMEOUT_SECONDS", "300")) 44 | MAX_CONCURRENT_REQUESTS: int = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "8")) 45 | 46 | # HTTP / retry behavior 47 | HTTP_TIMEOUT_SECONDS: float = float(os.environ.get("HTTP_TIMEOUT_SECONDS", "60")) 48 | API_MAX_RETRIES: int = int(os.environ.get("API_MAX_RETRIES", "3")) 49 | API_RETRY_BACKOFF_SECONDS: float = float(os.environ.get("API_RETRY_BACKOFF_SECONDS", "0.5")) 50 | 51 | # Version 52 | VERSION: str = VERSION 53 | 54 | @field_validator("DATABRICKS_HOST") 55 | def validate_databricks_host(cls, v: str) -> str: 56 | """Validate Databricks host URL.""" 57 | if not v.startswith(("https://", "http://")): 58 | raise ValueError("DATABRICKS_HOST must start with http:// or https://") 59 | return v 60 | 61 | @field_validator("DATABRICKS_WAREHOUSE_ID") 62 | def validate_warehouse_id(cls, v: Optional[str]) -> Optional[str]: 63 | """Validate warehouse ID format if provided.""" 64 | if v and len(v) < 10: 65 | import logging 66 | logger = logging.getLogger(__name__) 67 | logger.warning(f"Warehouse ID '{v}' seems unusually short") 68 | return v 69 | 70 | # Create global settings instance 71 | settings = Settings() 72 | 73 | 74 | def get_api_headers() -> Dict[str, str]: 75 | """Get headers for Databricks API requests.""" 76 | return { 77 | "Authorization": f"Bearer {settings.DATABRICKS_TOKEN}", 78 | "Content-Type": "application/json", 79 | } 80 | 81 | 82 | def get_databricks_api_url(endpoint: str) -> str: 83 | """ 84 | Construct the full Databricks API URL. 85 | 86 | Args: 87 | endpoint: The API endpoint path, e.g., "/api/2.0/clusters/list" 88 | 89 | Returns: 90 | Full URL to the Databricks API endpoint 91 | """ 92 | # Ensure endpoint starts with a slash 93 | if not endpoint.startswith("/"): 94 | endpoint = f"/{endpoint}" 95 | 96 | # Remove trailing slash from host if present 97 | host = settings.DATABRICKS_HOST.rstrip("/") 98 | 99 | return f"{host}{endpoint}" 100 | -------------------------------------------------------------------------------- /databricks_mcp/api/clusters.py: -------------------------------------------------------------------------------- 1 | """ 2 | API for managing Databricks clusters. 3 | """ 4 | 5 | import logging 6 | from typing import Any, Dict, List, Optional 7 | 8 | from databricks_mcp.core.utils import DatabricksAPIError, make_api_request 9 | 10 | # Configure logging 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | async def create_cluster(cluster_config: Dict[str, Any]) -> Dict[str, Any]: 15 | """ 16 | Create a new Databricks cluster. 17 | 18 | Args: 19 | cluster_config: Cluster configuration 20 | 21 | Returns: 22 | Response containing the cluster ID 23 | 24 | Raises: 25 | DatabricksAPIError: If the API request fails 26 | """ 27 | logger.info("Creating new cluster") 28 | return await make_api_request("POST", "/api/2.0/clusters/create", data=cluster_config) 29 | 30 | 31 | async def terminate_cluster(cluster_id: str) -> Dict[str, Any]: 32 | """ 33 | Terminate a Databricks cluster. 34 | 35 | Args: 36 | cluster_id: ID of the cluster to terminate 37 | 38 | Returns: 39 | Empty response on success 40 | 41 | Raises: 42 | DatabricksAPIError: If the API request fails 43 | """ 44 | logger.info(f"Terminating cluster: {cluster_id}") 45 | return await make_api_request("POST", "/api/2.0/clusters/delete", data={"cluster_id": cluster_id}) 46 | 47 | 48 | async def list_clusters() -> Dict[str, Any]: 49 | """ 50 | List all Databricks clusters. 51 | 52 | Returns: 53 | Response containing a list of clusters 54 | 55 | Raises: 56 | DatabricksAPIError: If the API request fails 57 | """ 58 | logger.info("Listing all clusters") 59 | return await make_api_request("GET", "/api/2.0/clusters/list") 60 | 61 | 62 | async def get_cluster(cluster_id: str) -> Dict[str, Any]: 63 | """ 64 | Get information about a specific cluster. 65 | 66 | Args: 67 | cluster_id: ID of the cluster 68 | 69 | Returns: 70 | Response containing cluster information 71 | 72 | Raises: 73 | DatabricksAPIError: If the API request fails 74 | """ 75 | logger.info(f"Getting information for cluster: {cluster_id}") 76 | return await make_api_request("GET", "/api/2.0/clusters/get", params={"cluster_id": cluster_id}) 77 | 78 | 79 | async def start_cluster(cluster_id: str) -> Dict[str, Any]: 80 | """ 81 | Start a terminated Databricks cluster. 82 | 83 | Args: 84 | cluster_id: ID of the cluster to start 85 | 86 | Returns: 87 | Empty response on success 88 | 89 | Raises: 90 | DatabricksAPIError: If the API request fails 91 | """ 92 | logger.info(f"Starting cluster: {cluster_id}") 93 | return await make_api_request("POST", "/api/2.0/clusters/start", data={"cluster_id": cluster_id}) 94 | 95 | 96 | async def resize_cluster(cluster_id: str, num_workers: int) -> Dict[str, Any]: 97 | """ 98 | Resize a cluster by changing the number of workers. 99 | 100 | Args: 101 | cluster_id: ID of the cluster to resize 102 | num_workers: New number of workers 103 | 104 | Returns: 105 | Empty response on success 106 | 107 | Raises: 108 | DatabricksAPIError: If the API request fails 109 | """ 110 | logger.info(f"Resizing cluster {cluster_id} to {num_workers} workers") 111 | return await make_api_request( 112 | "POST", 113 | "/api/2.0/clusters/resize", 114 | data={"cluster_id": cluster_id, "num_workers": num_workers} 115 | ) 116 | 117 | 118 | async def restart_cluster(cluster_id: str) -> Dict[str, Any]: 119 | """ 120 | Restart a Databricks cluster. 121 | 122 | Args: 123 | cluster_id: ID of the cluster to restart 124 | 125 | Returns: 126 | Empty response on success 127 | 128 | Raises: 129 | DatabricksAPIError: If the API request fails 130 | """ 131 | logger.info(f"Restarting cluster: {cluster_id}") 132 | return await make_api_request("POST", "/api/2.0/clusters/restart", data={"cluster_id": cluster_id}) 133 | -------------------------------------------------------------------------------- /databricks_mcp/cli/commands.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command-line interface for the Databricks MCP server. 3 | 4 | This module provides command-line functionality for interacting with the Databricks MCP server. 5 | """ 6 | 7 | import argparse 8 | import asyncio 9 | import json 10 | import logging 11 | from typing import List, Optional 12 | 13 | from databricks_mcp.core.logging_utils import configure_logging 14 | from databricks_mcp.server.databricks_mcp_server import DatabricksMCPServer, main as server_main 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: 20 | """Parse command-line arguments.""" 21 | parser = argparse.ArgumentParser(description="Databricks MCP Server CLI") 22 | 23 | # Create subparsers for different commands 24 | subparsers = parser.add_subparsers(dest="command", help="Command to run") 25 | 26 | # Start server command 27 | start_parser = subparsers.add_parser("start", help="Start the MCP server") 28 | start_parser.add_argument( 29 | "--debug", action="store_true", help="Enable debug logging" 30 | ) 31 | 32 | # List tools command 33 | list_parser = subparsers.add_parser("list-tools", help="List available tools") 34 | 35 | # Version command 36 | subparsers.add_parser("version", help="Show server version") 37 | 38 | # Sync repo and run notebook command 39 | sync_parser = subparsers.add_parser( 40 | "sync-run", 41 | help="Pull a repo and run a notebook", 42 | ) 43 | sync_parser.add_argument("--repo-id", type=int, required=True, help="Repo ID") 44 | sync_parser.add_argument("--notebook-path", required=True, help="Notebook path") 45 | sync_parser.add_argument("--cluster-id", help="Existing cluster ID") 46 | 47 | return parser.parse_args(args) 48 | 49 | 50 | async def list_tools() -> None: 51 | """List all available tools in the server.""" 52 | server = DatabricksMCPServer() 53 | tools = await server.list_tools() 54 | 55 | print("\nAvailable tools:") 56 | for tool in tools: 57 | print(f" - {tool.name}: {tool.description}") 58 | 59 | 60 | def show_version() -> None: 61 | """Show the server version.""" 62 | server = DatabricksMCPServer() 63 | print(f"\nDatabricks MCP Server v{server.version}") 64 | 65 | 66 | async def sync_run(repo_id: int, notebook_path: str, cluster_id: Optional[str]) -> None: 67 | """Convenience wrapper for the sync_repo_and_run_notebook tool.""" 68 | server = DatabricksMCPServer() 69 | params = { 70 | "repo_id": repo_id, 71 | "notebook_path": notebook_path, 72 | } 73 | if cluster_id: 74 | params["existing_cluster_id"] = cluster_id 75 | result = await server.call_tool("sync_repo_and_run_notebook", params) 76 | if result.isError: 77 | print(f"\nError: {result.content[0].text if result.content else 'Unknown failure'}") 78 | if result.meta and result.meta.get("data"): 79 | print(json.dumps(result.meta["data"], indent=2)) 80 | return 81 | 82 | if result.content: 83 | summary = next((block.text for block in result.content if hasattr(block, "text")), None) 84 | if summary: 85 | print(f"\n{summary}") 86 | 87 | if result.meta and result.meta.get("data"): 88 | print(json.dumps(result.meta["data"], indent=2)) 89 | 90 | 91 | def main(args: Optional[List[str]] = None) -> int: 92 | """Main entry point for the CLI.""" 93 | configure_logging() 94 | parsed_args = parse_args(args) 95 | 96 | # Set log level 97 | if hasattr(parsed_args, "debug") and parsed_args.debug: 98 | logging.getLogger().setLevel(logging.DEBUG) 99 | 100 | # Execute the appropriate command 101 | if parsed_args.command == "start": 102 | logger.info("Starting Databricks MCP server") 103 | server_main() 104 | elif parsed_args.command == "list-tools": 105 | asyncio.run(list_tools()) 106 | elif parsed_args.command == "version": 107 | show_version() 108 | elif parsed_args.command == "sync-run": 109 | asyncio.run(sync_run(parsed_args.repo_id, parsed_args.notebook_path, parsed_args.cluster_id)) 110 | else: 111 | # If no command is provided, show help 112 | parse_args(["--help"]) 113 | return 1 114 | 115 | return 0 116 | 117 | 118 | if __name__ == "__main__": 119 | sys.exit(main()) 120 | -------------------------------------------------------------------------------- /tests/test_clusters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the clusters API. 3 | """ 4 | 5 | import json 6 | import os 7 | from unittest.mock import AsyncMock, MagicMock, patch 8 | 9 | import pytest 10 | from fastapi import status 11 | from fastapi.testclient import TestClient 12 | 13 | from databricks_mcp.api import clusters 14 | from databricks_mcp.server.app import create_app 15 | 16 | 17 | @pytest.fixture 18 | def client(): 19 | """Create a test client for the API.""" 20 | app = create_app() 21 | return TestClient(app) 22 | 23 | 24 | @pytest.fixture 25 | def mock_cluster_response(): 26 | """Mock response for cluster operations.""" 27 | return { 28 | "cluster_id": "1234-567890-abcdef", 29 | "cluster_name": "Test Cluster", 30 | "spark_version": "10.4.x-scala2.12", 31 | "node_type_id": "Standard_D3_v2", 32 | "num_workers": 2, 33 | "state": "RUNNING", 34 | "creator_user_name": "test@example.com", 35 | } 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_create_cluster(): 40 | """Test creating a cluster.""" 41 | # Mock the API call 42 | clusters.create_cluster = AsyncMock(return_value={"cluster_id": "1234-567890-abcdef"}) 43 | 44 | # Create cluster config 45 | cluster_config = { 46 | "cluster_name": "Test Cluster", 47 | "spark_version": "10.4.x-scala2.12", 48 | "node_type_id": "Standard_D3_v2", 49 | "num_workers": 2, 50 | } 51 | 52 | # Call the function 53 | response = await clusters.create_cluster(cluster_config) 54 | 55 | # Check the response 56 | assert response["cluster_id"] == "1234-567890-abcdef" 57 | 58 | # Verify the mock was called with the correct arguments 59 | clusters.create_cluster.assert_called_once_with(cluster_config) 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_list_clusters(): 64 | """Test listing clusters.""" 65 | # Mock the API call 66 | mock_response = { 67 | "clusters": [ 68 | { 69 | "cluster_id": "1234-567890-abcdef", 70 | "cluster_name": "Test Cluster 1", 71 | "state": "RUNNING", 72 | }, 73 | { 74 | "cluster_id": "9876-543210-fedcba", 75 | "cluster_name": "Test Cluster 2", 76 | "state": "TERMINATED", 77 | }, 78 | ] 79 | } 80 | clusters.list_clusters = AsyncMock(return_value=mock_response) 81 | 82 | # Call the function 83 | response = await clusters.list_clusters() 84 | 85 | # Check the response 86 | assert len(response["clusters"]) == 2 87 | assert response["clusters"][0]["cluster_id"] == "1234-567890-abcdef" 88 | assert response["clusters"][1]["cluster_id"] == "9876-543210-fedcba" 89 | 90 | # Verify the mock was called 91 | clusters.list_clusters.assert_called_once() 92 | 93 | 94 | @pytest.mark.asyncio 95 | async def test_get_cluster(): 96 | """Test getting cluster information.""" 97 | # Mock the API call 98 | mock_response = { 99 | "cluster_id": "1234-567890-abcdef", 100 | "cluster_name": "Test Cluster", 101 | "state": "RUNNING", 102 | } 103 | clusters.get_cluster = AsyncMock(return_value=mock_response) 104 | 105 | # Call the function 106 | response = await clusters.get_cluster("1234-567890-abcdef") 107 | 108 | # Check the response 109 | assert response["cluster_id"] == "1234-567890-abcdef" 110 | assert response["state"] == "RUNNING" 111 | 112 | # Verify the mock was called with the correct arguments 113 | clusters.get_cluster.assert_called_once_with("1234-567890-abcdef") 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_terminate_cluster(): 118 | """Test terminating a cluster.""" 119 | # Mock the API call 120 | clusters.terminate_cluster = AsyncMock(return_value={}) 121 | 122 | # Call the function 123 | response = await clusters.terminate_cluster("1234-567890-abcdef") 124 | 125 | # Check the response 126 | assert response == {} 127 | 128 | # Verify the mock was called with the correct arguments 129 | clusters.terminate_cluster.assert_called_once_with("1234-567890-abcdef") 130 | 131 | 132 | @pytest.mark.asyncio 133 | async def test_start_cluster(): 134 | """Test starting a cluster.""" 135 | # Mock the API call 136 | clusters.start_cluster = AsyncMock(return_value={}) 137 | 138 | # Call the function 139 | response = await clusters.start_cluster("1234-567890-abcdef") 140 | 141 | # Check the response 142 | assert response == {} 143 | 144 | # Verify the mock was called with the correct arguments 145 | clusters.start_cluster.assert_called_once_with("1234-567890-abcdef") 146 | 147 | 148 | @pytest.mark.asyncio 149 | async def test_resize_cluster(): 150 | """Test resizing a cluster.""" 151 | # Mock the API call 152 | clusters.resize_cluster = AsyncMock(return_value={}) 153 | 154 | # Call the function 155 | response = await clusters.resize_cluster("1234-567890-abcdef", 4) 156 | 157 | # Check the response 158 | assert response == {} 159 | 160 | # Verify the mock was called with the correct arguments 161 | clusters.resize_cluster.assert_called_once_with("1234-567890-abcdef", 4) 162 | 163 | 164 | @pytest.mark.asyncio 165 | async def test_restart_cluster(): 166 | """Test restarting a cluster.""" 167 | # Mock the API call 168 | clusters.restart_cluster = AsyncMock(return_value={}) 169 | 170 | # Call the function 171 | response = await clusters.restart_cluster("1234-567890-abcdef") 172 | 173 | # Check the response 174 | assert response == {} 175 | 176 | # Verify the mock was called with the correct arguments 177 | clusters.restart_cluster.assert_called_once_with("1234-567890-abcdef") -------------------------------------------------------------------------------- /databricks_mcp/core/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for the Databricks MCP server. 3 | """ 4 | 5 | import asyncio 6 | import logging 7 | import random 8 | from contextvars import ContextVar 9 | from typing import Any, Dict, List, Optional, Union 10 | 11 | import httpx 12 | from httpx import HTTPError, HTTPStatusError 13 | 14 | from databricks_mcp.core.config import ( 15 | get_api_headers, 16 | get_databricks_api_url, 17 | settings, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | # Context for propagating correlation IDs into API calls. 23 | request_context_id: ContextVar[Optional[str]] = ContextVar("databricks_mcp_request_id", default=None) 24 | 25 | 26 | class DatabricksAPIError(Exception): 27 | """Exception raised for errors in the Databricks API.""" 28 | 29 | def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Any] = None): 30 | self.message = message 31 | self.status_code = status_code 32 | self.response = response 33 | super().__init__(self.message) 34 | 35 | 36 | async def make_api_request( 37 | method: str, 38 | endpoint: str, 39 | data: Optional[Dict[str, Any]] = None, 40 | params: Optional[Dict[str, Any]] = None, 41 | files: Optional[Dict[str, Any]] = None, 42 | ) -> Dict[str, Any]: 43 | """ 44 | Make a request to the Databricks API. 45 | 46 | Args: 47 | method: HTTP method ("GET", "POST", "PUT", "DELETE") 48 | endpoint: API endpoint path 49 | data: Request body data 50 | params: Query parameters 51 | files: Files to upload 52 | 53 | Returns: 54 | Response data as a dictionary 55 | 56 | Raises: 57 | DatabricksAPIError: If the API request fails 58 | """ 59 | url = get_databricks_api_url(endpoint) 60 | headers = get_api_headers().copy() 61 | request_id = request_context_id.get() 62 | if request_id: 63 | headers["X-Databricks-MCP-Request-ID"] = request_id 64 | retries = settings.API_MAX_RETRIES 65 | backoff_base = settings.API_RETRY_BACKOFF_SECONDS 66 | retry_statuses = {408, 425, 429, 500, 502, 503, 504} 67 | 68 | safe_data = "**REDACTED**" if data else None 69 | logger.debug("API Request: %s %s Params: %s Data: %s", method, url, params, safe_data) 70 | 71 | attempt = 0 72 | while True: 73 | try: 74 | timeout = httpx.Timeout(settings.HTTP_TIMEOUT_SECONDS) 75 | async with httpx.AsyncClient(timeout=timeout) as client: 76 | response = await client.request( 77 | method=method, 78 | url=url, 79 | headers=headers, 80 | params=params, 81 | json=data if not files else None, 82 | data=data if files else None, 83 | files=files, 84 | ) 85 | 86 | response.raise_for_status() 87 | 88 | if response.content: 89 | return response.json() 90 | return {} 91 | 92 | except HTTPStatusError as e: 93 | status_code = e.response.status_code if e.response else None 94 | error_response = None 95 | error_msg = f"API request failed: {e!s}" 96 | if e.response is not None: 97 | try: 98 | error_response = e.response.json() 99 | error_text = error_response.get("error") or error_response.get("message") 100 | if error_text: 101 | error_msg = f"{error_msg} - {error_text}" 102 | except ValueError: 103 | error_response = e.response.text 104 | 105 | if status_code in retry_statuses and attempt < retries: 106 | wait = backoff_base * (2 ** attempt) + random.uniform(0, backoff_base) 107 | logger.warning( 108 | "Retryable Databricks API error (%s). Retrying in %.2fs", 109 | status_code, 110 | wait, 111 | ) 112 | attempt += 1 113 | await asyncio.sleep(wait) 114 | continue 115 | 116 | logger.error("API Error: %s", error_msg, exc_info=True) 117 | raise DatabricksAPIError(error_msg, status_code, error_response) from e 118 | 119 | except HTTPError as e: 120 | status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None 121 | error_msg = f"API request failed: {e!s}" 122 | 123 | if status_code in retry_statuses and attempt < retries: 124 | wait = backoff_base * (2 ** attempt) + random.uniform(0, backoff_base) 125 | logger.warning( 126 | "HTTP transport error (status=%s). Retrying in %.2fs", 127 | status_code, 128 | wait, 129 | ) 130 | attempt += 1 131 | await asyncio.sleep(wait) 132 | continue 133 | 134 | error_response = None 135 | if hasattr(e, "response") and e.response is not None: 136 | try: 137 | error_response = e.response.json() 138 | except ValueError: 139 | error_response = e.response.text 140 | 141 | logger.error("API Error: %s", error_msg, exc_info=True) 142 | raise DatabricksAPIError(error_msg, status_code, error_response) from e 143 | 144 | 145 | def format_response( 146 | success: bool, 147 | data: Optional[Union[Dict[str, Any], List[Any]]] = None, 148 | error: Optional[str] = None, 149 | status_code: int = 200 150 | ) -> Dict[str, Any]: 151 | """ 152 | Format a standardized response. 153 | 154 | Args: 155 | success: Whether the operation was successful 156 | data: Response data 157 | error: Error message if not successful 158 | status_code: HTTP status code 159 | 160 | Returns: 161 | Formatted response dictionary 162 | """ 163 | response = { 164 | "success": success, 165 | "status_code": status_code, 166 | } 167 | 168 | if data is not None: 169 | response["data"] = data 170 | 171 | if error: 172 | response["error"] = error 173 | 174 | return response 175 | -------------------------------------------------------------------------------- /databricks_mcp/api/sql.py: -------------------------------------------------------------------------------- 1 | """ 2 | API for executing SQL statements on Databricks. 3 | """ 4 | 5 | import logging 6 | from typing import Any, Dict, List, Optional 7 | 8 | from databricks_mcp.core.utils import DatabricksAPIError, make_api_request 9 | from databricks_mcp.core.config import settings 10 | 11 | # Configure logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | async def execute_statement( 16 | statement: str, 17 | warehouse_id: Optional[str] = None, 18 | catalog: Optional[str] = None, 19 | schema: Optional[str] = None, 20 | parameters: Optional[Dict[str, Any]] = None, 21 | row_limit: int = 10000, 22 | byte_limit: int = 100000000, # 100MB 23 | ) -> Dict[str, Any]: 24 | """ 25 | Execute a SQL statement. 26 | 27 | Args: 28 | statement: The SQL statement to execute 29 | warehouse_id: ID of the SQL warehouse to use (optional if DATABRICKS_WAREHOUSE_ID is set) (optional if DATABRICKS_WAREHOUSE_ID is set) 30 | catalog: Optional catalog to use 31 | schema: Optional schema to use 32 | parameters: Optional statement parameters 33 | row_limit: Maximum number of rows to return 34 | byte_limit: Maximum number of bytes to return 35 | 36 | Returns: 37 | Response containing query results 38 | 39 | Raises: 40 | DatabricksAPIError: If the API request fails 41 | ValueError: If no warehouse_id is provided and DATABRICKS_WAREHOUSE_ID is not set 42 | """ 43 | logger.info(f"Executing SQL statement: {statement[:100]}...") 44 | 45 | # Use provided warehouse_id or fall back to environment variable 46 | effective_warehouse_id = warehouse_id or settings.DATABRICKS_WAREHOUSE_ID 47 | 48 | if not effective_warehouse_id: 49 | raise ValueError( 50 | "warehouse_id must be provided either as parameter or " 51 | "set DATABRICKS_WAREHOUSE_ID environment variable" 52 | ) 53 | 54 | request_data = { 55 | "statement": statement, 56 | "warehouse_id": effective_warehouse_id, 57 | "wait_timeout": "10s", 58 | "format": "JSON_ARRAY", 59 | "disposition": "INLINE", 60 | "row_limit": row_limit, 61 | "byte_limit": 16777216, 62 | } 63 | 64 | if catalog: 65 | request_data["catalog"] = catalog 66 | 67 | if schema: 68 | request_data["schema"] = schema 69 | 70 | if parameters: 71 | request_data["parameters"] = parameters 72 | 73 | return await make_api_request("POST", "/api/2.0/sql/statements", data=request_data) 74 | 75 | 76 | async def execute_and_wait( 77 | statement: str, 78 | warehouse_id: Optional[str] = None, 79 | catalog: Optional[str] = None, 80 | schema: Optional[str] = None, 81 | parameters: Optional[Dict[str, Any]] = None, 82 | timeout_seconds: int = 300, # 5 minutes 83 | poll_interval_seconds: int = 1, 84 | ) -> Dict[str, Any]: 85 | """ 86 | Execute a SQL statement and wait for completion. 87 | 88 | Args: 89 | statement: The SQL statement to execute 90 | warehouse_id: ID of the SQL warehouse to use (optional if DATABRICKS_WAREHOUSE_ID is set) 91 | catalog: Optional catalog to use 92 | schema: Optional schema to use 93 | parameters: Optional statement parameters 94 | timeout_seconds: Maximum time to wait for completion 95 | poll_interval_seconds: How often to poll for status 96 | 97 | Returns: 98 | Response containing query results 99 | 100 | Raises: 101 | DatabricksAPIError: If the API request fails 102 | TimeoutError: If query execution times out 103 | """ 104 | import asyncio 105 | import time 106 | 107 | logger.info(f"Executing SQL statement with waiting: {statement[:100]}...") 108 | 109 | # Start execution 110 | response = await execute_statement( 111 | statement=statement, 112 | warehouse_id=warehouse_id, 113 | catalog=catalog, 114 | schema=schema, 115 | parameters=parameters, 116 | ) 117 | 118 | statement_id = response.get("statement_id") 119 | if not statement_id: 120 | raise ValueError("No statement_id returned from execution") 121 | 122 | # Poll for completion 123 | start_time = time.time() 124 | status = response.get("status", {}).get("state", "") 125 | 126 | while status in ["PENDING", "RUNNING"]: 127 | # Check timeout 128 | if time.time() - start_time > timeout_seconds: 129 | raise TimeoutError(f"Query execution timed out after {timeout_seconds} seconds") 130 | 131 | # Wait before polling again 132 | await asyncio.sleep(poll_interval_seconds) 133 | 134 | # Check status 135 | status_response = await get_statement_status(statement_id) 136 | status = status_response.get("status", {}).get("state", "") 137 | 138 | if status == "SUCCEEDED": 139 | return status_response 140 | elif status in ["FAILED", "CANCELED", "CLOSED"]: 141 | error_message = status_response.get("status", {}).get("error", {}).get("message", "Unknown error") 142 | raise DatabricksAPIError(f"Query execution failed: {error_message}", response=status_response) 143 | 144 | return response 145 | 146 | 147 | async def get_statement_status(statement_id: str) -> Dict[str, Any]: 148 | """ 149 | Get the status of a SQL statement. 150 | 151 | Args: 152 | statement_id: ID of the statement to check 153 | 154 | Returns: 155 | Response containing statement status 156 | 157 | Raises: 158 | DatabricksAPIError: If the API request fails 159 | """ 160 | logger.info(f"Getting status of SQL statement: {statement_id}") 161 | return await make_api_request("GET", f"/api/2.0/sql/statements/{statement_id}", params={}) 162 | 163 | 164 | async def cancel_statement(statement_id: str) -> Dict[str, Any]: 165 | """ 166 | Cancel a running SQL statement. 167 | 168 | Args: 169 | statement_id: ID of the statement to cancel 170 | 171 | Returns: 172 | Empty response on success 173 | 174 | Raises: 175 | DatabricksAPIError: If the API request fails 176 | """ 177 | logger.info(f"Cancelling SQL statement: {statement_id}") 178 | return await make_api_request("POST", f"/api/2.0/sql/statements/{statement_id}/cancel", data={}) 179 | -------------------------------------------------------------------------------- /databricks_mcp/api/dbfs.py: -------------------------------------------------------------------------------- 1 | """ 2 | API for managing Databricks File System (DBFS). 3 | """ 4 | 5 | import base64 6 | import logging 7 | import os 8 | from typing import Any, Dict, List, Optional, BinaryIO 9 | 10 | from databricks_mcp.core.utils import DatabricksAPIError, make_api_request 11 | 12 | # Configure logging 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | async def put_file( 17 | dbfs_path: str, 18 | file_content: bytes, 19 | overwrite: bool = True, 20 | ) -> Dict[str, Any]: 21 | """ 22 | Upload a file to DBFS. 23 | 24 | Args: 25 | dbfs_path: The path where the file should be stored in DBFS 26 | file_content: The content of the file as bytes 27 | overwrite: Whether to overwrite an existing file 28 | 29 | Returns: 30 | Empty response on success 31 | 32 | Raises: 33 | DatabricksAPIError: If the API request fails 34 | """ 35 | logger.info(f"Uploading file to DBFS path: {dbfs_path}") 36 | 37 | # Convert bytes to base64 38 | content_base64 = base64.b64encode(file_content).decode("utf-8") 39 | 40 | return await make_api_request( 41 | "POST", 42 | "/api/2.0/dbfs/put", 43 | data={ 44 | "path": dbfs_path, 45 | "contents": content_base64, 46 | "overwrite": overwrite, 47 | }, 48 | ) 49 | 50 | 51 | async def upload_large_file( 52 | dbfs_path: str, 53 | local_file_path: str, 54 | overwrite: bool = True, 55 | buffer_size: int = 1024 * 1024, # 1MB chunks 56 | ) -> Dict[str, Any]: 57 | """ 58 | Upload a large file to DBFS in chunks. 59 | 60 | Args: 61 | dbfs_path: The path where the file should be stored in DBFS 62 | local_file_path: Local path to the file to upload 63 | overwrite: Whether to overwrite an existing file 64 | buffer_size: Size of chunks to upload 65 | 66 | Returns: 67 | Empty response on success 68 | 69 | Raises: 70 | DatabricksAPIError: If the API request fails 71 | FileNotFoundError: If the local file does not exist 72 | """ 73 | logger.info(f"Uploading large file from {local_file_path} to DBFS path: {dbfs_path}") 74 | 75 | if not os.path.exists(local_file_path): 76 | raise FileNotFoundError(f"Local file not found: {local_file_path}") 77 | 78 | # Create a handle for the upload 79 | create_response = await make_api_request( 80 | "POST", 81 | "/api/2.0/dbfs/create", 82 | data={ 83 | "path": dbfs_path, 84 | "overwrite": overwrite, 85 | }, 86 | ) 87 | 88 | handle = create_response.get("handle") 89 | 90 | try: 91 | with open(local_file_path, "rb") as f: 92 | chunk_index = 0 93 | while True: 94 | chunk = f.read(buffer_size) 95 | if not chunk: 96 | break 97 | 98 | # Convert chunk to base64 99 | chunk_base64 = base64.b64encode(chunk).decode("utf-8") 100 | 101 | # Add to handle 102 | await make_api_request( 103 | "POST", 104 | "/api/2.0/dbfs/add-block", 105 | data={ 106 | "handle": handle, 107 | "data": chunk_base64, 108 | }, 109 | ) 110 | 111 | chunk_index += 1 112 | logger.debug(f"Uploaded chunk {chunk_index}") 113 | 114 | # Close the handle 115 | return await make_api_request( 116 | "POST", 117 | "/api/2.0/dbfs/close", 118 | data={"handle": handle}, 119 | ) 120 | 121 | except Exception as e: 122 | # Attempt to abort the upload on error 123 | try: 124 | await make_api_request( 125 | "POST", 126 | "/api/2.0/dbfs/close", 127 | data={"handle": handle}, 128 | ) 129 | except Exception: 130 | pass 131 | 132 | logger.error(f"Error uploading file: {str(e)}") 133 | raise 134 | 135 | 136 | async def get_file( 137 | dbfs_path: str, 138 | offset: int = 0, 139 | length: int = 1024 * 1024, # Default to 1MB 140 | ) -> Dict[str, Any]: 141 | """ 142 | Get the contents of a file from DBFS. 143 | 144 | Args: 145 | dbfs_path: The path of the file in DBFS 146 | offset: Starting byte position 147 | length: Number of bytes to read 148 | 149 | Returns: 150 | Response containing the file content 151 | 152 | Raises: 153 | DatabricksAPIError: If the API request fails 154 | """ 155 | logger.info(f"Reading file from DBFS path: {dbfs_path}") 156 | 157 | response = await make_api_request( 158 | "GET", 159 | "/api/2.0/dbfs/read", 160 | params={ 161 | "path": dbfs_path, 162 | "offset": offset, 163 | "length": length, 164 | }, 165 | ) 166 | 167 | # Decode base64 content 168 | if "data" in response: 169 | try: 170 | response["decoded_data"] = base64.b64decode(response["data"]) 171 | except Exception as e: 172 | logger.warning(f"Failed to decode file content: {str(e)}") 173 | 174 | return response 175 | 176 | 177 | async def list_files(dbfs_path: str) -> Dict[str, Any]: 178 | """ 179 | List files and directories in a DBFS path. 180 | 181 | Args: 182 | dbfs_path: The path to list 183 | 184 | Returns: 185 | Response containing the directory listing 186 | 187 | Raises: 188 | DatabricksAPIError: If the API request fails 189 | """ 190 | logger.info(f"Listing files in DBFS path: {dbfs_path}") 191 | return await make_api_request("GET", "/api/2.0/dbfs/list", params={"path": dbfs_path}) 192 | 193 | 194 | async def delete_file( 195 | dbfs_path: str, 196 | recursive: bool = False, 197 | ) -> Dict[str, Any]: 198 | """ 199 | Delete a file or directory from DBFS. 200 | 201 | Args: 202 | dbfs_path: The path to delete 203 | recursive: Whether to recursively delete directories 204 | 205 | Returns: 206 | Empty response on success 207 | 208 | Raises: 209 | DatabricksAPIError: If the API request fails 210 | """ 211 | logger.info(f"Deleting DBFS path: {dbfs_path}") 212 | return await make_api_request( 213 | "POST", 214 | "/api/2.0/dbfs/delete", 215 | data={ 216 | "path": dbfs_path, 217 | "recursive": recursive, 218 | }, 219 | ) 220 | 221 | 222 | async def get_status(dbfs_path: str) -> Dict[str, Any]: 223 | """ 224 | Get the status of a file or directory. 225 | 226 | Args: 227 | dbfs_path: The path to check 228 | 229 | Returns: 230 | Response containing file status 231 | 232 | Raises: 233 | DatabricksAPIError: If the API request fails 234 | """ 235 | logger.info(f"Getting status of DBFS path: {dbfs_path}") 236 | return await make_api_request("GET", "/api/2.0/dbfs/get-status", params={"path": dbfs_path}) 237 | 238 | 239 | async def create_directory(dbfs_path: str) -> Dict[str, Any]: 240 | """ 241 | Create a directory in DBFS. 242 | 243 | Args: 244 | dbfs_path: The path to create 245 | 246 | Returns: 247 | Empty response on success 248 | 249 | Raises: 250 | DatabricksAPIError: If the API request fails 251 | """ 252 | logger.info(f"Creating DBFS directory: {dbfs_path}") 253 | return await make_api_request("POST", "/api/2.0/dbfs/mkdirs", data={"path": dbfs_path}) 254 | -------------------------------------------------------------------------------- /databricks_mcp/api/jobs.py: -------------------------------------------------------------------------------- 1 | """ 2 | API for managing Databricks jobs. 3 | """ 4 | 5 | import logging 6 | import asyncio 7 | import time 8 | from typing import Any, Dict, List, Optional, Union 9 | 10 | from databricks_mcp.core.models import Job 11 | from databricks_mcp.core.utils import DatabricksAPIError, make_api_request 12 | 13 | # Configure logging 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | async def create_job(job_config: Union[Job, Dict[str, Any]]) -> Dict[str, Any]: 18 | """ 19 | Create a new Databricks job. 20 | 21 | Args: 22 | job_config: Job configuration 23 | 24 | Returns: 25 | Response containing the job ID 26 | 27 | Raises: 28 | DatabricksAPIError: If the API request fails 29 | """ 30 | logger.info("Creating new job") 31 | 32 | if isinstance(job_config, Job): 33 | payload = job_config.model_dump(exclude_none=True) 34 | else: 35 | payload = job_config 36 | 37 | return await make_api_request("POST", "/api/2.2/jobs/create", data=payload) 38 | 39 | 40 | async def run_job(job_id: int, notebook_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 41 | """ 42 | Run a job now. 43 | 44 | Args: 45 | job_id: ID of the job to run 46 | notebook_params: Optional parameters for the notebook 47 | 48 | Returns: 49 | Response containing the run ID 50 | 51 | Raises: 52 | DatabricksAPIError: If the API request fails 53 | """ 54 | logger.info(f"Running job: {job_id}") 55 | 56 | run_params = {"job_id": job_id} 57 | if notebook_params: 58 | run_params["notebook_params"] = notebook_params 59 | 60 | return await make_api_request("POST", "/api/2.0/jobs/run-now", data=run_params) 61 | 62 | 63 | async def list_jobs() -> Dict[str, Any]: 64 | """ 65 | List all jobs. 66 | 67 | Returns: 68 | Response containing a list of jobs 69 | 70 | Raises: 71 | DatabricksAPIError: If the API request fails 72 | """ 73 | logger.info("Listing all jobs") 74 | return await make_api_request("GET", "/api/2.0/jobs/list") 75 | 76 | 77 | async def get_job(job_id: int) -> Dict[str, Any]: 78 | """ 79 | Get information about a specific job. 80 | 81 | Args: 82 | job_id: ID of the job 83 | 84 | Returns: 85 | Response containing job information 86 | 87 | Raises: 88 | DatabricksAPIError: If the API request fails 89 | """ 90 | logger.info(f"Getting information for job: {job_id}") 91 | return await make_api_request("GET", "/api/2.0/jobs/get", params={"job_id": job_id}) 92 | 93 | 94 | async def update_job(job_id: int, new_settings: Dict[str, Any]) -> Dict[str, Any]: 95 | """ 96 | Update an existing job. 97 | 98 | Args: 99 | job_id: ID of the job to update 100 | new_settings: New job settings 101 | 102 | Returns: 103 | Empty response on success 104 | 105 | Raises: 106 | DatabricksAPIError: If the API request fails 107 | """ 108 | logger.info(f"Updating job: {job_id}") 109 | 110 | update_data = { 111 | "job_id": job_id, 112 | "new_settings": new_settings 113 | } 114 | 115 | return await make_api_request("POST", "/api/2.0/jobs/update", data=update_data) 116 | 117 | 118 | async def delete_job(job_id: int) -> Dict[str, Any]: 119 | """ 120 | Delete a job. 121 | 122 | Args: 123 | job_id: ID of the job to delete 124 | 125 | Returns: 126 | Empty response on success 127 | 128 | Raises: 129 | DatabricksAPIError: If the API request fails 130 | """ 131 | logger.info(f"Deleting job: {job_id}") 132 | return await make_api_request("POST", "/api/2.2/jobs/delete", data={"job_id": job_id}) 133 | 134 | 135 | async def get_run(run_id: int) -> Dict[str, Any]: 136 | """ 137 | Get information about a specific job run. 138 | 139 | Args: 140 | run_id: ID of the run 141 | 142 | Returns: 143 | Response containing run information 144 | 145 | Raises: 146 | DatabricksAPIError: If the API request fails 147 | """ 148 | logger.info(f"Getting information for run: {run_id}") 149 | return await make_api_request("GET", "/api/2.1/jobs/runs/get", params={"run_id": run_id}) 150 | 151 | 152 | async def list_runs(job_id: Optional[int] = None, limit: int = 20) -> Dict[str, Any]: 153 | """List job runs.""" 154 | logger.info("Listing job runs") 155 | params: Dict[str, Any] = {"limit": limit} 156 | if job_id is not None: 157 | params["job_id"] = job_id 158 | return await make_api_request("GET", "/api/2.1/jobs/runs/list", params=params) 159 | 160 | 161 | async def get_run_status(run_id: int) -> Dict[str, Any]: 162 | """Get concise status information for a run.""" 163 | info = await get_run(run_id) 164 | state = info.get("state", {}) 165 | return { 166 | "state": state.get("result_state") or state.get("life_cycle_state"), 167 | "life_cycle": state.get("life_cycle_state"), 168 | "run_id": run_id, 169 | } 170 | 171 | 172 | async def cancel_run(run_id: int) -> Dict[str, Any]: 173 | """ 174 | Cancel a job run. 175 | 176 | Args: 177 | run_id: ID of the run to cancel 178 | 179 | Returns: 180 | Empty response on success 181 | 182 | Raises: 183 | DatabricksAPIError: If the API request fails 184 | """ 185 | logger.info(f"Cancelling run: {run_id}") 186 | return await make_api_request("POST", "/api/2.1/jobs/runs/cancel", data={"run_id": run_id}) 187 | 188 | 189 | async def submit_run(run_config: Dict[str, Any]) -> Dict[str, Any]: 190 | """Submit a one-time run. 191 | 192 | Args: 193 | run_config: Configuration for the run 194 | 195 | Returns: 196 | Response containing the run ID 197 | """ 198 | logger.info("Submitting one-time run") 199 | return await make_api_request("POST", "/api/2.0/jobs/runs/submit", data=run_config) 200 | 201 | 202 | async def get_run_output(run_id: int) -> Dict[str, Any]: 203 | """Get the output of a run.""" 204 | logger.info(f"Fetching output for run {run_id}") 205 | return await make_api_request("GET", "/api/2.0/jobs/runs/get-output", params={"run_id": run_id}) 206 | 207 | 208 | async def await_until_state( 209 | run_id: int, 210 | desired_state: str = "TERMINATED", 211 | timeout_seconds: int = 600, 212 | poll_interval_seconds: int = 5, 213 | ) -> Dict[str, Any]: 214 | """Wait for a run to reach a desired state.""" 215 | start = time.monotonic() 216 | while True: 217 | run_info = await get_run(run_id) 218 | state = run_info.get("state", {}).get("life_cycle_state") 219 | if state == desired_state: 220 | return run_info 221 | if time.monotonic() - start > timeout_seconds: 222 | raise TimeoutError(f"Run {run_id} did not reach state {desired_state} within {timeout_seconds}s") 223 | await asyncio.sleep(poll_interval_seconds) 224 | 225 | 226 | async def run_notebook( 227 | notebook_path: str, 228 | existing_cluster_id: Optional[str] = None, 229 | base_parameters: Optional[Dict[str, Any]] = None, 230 | timeout_seconds: int = 600, 231 | poll_interval_seconds: int = 5, 232 | ) -> Dict[str, Any]: 233 | """Submit a one-time run for a notebook and wait for completion.""" 234 | task = { 235 | "task_key": "run_notebook", 236 | "notebook_task": {"notebook_path": notebook_path}, 237 | } 238 | if base_parameters: 239 | task["notebook_task"]["base_parameters"] = base_parameters 240 | if existing_cluster_id: 241 | task["existing_cluster_id"] = existing_cluster_id 242 | 243 | run_conf = {"tasks": [task]} 244 | submit_response = await submit_run(run_conf) 245 | run_id = submit_response.get("run_id") 246 | if not run_id: 247 | raise ValueError("submit_run did not return a run_id") 248 | 249 | await await_until_state(run_id, timeout_seconds=timeout_seconds, poll_interval_seconds=poll_interval_seconds) 250 | output = await get_run_output(run_id) 251 | output["run_id"] = run_id 252 | return output 253 | -------------------------------------------------------------------------------- /AGENTS.md: -------------------------------------------------------------------------------- 1 | # Repository Guidelines 2 | 3 | ## 1. Core Philosophy 4 | 5 | Our development is guided by a few key principles. Always consider these when making design decisions. 6 | 7 | - **Simplicity and Pragmatism**: Prefer simple, explicit, and readable solutions. If a complex design pattern isn't immediately necessary, don't implement it. Code should be easy to delete. 8 | - **Compose, Don't Duplicate**: Encapsulate reusable logic into `components/` modules. Build complex features by composing these smaller, independent units. This improves testability and maintainability. 9 | - **Stateless by Default**: Design tools and components to be stateless whenever possible. State adds significant complexity to testing, reasoning, and scaling. Pass state explicitly as parameters rather than relying on side effects or implicit context. 10 | - **Fail Fast and Loud**: Don't swallow errors. Let exceptions propagate up to a centralized handler. This makes bugs immediately obvious and debugging far simpler. Avoid defensive `try/except` blocks deep within application logic. 11 | 12 | ## 2. Code Conventions & Style 13 | 14 | Consistency is key to a readable and maintainable codebase. 15 | 16 | - **File and Directory Naming**: Use `kebab-case` for all files and directories (e.g., `databricks-api.py`, `utils/`). 17 | - **Function Definitions**: Always prefer `def` functions over `lambda`. `def` provides a name for debugging, supports docstrings, and allows for proper type annotation. 18 | - **Type Hinting**: 19 | - Annotate all function parameters and return types. Use Python 3.10+ syntax (e.g., `list[str]` instead of `List[str]`). 20 | - Never cast to `Any`. If you must use it, use `typing.cast` with a clear comment explaining why no stricter type is available. 21 | - Represent collections with standard types like `list[T]`, `dict[K, V]`, and `set[T]`. 22 | - **Input Validation**: Use `pydantic` models for validating the structure and types of external inputs (e.g., API requests, tool arguments). 23 | - Name schemas with a `Schema` suffix (e.g., `CreateClusterSchema`). 24 | - Keep Pydantic models as pure data contracts. Avoid adding business logic or complex methods to them. 25 | - **Imports**: Use absolute imports for project modules. Never use dynamic imports like `importlib.import_module` or `await import(...)` unless it's a core requirement of a plug-in system. 26 | - **Formatting**: Follow PEP 8, enforced by Black with a line length of 88 characters and 4-space indentation. 27 | 28 | ## 3. Project Structure & Modularity 29 | 30 | A clear structure makes the system easier to navigate and understand. 31 | 32 | - **`databricks_mcp/`**: The primary application package. 33 | - **`core/`**: Shared Pydantic models, custom exception classes, and authentication logic. This code should be generic and application-agnostic. 34 | - **`api/`**: Thin, stateless clients for external REST APIs (e.g., Databricks). This layer is responsible for HTTP requests, authentication, and translating API errors into well-defined exceptions. It should not contain any business logic. 35 | - **`components/`**: Reusable modules that encapsulate specific business logic (e.g., parsing a notebook, calculating cluster costs). Components should be independent and easily testable in isolation. 36 | - **`server/` or `tools/`**: The layer that composes components into agent-callable tools. It defines tool inputs (`SomethingSchema`), orchestrates calls to components, and formats the final `CallToolResult`. 37 | - **`cli/`**: Command-line entry points and interfaces. 38 | - **`tests/`**: Contains all tests, mirroring the application structure. 39 | - **File Length**: Aim to keep files under 400 lines. A long file is often a sign that logic can be refactored into a separate, more focused component in `components/` or a helper in a `utils/` module. 40 | 41 | ## 4. Agent Tool Design 42 | 43 | Well-designed tools are the foundation of a capable agent. 44 | 45 | - **Single Responsibility**: Each tool should do one thing and do it well. A tool to `list_clusters` should not also have an option to delete one. 46 | - **Clear Inputs & Outputs**: 47 | - Define inputs with a dedicated Pydantic `...Schema` model. This provides automatic validation and clear documentation. 48 | - Return a `CallToolResult` object. 49 | - The `TextContent` should be a concise, human-readable summary for the LLM. 50 | - Store structured, machine-readable data in `_meta['data']`. This is what downstream tools or programmatic clients will use. 51 | - For large artifacts (e.g., files, query results), use server cache helpers and place a reference handle in `_meta['resources']`. 52 | - **Idempotency**: Strive to make tools idempotent where possible. A tool that creates a resource should be safely runnable multiple times without creating duplicate resources. 53 | - **Naming**: Tool identifiers should be `snake_case` and clearly describe their action (e.g., `get_cluster_details`, `submit_spark_job`). 54 | 55 | ## 5. Error Handling & Logging 56 | 57 | A robust system anticipates and reports failures clearly. 58 | 59 | - **Error Handling**: 60 | - Use custom, specific exceptions for predictable application errors (e.g., `ClusterNotFoundError`, `InvalidConfigurationError`). 61 | - Use `try/except` blocks sparingly, primarily at system boundaries (e.g., in the `api/` layer to catch network errors, or in the main server loop to catch unhandled exceptions). 62 | - Avoid broad `except Exception:` blocks that hide bugs. If you must catch a broad exception, re-raise it or log it with a full traceback. 63 | - **Logging**: 64 | - Use the standard `logging` module. 65 | - Log at appropriate levels: 66 | - `INFO`: Key lifecycle events (e.g., "Server starting," "Tool 'list_clusters' invoked"). 67 | - `DEBUG`: Verbose, detailed information useful for tracing execution flow (e.g., API request bodies, intermediate values). 68 | - `WARNING`: A potential issue was encountered but the operation succeeded (e.g., "API deprecated, using fallback"). 69 | - `ERROR`: An operation failed due to a handled exception. 70 | - Logs write to `databricks_mcp.log`. Be mindful of logging sensitive information. 71 | 72 | ## 6. Development & Testing Workflow 73 | 74 | Follow these steps to ensure code quality and a smooth contribution process. 75 | 76 | - **Environment Setup**: Manage environments exclusively with `uv`. 77 | - Initial setup: `uv pip install -e . && uv pip install -e ".[dev]"` 78 | - **Local Development & Testing**: 79 | 1. Write your code and corresponding tests. 80 | 2. Run tests locally: `uv run pytest` 81 | 3. Format and lint your code: `uv run black .` and `uv run pylint databricks_mcp tests` 82 | 4. Manually test CLI or server functionality if needed: `uv run databricks-mcp -- --help` 83 | - **Testing Guidelines**: 84 | - Use `pytest` with `pytest-asyncio` for asynchronous code. 85 | - Mock external dependencies at the `api/` boundary. Tests must be fast and runnable offline. 86 | - **Test behavior, not implementation.** Assert that a function produces the correct output for a given input, not that it calls internal helpers in a specific way. 87 | - For tool tests, assert the structure and content of `_meta['data']` and `_meta['resources']`. 88 | 89 | ## 7. Commits & Pull Requests 90 | 91 | Clear communication helps the team review and integrate your work efficiently. 92 | 93 | - **Commit Messages**: 94 | - Use descriptive, sentence-case titles (e.g., "Add tool to retrieve Spark job status"). 95 | - Reference issues in the body with `Fixes #` or `Refs #`. 96 | - Keep body text wrapped at 72 columns for readability. 97 | - **Pull Requests (PRs)**: 98 | - Write a clear summary of the changes and the problem being solved. 99 | - Use a checklist for complex changes. 100 | - Include screenshots, logs, or command-line captures to demonstrate new behavior. 101 | - Ensure all automated checks (CI) are passing before requesting a review. 102 | - Surface at least two implementation options for complex features, invite critique, and justify your chosen approach. 103 | 104 | ## 8. Security & Configuration 105 | 106 | - **Secrets Management**: Store secrets and environment-specific configuration in a `.env` file (see `.env.example`). This file is git-ignored. 107 | - **Required Variables**: `DATABRICKS_HOST`, `DATABRICKS_TOKEN`. 108 | - **Log Rotation**: Before sharing log artifacts, ensure `databricks_mcp.log` has been rotated or cleared of sensitive information. -------------------------------------------------------------------------------- /databricks_mcp/api/notebooks.py: -------------------------------------------------------------------------------- 1 | """ 2 | API for managing Databricks notebooks. 3 | """ 4 | 5 | import base64 6 | import logging 7 | from typing import Any, Dict, List, Optional 8 | 9 | from databricks_mcp.core.utils import DatabricksAPIError, make_api_request 10 | 11 | # Configure logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | async def import_notebook( 16 | path: str, 17 | content: str, 18 | format: str = "SOURCE", 19 | language: Optional[str] = None, 20 | overwrite: bool = False, 21 | ) -> Dict[str, Any]: 22 | """ 23 | Import a notebook into the workspace. 24 | 25 | Args: 26 | path: The path where the notebook should be stored 27 | content: The content of the notebook (base64 encoded) 28 | format: The format of the notebook (SOURCE, HTML, JUPYTER, DBC) 29 | language: The language of the notebook (SCALA, PYTHON, SQL, R) 30 | overwrite: Whether to overwrite an existing notebook 31 | 32 | Returns: 33 | Empty response on success 34 | 35 | Raises: 36 | DatabricksAPIError: If the API request fails 37 | """ 38 | logger.info(f"Importing notebook to path: {path}") 39 | 40 | # Ensure content is base64 encoded 41 | if not is_base64(content): 42 | content = base64.b64encode(content.encode("utf-8")).decode("utf-8") 43 | 44 | import_data = { 45 | "path": path, 46 | "format": format, 47 | "content": content, 48 | "overwrite": overwrite, 49 | } 50 | 51 | if language: 52 | import_data["language"] = language 53 | 54 | return await make_api_request("POST", "/api/2.0/workspace/import", data=import_data) 55 | 56 | 57 | async def export_notebook( 58 | path: str, 59 | format: str = "SOURCE", 60 | ) -> Dict[str, Any]: 61 | """ 62 | Export a notebook from the workspace. 63 | 64 | Args: 65 | path: The path of the notebook to export 66 | format: The format to export (SOURCE, HTML, JUPYTER, DBC) 67 | 68 | Returns: 69 | Response containing the notebook content 70 | 71 | Raises: 72 | DatabricksAPIError: If the API request fails 73 | """ 74 | logger.info(f"Exporting notebook from path: {path}") 75 | 76 | params = { 77 | "path": path, 78 | "format": format, 79 | } 80 | 81 | response = await make_api_request("GET", "/api/2.0/workspace/export", params=params) 82 | 83 | # Optionally decode base64 content 84 | if "content" in response and format in ["SOURCE", "JUPYTER"]: 85 | try: 86 | response["decoded_content"] = base64.b64decode(response["content"]).decode("utf-8") 87 | except Exception as e: 88 | logger.warning(f"Failed to decode notebook content: {str(e)}") 89 | 90 | return response 91 | 92 | 93 | async def list_notebooks(path: str) -> Dict[str, Any]: 94 | """ 95 | List notebooks in a workspace directory. 96 | 97 | Args: 98 | path: The path to list 99 | 100 | Returns: 101 | Response containing the directory listing 102 | 103 | Raises: 104 | DatabricksAPIError: If the API request fails 105 | """ 106 | logger.info(f"Listing notebooks in path: {path}") 107 | return await make_api_request("GET", "/api/2.0/workspace/list", params={"path": path}) 108 | 109 | 110 | async def delete_notebook(path: str, recursive: bool = False) -> Dict[str, Any]: 111 | """ 112 | Delete a notebook or directory. 113 | 114 | Args: 115 | path: The path to delete 116 | recursive: Whether to recursively delete directories 117 | 118 | Returns: 119 | Empty response on success 120 | 121 | Raises: 122 | DatabricksAPIError: If the API request fails 123 | """ 124 | logger.info(f"Deleting path: {path}") 125 | return await make_api_request( 126 | "POST", 127 | "/api/2.0/workspace/delete", 128 | data={"path": path, "recursive": recursive} 129 | ) 130 | 131 | 132 | async def create_directory(path: str) -> Dict[str, Any]: 133 | """ 134 | Create a directory in the workspace. 135 | 136 | Args: 137 | path: The path to create 138 | 139 | Returns: 140 | Empty response on success 141 | 142 | Raises: 143 | DatabricksAPIError: If the API request fails 144 | """ 145 | logger.info(f"Creating directory: {path}") 146 | return await make_api_request("POST", "/api/2.0/workspace/mkdirs", data={"path": path}) 147 | 148 | 149 | async def export_workspace_file( 150 | path: str, 151 | format: str = "SOURCE", 152 | ) -> Dict[str, Any]: 153 | """ 154 | Export any file from the workspace (not just notebooks). 155 | 156 | Args: 157 | path: The workspace path of the file to export (e.g., /Users/user@domain.com/file.json) 158 | format: The format to export (SOURCE, HTML, JUPYTER, DBC) 159 | 160 | Returns: 161 | Response containing the file content 162 | 163 | Raises: 164 | DatabricksAPIError: If the API request fails 165 | """ 166 | logger.info(f"Exporting workspace file from path: {path}") 167 | 168 | params = { 169 | "path": path, 170 | "format": format, 171 | } 172 | 173 | response = await make_api_request("GET", "/api/2.0/workspace/export", params=params) 174 | 175 | # Always try to decode base64 content for SOURCE format 176 | if "content" in response and format == "SOURCE": 177 | try: 178 | decoded_content = base64.b64decode(response["content"]).decode("utf-8") 179 | response["decoded_content"] = decoded_content 180 | response["content_type"] = "text" 181 | 182 | # Try to detect if it's JSON 183 | try: 184 | import json 185 | json.loads(decoded_content) # Validate JSON 186 | response["content_type"] = "json" 187 | except json.JSONDecodeError: 188 | pass # Keep as text 189 | 190 | except UnicodeDecodeError as e: 191 | logger.warning(f"Failed to decode file content as UTF-8: {str(e)}") 192 | # Try different encodings 193 | try: 194 | decoded_bytes = base64.b64decode(response["content"]) 195 | # Return as text with error replacement 196 | response["decoded_content"] = decoded_bytes.decode("utf-8", errors="replace") 197 | response["content_type"] = "text" 198 | response["encoding_warning"] = "Some characters may not display correctly" 199 | except Exception as e2: 200 | logger.warning(f"Failed to decode content with any encoding: {str(e2)}") 201 | response["content_type"] = "binary" 202 | response["note"] = "Content could not be decoded as text" 203 | 204 | return response 205 | 206 | 207 | async def get_workspace_file_info(path: str) -> Dict[str, Any]: 208 | """ 209 | Get information about a workspace file without downloading content. 210 | 211 | Args: 212 | path: The workspace path to check 213 | 214 | Returns: 215 | Response containing file information 216 | 217 | Raises: 218 | DatabricksAPIError: If the API request fails 219 | """ 220 | logger.info(f"Getting workspace file info for path: {path}") 221 | 222 | # Use the workspace list API to get file metadata 223 | # Split the path to get directory and filename 224 | import os 225 | directory = os.path.dirname(path) 226 | filename = os.path.basename(path) 227 | 228 | if not directory: 229 | directory = "/" 230 | 231 | # List the directory to find the file 232 | response = await make_api_request("GET", "/api/2.0/workspace/list", params={"path": directory}) 233 | 234 | # Find the specific file in the listing 235 | if "objects" in response: 236 | for obj in response["objects"]: 237 | if obj.get("path") == path: 238 | return obj 239 | 240 | raise DatabricksAPIError(f"File not found: {path}") 241 | 242 | 243 | def is_base64(content: str) -> bool: 244 | """ 245 | Check if a string is already base64 encoded. 246 | 247 | Args: 248 | content: The string to check 249 | 250 | Returns: 251 | True if the string is base64 encoded, False otherwise 252 | """ 253 | try: 254 | return base64.b64encode(base64.b64decode(content)) == content.encode('utf-8') 255 | except Exception: 256 | return False -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Databricks MCP Server 2 | 3 | A production-ready **Model Context Protocol (MCP)** server that exposes Databricks REST capabilities to MCP-compatible agents and tooling. Version **0.4.4** introduces structured responses, resource caching, retry-aware networking, and end-to-end resilience improvements. 4 | 5 | --- 6 | 7 | ## Table of Contents 8 | 1. [Key Capabilities](#key-capabilities) 9 | 2. [Architecture Highlights](#architecture-highlights) 10 | 3. [Installation](#installation) 11 | 4. [Configuration](#configuration) 12 | 5. [Running the Server](#running-the-server) 13 | 6. [Integrating with MCP Clients](#integrating-with-mcp-clients) 14 | 7. [Working with Tool Responses](#working-with-tool-responses) 15 | 8. [Available Tools](#available-tools) 16 | 9. [Development Workflow](#development-workflow) 17 | 10. [Testing](#testing) 18 | 11. [Publishing Builds](#publishing-builds) 19 | 12. [Support & Contact](#support--contact) 20 | 13. [License](#license) 21 | 22 | --- 23 | 24 | ## Key Capabilities 25 | - **Structured MCP Responses** - Each tool returns a `CallToolResult` with a human-readable summary in `content` and machine-readable payloads in `structuredContent` that conform to the tool’s `outputSchema`. 26 | - **Resource Caching** - Large notebook/workspace exports are cached once and returned as `resource_link` content blocks with URIs such as `resource://databricks/exports/{id}` (also reflected in metadata for convenience). 27 | - **Progress & Metrics** - Long-running actions stream MCP progress notifications and track per-tool success/error/timeout/cancel metrics. 28 | - **Resilient Networking** - Shared HTTP client injects request IDs, enforces timeouts, and retries retryable Databricks responses (408/429/5xx) with exponential backoff. 29 | - **Async Runtime** - Built on `mcp.server.FastMCP` with centralized JSON logging and concurrency guards for predictable stdio behaviour. 30 | 31 | ## Architecture Highlights 32 | - `databricks_mcp/server/databricks_mcp_server.py` - FastMCP server with tool registration, progress handling, metrics, and resource caching. 33 | - `databricks_mcp/core/utils.py` - HTTP utilities with correlation IDs, retries, and error mapping to `DatabricksAPIError`. 34 | - `databricks_mcp/core/logging_utils.py` - JSON logging configuration for stderr/file outputs. 35 | - `databricks_mcp/core/models.py` - Pydantic models (e.g., `ClusterConfig`) used by tool schemas. 36 | - Tests under `tests/` mock Databricks APIs to validate orchestration, structured responses, and schema metadata without shell scripts. 37 | 38 | For an in-depth tour of data flow and design decisions, see [ARCHITECTURE.md](ARCHITECTURE.md). 39 | 40 | ## Installation 41 | 42 | ### Prerequisites 43 | - Python 3.10+ 44 | - [`uv`](https://github.com/astral-sh/uv) for dependency management and publishing 45 | 46 | ### Quick Install (recommended) 47 | Register the server with Cursor using the deeplink below - it resolves to `uvx databricks-mcp-server@latest` and picks up future updates automatically. 48 | 49 | ```text 50 | cursor://anysphere.cursor-deeplink/mcp/install?name=databricks-mcp&config=eyJjb21tYW5kIjoidXZ4IiwiYXJncyI6WyJkYXRhYnJpY2tzLW1jcC1zZXJ2ZXIiXSwiZW52Ijp7IkRBVEFCUklDS1NfSE9TVCI6IiR7REFUQUJSSUNLU19IT1NUfSIsIkRBVEFCUklDS1NfVE9LRU4iOiIke0RBVEFCUklDS1NfVE9LRU59IiwiREFUQUJSSUNLU19XQVJFSE9VU0VfSUQiOiIke0RBVEFCUklDS1NfV0FSRUhPVVNFX0lEfSJ9fQ== 51 | ``` 52 | 53 | ### Manual Installation 54 | ```bash 55 | # Clone and enter the repository 56 | git clone https://github.com/markov-kernel/databricks-mcp.git 57 | cd databricks-mcp 58 | 59 | # Create an isolated environment (optional but recommended) 60 | uv venv 61 | source .venv/bin/activate # Linux/Mac 62 | # .\.venv\Scripts\activate # Windows PowerShell 63 | 64 | # Install package and development dependencies 65 | uv pip install -e . 66 | uv pip install -e ".[dev]" 67 | ``` 68 | 69 | ## Configuration 70 | Set the following environment variables (or populate `.env` from `.env.example`). 71 | 72 | ```bash 73 | export DATABRICKS_HOST="https://your-workspace.databricks.com" 74 | export DATABRICKS_TOKEN="dapiXXXXXXXXXXXXXXXX" 75 | export DATABRICKS_WAREHOUSE_ID="sql_warehouse_12345" # optional default 76 | export TOOL_TIMEOUT_SECONDS=300 77 | export MAX_CONCURRENT_REQUESTS=8 78 | export HTTP_TIMEOUT_SECONDS=60 79 | export API_MAX_RETRIES=3 80 | export API_RETRY_BACKOFF_SECONDS=0.5 81 | ``` 82 | 83 | ## Running the Server 84 | ```bash 85 | uvx databricks-mcp-server@latest 86 | ``` 87 | > Tip: append `--refresh` (e.g., `uvx databricks-mcp-server@latest --refresh`) to force `uv` to resolve the latest PyPI release after publishing. Logs are emitted as JSON lines to stderr and persisted to `databricks_mcp.log` in the working directory. 88 | 89 | To adjust logging: 90 | ```bash 91 | uvx databricks-mcp-server@latest -- --log-level DEBUG 92 | ``` 93 | 94 | ## Integrating with MCP Clients 95 | 96 | ### Codex CLI (STDIO) 97 | Register the server and inject credentials via the CLI: 98 | 99 | ```bash 100 | codex mcp add databricks --env DATABRICKS_HOST="https://your-workspace.databricks.com" --env DATABRICKS_TOKEN="dapi_XXXXXXXXXXXXXXXX" --env DATABRICKS_WAREHOUSE_ID="sql_warehouse_12345" -- uvx databricks-mcp-server@latest 101 | # Add --refresh immediately after a publish to invalidate the uv cache 102 | ``` 103 | 104 | Or edit `~/.codex/config.toml`: 105 | 106 | ```toml 107 | [mcp_servers.databricks] 108 | command = "uvx" 109 | args = ["databricks-mcp-server@latest"] 110 | env = { 111 | DATABRICKS_HOST = "https://your-workspace.databricks.com", 112 | DATABRICKS_TOKEN = "dapi_XXXXXXXXXXXXXXXX", 113 | DATABRICKS_WAREHOUSE_ID = "sql_warehouse_12345" 114 | } 115 | startup_timeout_sec = 15 116 | tool_timeout_sec = 300 117 | ``` 118 | 119 | > Planning an HTTP deployment? Codex also supports `url = "https://…"` plus 120 | > `bearer_token_env_var = "DATABRICKS_TOKEN"` or `codex mcp login` (with 121 | > `experimental_use_rmcp_client = true`). 122 | 123 | ### Cursor 124 | ```jsonc 125 | { 126 | "mcpServers": { 127 | "databricks-mcp-local": { 128 | "command": "uvx", 129 | "args": ["databricks-mcp-server@latest"], 130 | "env": { 131 | "DATABRICKS_HOST": "https://your-workspace.databricks.com", 132 | "DATABRICKS_TOKEN": "dapiXXXXXXXXXXXXXXXX", 133 | "DATABRICKS_WAREHOUSE_ID": "sql_warehouse_12345", 134 | "RUNNING_VIA_CURSOR_MCP": "true" 135 | } 136 | } 137 | } 138 | } 139 | ``` 140 | Restart Cursor after saving and invoke tools as `databricks-mcp-local:`. 141 | 142 | ### Claude CLI 143 | ```bash 144 | claude mcp add databricks-mcp-local -s user -e DATABRICKS_HOST="https://your-workspace.databricks.com" -e DATABRICKS_TOKEN="dapiXXXXXXXXXXXXXXXX" -e DATABRICKS_WAREHOUSE_ID="sql_warehouse_12345" -- uvx databricks-mcp-server@latest 145 | ``` 146 | 147 | ## Working with Tool Responses 148 | `structuredContent` carries machine-readable payloads. Large artifacts are returned as `resource_link` content blocks using URIs like `resource://databricks/exports/{id}` and can be fetched via the MCP resources API. 149 | 150 | ```python 151 | result = await session.call_tool("list_clusters", {}) 152 | summary = next((block.text for block in result.content if getattr(block, "type", "") == "text"), "") 153 | clusters = (result.structuredContent or {}).get("clusters", []) 154 | resource_links = [block for block in result.content if isinstance(block, dict) and block.get("type") == "resource_link"] 155 | ``` 156 | 157 | Progress notifications follow MCP’s progress token mechanism; Codex surfaces these messages in the UI while a tool runs. 158 | 159 | ### Example - SQL Query 160 | ```python 161 | result = await session.call_tool("execute_sql", {"statement": "SELECT * FROM samples LIMIT 10"}) 162 | print(result.content[0].text) 163 | rows = (result.structuredContent or {}).get("result", []) 164 | ``` 165 | 166 | ### Example - Workspace File Export 167 | ```python 168 | result = await session.call_tool("get_workspace_file_content", { 169 | "path": "/Users/user@domain.com/report.ipynb", 170 | "format": "SOURCE" 171 | }) 172 | resource_link = next((block for block in result.content if isinstance(block, dict) and block.get("type") == "resource_link"), None) 173 | if resource_link: 174 | contents = await session.read_resource(resource_link["uri"]) 175 | ``` 176 | 177 | ## Available Tools 178 | | Category | Tool | Description | 179 | | --- | --- | --- | 180 | | Clusters | `list_clusters`, `create_cluster`, `terminate_cluster`, `get_cluster`, `start_cluster`, `resize_cluster`, `restart_cluster` | Manage interactive clusters | 181 | | Jobs | `list_jobs`, `create_job`, `delete_job`, `run_job`, `run_notebook`, `sync_repo_and_run_notebook`, `get_run_status`, `list_job_runs`, `cancel_run` | Manage scheduled and ad-hoc jobs | 182 | | Workspace | `list_notebooks`, `export_notebook`, `import_notebook`, `delete_workspace_object`, `get_workspace_file_content`, `get_workspace_file_info` | Inspect and manage workspace assets | 183 | | DBFS | `list_files`, `dbfs_put`, `dbfs_delete` | Explore DBFS and manage files | 184 | | SQL | `execute_sql` | Submit SQL statements with optional `warehouse_id`, `catalog`, `schema_name` | 185 | | Libraries | `install_library`, `uninstall_library`, `list_cluster_libraries` | Manage cluster libraries | 186 | | Repos | `create_repo`, `update_repo`, `list_repos`, `pull_repo` | Manage Databricks repos | 187 | | Unity Catalog | `list_catalogs`, `create_catalog`, `list_schemas`, `create_schema`, `list_tables`, `create_table`, `get_table_lineage` | Unity Catalog operations | 188 | 189 | ## Development Workflow 190 | ```bash 191 | uv run black databricks_mcp tests 192 | uv run pylint databricks_mcp tests 193 | uv run pytest 194 | uv build 195 | uv publish --token "$PYPI_TOKEN" 196 | ``` 197 | 198 | ## Testing 199 | ```bash 200 | uv run pytest 201 | ``` 202 | Pytest suites mock Databricks APIs, providing deterministic structured outputs and transcript tests. 203 | 204 | ## Publishing Builds 205 | Ensure `PYPI_TOKEN` is available (via `.env` or environment) before publishing: 206 | ```bash 207 | uv build 208 | uv publish --token "$PYPI_TOKEN" 209 | ``` 210 | 211 | ## Support & Contact 212 | - Maintainer: Olivier Debeuf De Rijcker (olivier@markov.bot) 213 | - Issues: [GitHub Issues](https://github.com/markov-kernel/databricks-mcp/issues) 214 | - Architecture deep dive: [ARCHITECTURE.md](ARCHITECTURE.md) 215 | 216 | ## License 217 | 218 | Released under the MIT License. See [LICENSE](LICENSE). 219 | -------------------------------------------------------------------------------- /ARCHITECTURE.md: -------------------------------------------------------------------------------- 1 | # Databricks MCP Server — Architecture and Deep Dive 2 | 3 | This document provides a comprehensive, highly detailed, end‑to‑end overview of the Databricks MCP Server contained in this repository. It covers the project structure, runtime architecture, MCP tools and their parameters, data flow and error handling, configuration, testing, and known caveats. 4 | 5 | > Package: `databricks-mcp-server` (v0.4.4 in packaging metadata) 6 | 7 | 8 | ## 1) Repository at a Glance 9 | 10 | ``` 11 | . 12 | ├─ AGENTS.md 13 | ├─ ARCHITECTURE.md 14 | ├─ README.md 15 | ├─ databricks_mcp/ 16 | │ ├─ __init__.py 17 | │ ├─ __main__.py 18 | │ ├─ main.py 19 | │ ├─ api/ 20 | │ │ ├─ clusters.py 21 | │ │ ├─ dbfs.py 22 | │ │ ├─ jobs.py 23 | │ │ ├─ libraries.py 24 | │ │ ├─ notebooks.py 25 | │ │ ├─ repos.py 26 | │ │ └─ unity_catalog.py 27 | │ ├─ cli/ 28 | │ │ └─ commands.py 29 | │ ├─ core/ 30 | │ │ ├─ auth.py 31 | │ │ ├─ config.py 32 | │ │ ├─ logging_utils.py 33 | │ │ ├─ models.py 34 | │ │ └─ utils.py 35 | │ └─ server/ 36 | │ ├─ __init__.py 37 | │ ├─ __main__.py 38 | │ ├─ app.py 39 | │ ├─ databricks_mcp_server.py 40 | │ └─ tool_helpers.py 41 | ├─ tests/ 42 | │ ├─ test_additional_features.py 43 | │ ├─ test_clusters.py 44 | │ ├─ test_server_structured.py 45 | │ ├─ test_tool_metadata.py 46 | │ └─ test_transcript.py 47 | ├─ .env.example 48 | ├─ pyproject.toml 49 | └─ uv.lock 50 | ``` 51 | 52 | 53 | ## 2) Build, Packaging, and Entry Points 54 | 55 | - Packaging is configured via Hatch (`hatchling`). 56 | - Python ≥ 3.10. 57 | - Key dependencies: `mcp[cli]` (1.2.0+), `httpx`, `databricks-sdk`. Dev extras add `pytest`, `pytest-asyncio`, `fastapi`, `anyio` for local HTTP testing and async test support. 58 | - Console scripts are declared in packaging metadata (`pyproject.toml`): 59 | - `databricks-mcp-server` → `databricks_mcp.server.databricks_mcp_server:main` 60 | - `databricks-mcp` → `databricks_mcp.cli.commands:main` 61 | 62 | Module entrypoints for `python -m` execution: 63 | - `databricks_mcp/__main__.py` delegates to `databricks_mcp.main:main`. 64 | - `databricks_mcp/server/__main__.py` invokes `server.databricks_mcp_server:main()` directly. 65 | 66 | 67 | ## 3) Configuration & Environment 68 | 69 | File: `databricks_mcp/core/config.py` 70 | - `.env` loading is silent (no stdout noise) unless Cursor provides env via `RUNNING_VIA_CURSOR_MCP`. 71 | - Pydantic `Settings` surface: 72 | - Core Databricks auth: `DATABRICKS_HOST`, `DATABRICKS_TOKEN`, optional `DATABRICKS_WAREHOUSE_ID`. 73 | - Logging/runtime: `LOG_LEVEL`, plus `TOOL_TIMEOUT_SECONDS`, `MAX_CONCURRENT_REQUESTS` controlling server execution safeguards. 74 | - HTTP behaviour: `HTTP_TIMEOUT_SECONDS`, `API_MAX_RETRIES`, `API_RETRY_BACKOFF_SECONDS` used by `core.utils` for exponential backoff. 75 | - Helpers: 76 | - `get_api_headers()` returns Authorization and JSON headers. 77 | - `get_databricks_api_url(endpoint)` joins host + endpoint, trimming extra slashes. 78 | 79 | Example `.env` (see `.env.example`): 80 | ``` 81 | DATABRICKS_HOST=https://your-workspace.databricks.com 82 | DATABRICKS_TOKEN=dapi_your_token_here 83 | DATABRICKS_WAREHOUSE_ID=sql_warehouse_12345 84 | SERVER_HOST=0.0.0.0 85 | SERVER_PORT=8000 86 | TOOL_TIMEOUT_SECONDS=300 87 | MAX_CONCURRENT_REQUESTS=8 88 | HTTP_TIMEOUT_SECONDS=60 89 | API_MAX_RETRIES=3 90 | API_RETRY_BACKOFF_SECONDS=0.5 91 | ``` 92 | 93 | Note: Package and runtime versions are unified via `settings.VERSION` (currently `0.4.4`). 94 | 95 | 96 | ## 4) Core Utilities and Error Handling 97 | 98 | File: `databricks_mcp/core/utils.py` 99 | - `DatabricksAPIError` captures message, status code, and raw response. 100 | - `request_context_id` (`ContextVar[str | None]`) propagates per-request correlation IDs from the MCP layer into HTTP headers (`X-Databricks-MCP-Request-ID`). 101 | - `make_api_request(...)` now includes: 102 | - `httpx.AsyncClient` with configurable timeout from settings. 103 | - Exponential backoff (`API_MAX_RETRIES`, `API_RETRY_BACKOFF_SECONDS`) for retryable status codes (408/429/5xx) and transport hiccups. 104 | - Redaction of payload logs and structured logging on failures before raising `DatabricksAPIError`. 105 | - `format_response(...)` remains available for legacy helpers but primary paths return raw dicts that the server wraps into `CallToolResult` metadata. 106 | 107 | 108 | ## 5) Domain Models 109 | 110 | File: `databricks_mcp/core/models.py` 111 | - Lightweight Pydantic models for common structures: `ClusterConfig`, `JobTask`, `Job`, `Run`, `WorkspaceObject`, `DbfsItem`, `Library`, `Repo`, `Catalog`, `Schema`, `Table`. 112 | - `ClusterConfig` backs the structured `create_cluster` tool signature so MCP schemas expose the Databricks create API fields. 113 | 114 | 115 | ## 6) Server Architecture (MCP) 116 | 117 | File: `databricks_mcp/server/databricks_mcp_server.py:1` 118 | - Implements an MCP server using `mcp.server.FastMCP`. 119 | - On construction, logs environment, registers all tools, and serves over stdio using `FastMCP.run()` in `main()`. 120 | - Logging targets `databricks_mcp.log` with level from `LOG_LEVEL`. 121 | 122 | ### 6.1 Parameter Handling and Client Compatibility 123 | - Tools expose explicit, flat parameters via FastMCP's schema generation, so clients see the canonical JSON shape (e.g., `{ "cluster_id": "..." }`). 124 | - Legacy `{ "params": { ... } }` envelopes were removed in favour of consistent argument validation. 125 | 126 | ### 6.2 Content Shape and Structured Results 127 | - Tool handlers return `CallToolResult` objects with a short human summary (`TextContent`) and the full Databricks payload in `structuredContent` (validated by each tool's `outputSchema`). 128 | - Each response annotates `_meta['_request_id']` for correlation and attaches cached resource references for large exports. 129 | - Tests such as `tests/test_server_structured.py` assert the presence of structured JSON and resource metadata. 130 | 131 | ### 6.3 Progress & Metrics 132 | - `_report_progress` invokes `ctx.report_progress(...)` so clients receive start/finish updates (with midpoints for multi-phase tools like repo sync + notebook run). 133 | - A `Counter` tracks success/error/timeout/cancelled tallies per tool, retrievable via `get_metrics_snapshot()`. 134 | 135 | ### 6.4 Startup 136 | - `main()` reconfigures stdout line buffering (useful for stdio-based protocols) and calls `server.run()`. 137 | - The server does not expose HTTP routes by default; HTTP is provided by a separate FastAPI stub for testing. 138 | 139 | 140 | ## 7) HTTP Stub (FastAPI) for Tests 141 | 142 | File: `databricks_mcp/server/app.py:1` 143 | - Minimal FastAPI app that routes a subset of cluster and workspace operations directly to the async API layer. 144 | - Intended only for test compatibility and not used by the MCP runtime. 145 | 146 | 147 | ## 8) CLI 148 | 149 | File: `databricks_mcp/cli/commands.py` 150 | - Subcommands: 151 | - `start` — runs the MCP server (stdio entrypoint). 152 | - `list-tools` — prints tool name + description via `FastMCP.list_tools()`. 153 | - `version` — instantiates the server to display `server.version` and warn about missing env vars. 154 | - `sync-run` — wraps `sync_repo_and_run_notebook`, printing the summary text block and pretty-printing `structuredContent` on success/errors. 155 | 156 | Examples: 157 | ``` 158 | # Start server (stdio MCP host must spawn this) 159 | databricks-mcp start 160 | 161 | # List tools 162 | databricks-mcp list-tools 163 | 164 | # Version 165 | databricks-mcp version 166 | 167 | # Pull repo and run notebook 168 | databricks-mcp sync-run --repo-id 42 --notebook-path /Shared/foo --cluster-id 1234-abc 169 | ``` 170 | 171 | 172 | ## 9) Databricks API Modules 173 | 174 | All modules delegate HTTP calls to `core.utils.make_api_request` and are fully async. 175 | 176 | ### 9.1 Clusters — `databricks_mcp/api/clusters.py:1` 177 | - `create_cluster(cluster_config)` → `POST /api/2.0/clusters/create` 178 | - `terminate_cluster(cluster_id)` → `POST /api/2.0/clusters/delete` 179 | - `list_clusters()` → `GET /api/2.0/clusters/list` 180 | - `get_cluster(cluster_id)` → `GET /api/2.0/clusters/get` 181 | - `start_cluster(cluster_id)` → `POST /api/2.0/clusters/start` 182 | - `resize_cluster(cluster_id, num_workers)` → `POST /api/2.0/clusters/resize` 183 | - `restart_cluster(cluster_id)` → `POST /api/2.0/clusters/restart` 184 | 185 | ### 9.2 Jobs — `databricks_mcp/api/jobs.py:1` 186 | - CRUD & execution: 187 | - `create_job(job_config)` → `POST /api/2.2/jobs/create` 188 | - `run_job(job_id, notebook_params=None)` → `POST /api/2.0/jobs/run-now` 189 | - `list_jobs()` → `GET /api/2.0/jobs/list` 190 | - `get_job(job_id)` → `GET /api/2.0/jobs/get` 191 | - `update_job(job_id, new_settings)` → `POST /api/2.0/jobs/update` 192 | - `delete_job(job_id)` → `POST /api/2.2/jobs/delete` 193 | - Runs & polling: 194 | - `submit_run(run_config)` → `POST /api/2.0/jobs/runs/submit` 195 | - `get_run(run_id)` → `GET /api/2.1/jobs/runs/get` 196 | - `list_runs(job_id=None, limit=20)` → `GET /api/2.1/jobs/runs/list` 197 | - `get_run_status(run_id)` → extracts concise `state` and `life_cycle` fields 198 | - `cancel_run(run_id)` → `POST /api/2.1/jobs/runs/cancel` 199 | - `get_run_output(run_id)` → `GET /api/2.0/jobs/runs/get-output` 200 | - Notebook one-off execution helper: 201 | - `run_notebook(notebook_path, existing_cluster_id=None, base_parameters=None, ...)` 202 | - Builds a transient run task and waits until termination (`await_until_state`) before fetching output. 203 | 204 | ### 9.3 Notebooks & Workspace — `databricks_mcp/api/notebooks.py:1` 205 | - `import_notebook(path, content, format='SOURCE', language=None, overwrite=False)` → `POST /api/2.0/workspace/import` 206 | - If `content` is not base64, it will be encoded. 207 | - `export_notebook(path, format='SOURCE')` → `GET /api/2.0/workspace/export` 208 | - Decodes base64 when possible, attaching `decoded_content` and `content_type`. 209 | - `list_notebooks(path)` → `GET /api/2.0/workspace/list` 210 | - `delete_notebook(path, recursive=False)` → `POST /api/2.0/workspace/delete` 211 | - `create_directory(path)` → `POST /api/2.0/workspace/mkdirs` 212 | - `export_workspace_file(path, format='SOURCE')` → general-purpose export for non-notebook files 213 | - `get_workspace_file_info(path)` → directory listing lookup to return metadata for a specific file 214 | 215 | ### 9.4 DBFS — `databricks_mcp/api/dbfs.py:1` 216 | - Small uploads: `put_file(dbfs_path, file_content_bytes, overwrite=True)` → `POST /api/2.0/dbfs/put` 217 | - Large uploads: `upload_large_file(dbfs_path, local_file_path, overwrite=True, buffer_size=1MB)` 218 | - Orchestrates `create` → repeated `add-block` → `close` with base64 chunks. 219 | - Attempts a best-effort `close` on error for cleanup. 220 | - Reads: `get_file(dbfs_path, offset=0, length=1MB)` → `GET /api/2.0/dbfs/read` (decodes base64 into `decoded_data`). 221 | - Listings & metadata: `list_files(dbfs_path)`, `get_status(dbfs_path)`, `create_directory(dbfs_path)`, `delete_file(dbfs_path, recursive=False)`. 222 | 223 | ### 9.5 SQL Warehouses — `databricks_mcp/api/sql.py:1` 224 | - `execute_statement(statement, warehouse_id=None, catalog=None, schema=None, parameters=None, row_limit=10000, byte_limit=100MB)` → `POST /api/2.0/sql/statements` 225 | - Falls back to `settings.DATABRICKS_WAREHOUSE_ID` if `warehouse_id` not provided. 226 | - Uses `format=JSON_ARRAY`, `disposition=INLINE`, `wait_timeout=10s`. 227 | - `execute_and_wait(...)` → kicks off `execute_statement`, then polls `get_statement_status(statement_id)` until `SUCCEEDED` or failure/timeout. 228 | - `get_statement_status(statement_id)`, `cancel_statement(statement_id)`. 229 | 230 | ### 9.6 Cluster Libraries — `databricks_mcp/api/libraries.py:1` 231 | - `install_library(cluster_id, libraries)` → `POST /api/2.0/libraries/install` 232 | - `uninstall_library(cluster_id, libraries)` → `POST /api/2.0/libraries/uninstall` 233 | - `list_cluster_libraries(cluster_id)` → `GET /api/2.0/libraries/cluster-status` 234 | 235 | ### 9.7 Repos — `databricks_mcp/api/repos.py:1` 236 | - `create_repo(url, provider, branch=None, path=None)` → `POST /api/2.0/repos` 237 | - `update_repo(repo_id, branch=None, tag=None)` → `PATCH /api/2.0/repos/{id}` 238 | - `list_repos(path_prefix=None)` → `GET /api/2.0/repos` 239 | - `pull_repo(repo_id)` → `POST /api/2.0/repos/{id}/pull` 240 | 241 | ### 9.8 Unity Catalog — `databricks_mcp/api/unity_catalog.py:1` 242 | - Catalogs: `list_catalogs()`, `create_catalog(name, comment=None)` 243 | - Schemas: `list_schemas(catalog_name)`, `create_schema(catalog_name, name, comment=None)` 244 | - Tables: `list_tables(catalog_name, schema_name)`, `create_table(warehouse_id, statement)` (via SQL API) 245 | - Lineage: `get_table_lineage(full_name)` → `GET /api/2.1/unity-catalog/lineage-tracking/table-lineage/{full_name}` 246 | 247 | 248 | ## 10) MCP Tool Inventory (Names, Purpose, Parameters) 249 | 250 | All registered in `databricks_mcp/server/databricks_mcp_server.py`: 251 | 252 | - Clusters: 253 | - `list_clusters` 254 | - `create_cluster` — params mirror Databricks create API (name, spark_version, node_type_id, …) 255 | - `terminate_cluster` — `cluster_id` 256 | - `get_cluster` — `cluster_id` 257 | - `start_cluster` — `cluster_id` 258 | - Jobs: 259 | - `list_jobs` 260 | - `create_job` — `{ name, tasks, … }` 261 | - `delete_job` — `job_id` 262 | - `run_job` — `job_id`, optional `notebook_params` 263 | - `run_notebook` — `notebook_path`, optional `existing_cluster_id`, `base_parameters` 264 | - `sync_repo_and_run_notebook` — `repo_id`, `notebook_path`, optional cluster/parameters 265 | - `get_run_status` — `run_id` 266 | - `list_job_runs` — `job_id` 267 | - `cancel_run` — `run_id` 268 | - Workspace/Notebooks: 269 | - `list_notebooks` — `path` 270 | - `export_notebook` — `path`, optional `format` 271 | - `import_notebook` — `path`, `content` (base64 or text), optional `format` 272 | - `delete_workspace_object` — `path`, optional `recursive` 273 | - `get_workspace_file_content` — `path`, optional `format` 274 | - `get_workspace_file_info` — `path` 275 | - DBFS: 276 | - `list_files` — `path` 277 | - `dbfs_put` — `path`, `content` (UTF-8 string) 278 | - `dbfs_delete` — `path`, optional `recursive` 279 | - SQL: 280 | - `execute_sql` — `statement`, optional `warehouse_id`, `catalog`, `schema_name` 281 | - Cluster Libraries: 282 | - `install_library` — `cluster_id`, `libraries` 283 | - `uninstall_library` — `cluster_id`, `libraries` 284 | - `list_cluster_libraries` — `cluster_id` 285 | - Repos: 286 | - `create_repo` — `url`, `provider`, optional `branch`, `path` 287 | - `update_repo` — `repo_id`, `branch` or `tag` 288 | - `list_repos` — optional `path_prefix` 289 | - `pull_repo` — `repo_id` 290 | - Unity Catalog: 291 | - `list_catalogs`, `create_catalog` 292 | - `list_schemas`, `create_schema` 293 | - `list_tables`, `create_table` 294 | - `get_table_lineage` 295 | 296 | 297 | ## 11) Data Flow (Typical Lifecycles) 298 | 299 | ### 11.1 MCP Tool Invocation → Databricks 300 | 1. MCP clients invoke tool `X` with flat JSON arguments generated from FastMCP's auto-synthesised `inputSchema`. 301 | 2. Server validates/coerces arguments via Pydantic-toned metadata and dispatches to the async API module. 302 | 3. API utilities issue REST calls with exponential retry, correlation headers, and bounded concurrency. 303 | 4. Tool handler wraps the API payload in a `CallToolResult`, emitting a concise text summary and attaching the raw JSON to `structuredContent` (with `_meta['_request_id']`). 304 | 5. For large artifacts (notebook exports, workspace files), the handler caches the payload and emits `resource_link` content blocks using URIs such as `resource://databricks/exports/{id}`, allowing clients to fetch the data through the MCP resources API. 305 | 306 | ### 11.2 SQL Execution 307 | 1. `execute_sql` builds a statement payload; uses explicit `warehouse_id` or `settings.DATABRICKS_WAREHOUSE_ID` and forwards `catalog` / `schema_name` when provided. 308 | 2. Returns inline results (JSON array format). For long-running queries, use `execute_and_wait`. 309 | 310 | ### 11.3 Notebook One-Off Run 311 | 1. `run_notebook` constructs a submit run with `notebook_task`. 312 | 2. Waits until lifecycle reaches target state; fetches output via `get_run_output`. 313 | 314 | ### 11.4 DBFS Large Upload 315 | 1. `upload_large_file` issues `create` to get a handle. 316 | 2. Splits local file into 1MB chunks, base64 encodes, and `add-block` for each. 317 | 3. `close` finalizes the upload; attempts cleanup on failure. 318 | 319 | 320 | ## 12) Error Handling, Progress, and Metrics 321 | 322 | - HTTP layer retries transient failures and raises `DatabricksAPIError` with structured response payloads when available. 323 | - `_run_tool` wraps calls in `asyncio.wait_for`, tracks success/error/timeout/cancel counters, and injects `_meta['_request_id']` for every response. 324 | - On failure, `error_result(...)` places details in `structuredContent`; clients can inspect `isError` and use `_meta` solely for request metadata. 325 | - Progress updates are reported through `Context.report_progress`, emitting start/mid/end notifications for long-running actions (repo sync + notebook run, SQL execution, etc.). 326 | - Logging is centralized via `core.logging_utils.configure_logging`, emitting JSON lines to stderr (and `databricks_mcp.log` when configured) with correlation IDs. 327 | 328 | 329 | ## 13) Security Considerations 330 | 331 | - MCP server communicates over stdio to its client (no socket binding), so access control is pushed to the embedding tool. 332 | - The FastAPI stub includes a very basic API key mechanism (intended for local demos only) and should not be used in production (`databricks_mcp/core/auth.py`). 333 | - Secrets are read from environment; avoid committing `.env`. 334 | 335 | 336 | ## 14) Testing Overview 337 | 338 | Configuration (`pyproject.toml`) enables async tests with concise output and short tracebacks. The current suite is fully self-contained and does not depend on the removed PowerShell harnesses. 339 | 340 | Representative suites: 341 | - `tests/test_clusters.py` / `tests/test_additional_features.py` — patch API modules to validate tool orchestration logic. 342 | - `tests/test_server_structured.py` — exercises structured `CallToolResult` payloads, ensuring `structuredContent` and cached resource URIs behave correctly. 343 | - `tests/test_tool_metadata.py` — verifies FastMCP emits input/output schemas for registered tools. 344 | - `tests/test_transcript.py` — captures a deterministic request/response transcript for regression detection. 345 | 346 | All tests run offline by monkeypatching the Databricks API modules; real credentials are only required when manually invoking tools against live workspaces. 347 | 348 | 349 | ## 15) Known Caveats, Inconsistencies, and Suggested Fixes 350 | 351 | - **Resource cache lifecycle**: cached exports accumulate in-memory without eviction. Consider TTL-based purging or exposing a `clear_cache` tool for long-lived processes. 352 | - **Cancellation semantics**: incoming cancellation stops local awaiting (`asyncio.CancelledError`), but outstanding Databricks jobs/statements are not actively cancelled. A future iteration could call the corresponding Databricks cancel endpoints when feasible. 353 | - **Progress granularity**: current progress notifications cover major phases only. Additional instrumentation (e.g., chunk counts for large uploads) may enhance UX. 354 | - **FastAPI stub**: remains demo-only, unauthenticated aside from a simple API key helper. Production deployments should rely on the stdio MCP transport. 355 | 356 | 357 | ## 16) Usage Examples 358 | 359 | ### 16.1 Running via CLI 360 | ``` 361 | # Ensure env vars are set (see .env.example) 362 | export DATABRICKS_HOST=... 363 | export DATABRICKS_TOKEN=... 364 | export DATABRICKS_WAREHOUSE_ID=... 365 | 366 | # List tools 367 | uvx databricks-mcp list-tools 368 | 369 | # Start (typically the MCP client launches this via stdio) 370 | uvx databricks-mcp start 371 | ``` 372 | 373 | ### 16.2 Calling MCP tools (conceptual) 374 | Pseudocode using an MCP client: 375 | ```python 376 | # After session.initialize() 377 | result = await session.call_tool("list_clusters", {}) 378 | 379 | summary = next((block.text for block in result.content if getattr(block, "type", "") == "text"), "") 380 | data = result.structuredContent or {} 381 | 382 | print(summary) 383 | print(data.get("clusters", [])) 384 | 385 | resource_links = [block for block in result.content if isinstance(block, dict) and block.get("type") == "resource_link"] 386 | print(resource_links) 387 | ``` 388 | 389 | ### 16.3 DBFS upload (small) 390 | ```json 391 | { 392 | "tool": "dbfs_put", 393 | "params": { 394 | "path": "/FileStore/samples/foo.txt", 395 | "content": "Hello, Databricks!" 396 | } 397 | } 398 | ``` 399 | 400 | ### 16.4 Notebook export and resource retrieval 401 | ```python 402 | result = await session.call_tool("export_notebook", {"path": "/Repos/user/demo", "format": "SOURCE"}) 403 | resource_link = next((block for block in result.content if isinstance(block, dict) and block.get("type") == "resource_link"), None) 404 | if resource_link: 405 | contents = await session.read_resource(resource_link["uri"]) 406 | ``` 407 | 408 | 409 | ## 17) Appendix A — MCP vs FastAPI 410 | 411 | - MCP Server (primary): stdio transport; tools registered in `DatabricksMCPServer`. 412 | - FastAPI stub (secondary): small HTTP facade for clusters/notebooks used by tests. 413 | - Both reuse the same async API modules. 414 | 415 | 416 | ## 18) Appendix B — File/Module Cross-Reference 417 | 418 | - Entrypoints: `databricks_mcp/__main__.py`, `databricks_mcp/main.py`, and `databricks_mcp/server/__main__.py`. 419 | - MCP server implementation: `databricks_mcp/server/databricks_mcp_server.py`, helpers in `databricks_mcp/server/tool_helpers.py`. 420 | - HTTP stub: `databricks_mcp/server/app.py` and related auth helpers in `databricks_mcp/core/auth.py`. 421 | - Core utilities: configuration (`databricks_mcp/core/config.py`), logging (`core/logging_utils.py`), HTTP utilities (`core/utils.py`), domain models (`core/models.py`). 422 | - CLI surface: `databricks_mcp/cli/commands.py`. 423 | - Databricks API adapters: modules under `databricks_mcp/api/` for clusters, jobs, notebooks, dbfs, sql, libraries, repos, and unity catalog. 424 | 425 | 426 | ## 19) Appendix C — Troubleshooting Checklist 427 | 428 | - “Missing Databricks credentials” warnings — set `DATABRICKS_HOST` and `DATABRICKS_TOKEN` before running CLI commands or MCP clients. 429 | - “warehouse_id must be provided…” — set `DATABRICKS_WAREHOUSE_ID` in the environment or pass `warehouse_id` to `execute_sql` explicitly. 430 | - Notebook export returns no resource URIs — ensure the target path exists and the calling principal can read the workspace object. 431 | - Cached resource URIs missing data — the `resources` cache is in-memory; long-lived processes may need manual cleanup or restarts if URIs expire. 432 | 433 | 434 | ## 20) Roadmap Ideas 435 | 436 | - Add eviction/TTL controls to the resource cache to prevent unbounded growth. 437 | - Invoke Databricks cancellation endpoints when MCP cancellation notifications arrive for long-running jobs/statements. 438 | - Emit richer progress telemetry (e.g., per-chunk updates for large uploads) and expose metrics via a lightweight diagnostic tool. 439 | - Consider removing or hardening the FastAPI stub to avoid confusion with the primary MCP transport. 440 | 441 | --- 442 | 443 | Maintainer: Olivier Debeuf De Rijcker 444 | -------------------------------------------------------------------------------- /databricks_mcp/server/databricks_mcp_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Databricks MCP Server implementation. 3 | 4 | Provides MCP tools that wrap Databricks REST APIs with structured results, 5 | resource links for large payloads, and standardized error handling. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import asyncio 11 | import base64 12 | import json 13 | import logging 14 | import sys 15 | import uuid 16 | from collections import Counter 17 | from dataclasses import dataclass 18 | from typing import Any, Awaitable, Callable, Dict, List, Optional 19 | 20 | from mcp.server import FastMCP 21 | from mcp.server.fastmcp.server import Context 22 | from mcp.types import CallToolResult, TextContent 23 | 24 | from databricks_mcp.api import clusters, dbfs, jobs, libraries, notebooks, repos, sql, unity_catalog 25 | from databricks_mcp.core.config import settings 26 | from databricks_mcp.core.logging_utils import configure_logging 27 | from databricks_mcp.core.models import ClusterConfig, Job 28 | from databricks_mcp.core.utils import DatabricksAPIError, request_context_id 29 | from databricks_mcp.server.tool_helpers import error_result, success_result 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | @dataclass 35 | class ResourcePayload: 36 | data: bytes 37 | mime_type: Optional[str] 38 | description: Optional[str] 39 | is_text: bool = False 40 | 41 | 42 | class DatabricksMCPServer(FastMCP): 43 | """An MCP server for Databricks APIs with structured outputs.""" 44 | 45 | def __init__(self) -> None: 46 | super().__init__( 47 | name="databricks-mcp", 48 | instructions="Use this server to manage Databricks resources", 49 | ) 50 | self.version = settings.VERSION 51 | logger.info("Initializing Databricks MCP server") 52 | 53 | self._task_semaphore = asyncio.Semaphore(settings.MAX_CONCURRENT_REQUESTS) 54 | self._resource_cache: Dict[str, ResourcePayload] = {} 55 | self._metrics: Counter[str] = Counter() 56 | 57 | self._validate_environment() 58 | self._register_resources() 59 | self._register_tools() 60 | 61 | async def _report_progress(self, ctx: Context | None, progress: float, total: float = 100.0, message: str | None = None) -> None: 62 | if ctx is None: 63 | return 64 | try: 65 | await ctx.report_progress(progress, total, message=message) 66 | except Exception: # pragma: no cover - progress failures are non-fatal 67 | logger.debug("Failed to report progress %s for %s", progress, message, exc_info=True) 68 | 69 | def _validate_environment(self) -> None: 70 | missing = [] 71 | if not settings.DATABRICKS_HOST or settings.DATABRICKS_HOST == "https://example.databricks.net": 72 | missing.append("DATABRICKS_HOST") 73 | if settings.DATABRICKS_TOKEN == "dapi_token_placeholder": 74 | missing.append("DATABRICKS_TOKEN") 75 | 76 | if missing: 77 | hint = ", ".join(missing) 78 | logger.warning("Missing Databricks credentials: %s", hint) 79 | 80 | async def call_tool(self, name: str, arguments: Dict[str, Any]) -> CallToolResult: # type: ignore[override] 81 | """Expose structured call_tool for in-process callers (CLI/tests).""" 82 | context = self.get_context() 83 | result = await self._tool_manager.call_tool(name, arguments, context=context, convert_result=False) 84 | 85 | if isinstance(result, CallToolResult): 86 | return result 87 | 88 | if isinstance(result, tuple) and len(result) == 2: 89 | unstructured, structured = result 90 | return CallToolResult(content=list(unstructured), _meta={"data": structured}, isError=False) 91 | 92 | if isinstance(result, dict): 93 | return CallToolResult( 94 | content=[TextContent(type="text", text=json.dumps(result, indent=2))], 95 | _meta={"data": result}, 96 | isError=False, 97 | ) 98 | 99 | if hasattr(result, "__iter__"): 100 | return CallToolResult(content=list(result), _meta={"data": {}}, isError=False) 101 | 102 | return CallToolResult( 103 | content=[TextContent(type="text", text=f"Unexpected return type: {type(result).__name__}")], 104 | _meta={"data": {"error": "unexpected_type"}}, 105 | isError=True, 106 | ) 107 | 108 | # ------------------------------------------------------------------ 109 | # Resource helpers 110 | # ------------------------------------------------------------------ 111 | def _register_resources(self) -> None: 112 | @self.resource("resource://databricks/exports/{resource_id}", description="Cached Databricks export") 113 | async def read_cached_resource(resource_id: str) -> str | bytes: 114 | payload = self._resource_cache.get(resource_id) 115 | if payload is None: 116 | raise ValueError(f"Resource {resource_id} not found") 117 | if payload.is_text: 118 | return payload.data.decode("utf-8") 119 | return payload.data 120 | 121 | def _cache_resource( 122 | self, 123 | content: bytes | str, 124 | *, 125 | mime_type: Optional[str], 126 | description: Optional[str], 127 | ) -> str: 128 | if isinstance(content, str): 129 | payload = ResourcePayload(data=content.encode("utf-8"), mime_type=mime_type, description=description, is_text=True) 130 | else: 131 | payload = ResourcePayload(data=content, mime_type=mime_type, description=description, is_text=False) 132 | 133 | resource_id = uuid.uuid4().hex 134 | self._resource_cache[resource_id] = payload 135 | return f"resource://databricks/exports/{resource_id}" 136 | 137 | # ------------------------------------------------------------------ 138 | # Execution helper 139 | # ------------------------------------------------------------------ 140 | async def _run_tool( 141 | self, 142 | name: str, 143 | action: Callable[[], Awaitable[Any]], 144 | summary_fn: Callable[[Any], str], 145 | ctx: Context | None, 146 | ) -> CallToolResult: 147 | inbound_request_id = getattr(ctx, "request_id", None) if ctx else None 148 | execution_id = inbound_request_id or uuid.uuid4().hex 149 | extra = {"request_id": execution_id} 150 | token = request_context_id.set(execution_id) 151 | 152 | try: 153 | await self._report_progress(ctx, 0, message=f"Starting {name}") 154 | async with self._task_semaphore: 155 | try: 156 | result = await asyncio.wait_for(action(), timeout=settings.TOOL_TIMEOUT_SECONDS) 157 | except asyncio.TimeoutError: 158 | message = f"{name} timed out after {settings.TOOL_TIMEOUT_SECONDS}s" 159 | logger.warning("%s", message, extra=extra) 160 | self._metrics[f"{name}.timeout"] += 1 161 | err = error_result(message, status_code=504) 162 | err.meta = {"tool": name, "_request_id": execution_id} 163 | return err 164 | except asyncio.CancelledError: 165 | message = f"{name} was cancelled" 166 | logger.info(message, extra=extra) 167 | self._metrics[f"{name}.cancelled"] += 1 168 | err = error_result(message, status_code=499) 169 | err.meta = {"tool": name, "_request_id": execution_id} 170 | return err 171 | except DatabricksAPIError as err: 172 | message = f"{name} failed: {err.message}" 173 | logger.warning(message, extra=extra) 174 | self._metrics[f"{name}.error"] += 1 175 | err_result = error_result(message, details=err.response, status_code=err.status_code) 176 | err_result.meta = {"tool": name, "_request_id": execution_id} 177 | return err_result 178 | except Exception as err: # pylint: disable=broad-except 179 | logger.exception("Unexpected error running %s", name, extra=extra) 180 | self._metrics[f"{name}.error"] += 1 181 | err_result = error_result(f"{name} failed unexpectedly", details=str(err)) 182 | err_result.meta = {"tool": name, "_request_id": execution_id} 183 | return err_result 184 | finally: 185 | request_context_id.reset(token) 186 | 187 | summary = summary_fn(result) 188 | response = success_result(summary, result, meta={"tool": name, "_request_id": execution_id}) 189 | self._metrics[f"{name}.success"] += 1 190 | await self._report_progress(ctx, 100, message=f"Completed {name}") 191 | logger.info("Tool %s succeeded", name, extra={"request_id": execution_id, "tool": name}) 192 | return response 193 | 194 | # ------------------------------------------------------------------ 195 | # Tool registration 196 | # ------------------------------------------------------------------ 197 | def _register_tools(self) -> None: 198 | # Cluster tools 199 | @self.tool(name="list_clusters", description="List all Databricks clusters") 200 | async def list_clusters(ctx: Context | None = None) -> CallToolResult: 201 | return await self._run_tool( 202 | "list_clusters", 203 | lambda: clusters.list_clusters(), 204 | lambda data: f"Found {len(data.get('clusters', []))} clusters", 205 | ctx, 206 | ) 207 | 208 | @self.tool( 209 | name="create_cluster", 210 | description="Create a new Databricks cluster", 211 | ) 212 | async def create_cluster(cluster: ClusterConfig, ctx: Context | None = None) -> CallToolResult: 213 | payload = cluster.model_dump(exclude_none=True) 214 | 215 | async def action() -> Any: 216 | return await clusters.create_cluster(payload) 217 | 218 | return await self._run_tool( 219 | "create_cluster", 220 | action, 221 | lambda data: f"Cluster {data.get('cluster_id', payload['cluster_name'])} creation submitted", 222 | ctx, 223 | ) 224 | 225 | @self.tool(name="terminate_cluster", description="Terminate a cluster") 226 | async def terminate_cluster(cluster_id: str, ctx: Context | None = None) -> CallToolResult: 227 | return await self._run_tool( 228 | "terminate_cluster", 229 | lambda: clusters.terminate_cluster(cluster_id), 230 | lambda _: f"Termination requested for cluster {cluster_id}", 231 | ctx, 232 | ) 233 | 234 | @self.tool(name="get_cluster", description="Get information about a cluster") 235 | async def get_cluster(cluster_id: str, ctx: Context | None = None) -> CallToolResult: 236 | return await self._run_tool( 237 | "get_cluster", 238 | lambda: clusters.get_cluster(cluster_id), 239 | lambda data: f"Cluster {data.get('cluster_id', cluster_id)} state: {data.get('state', 'unknown')}", 240 | ctx, 241 | ) 242 | 243 | @self.tool(name="start_cluster", description="Start a terminated cluster") 244 | async def start_cluster(cluster_id: str, ctx: Context | None = None) -> CallToolResult: 245 | return await self._run_tool( 246 | "start_cluster", 247 | lambda: clusters.start_cluster(cluster_id), 248 | lambda _: f"Start requested for cluster {cluster_id}", 249 | ctx, 250 | ) 251 | 252 | # Job tools 253 | @self.tool(name="list_jobs", description="List Databricks jobs") 254 | async def list_jobs(ctx: Context | None = None) -> CallToolResult: 255 | return await self._run_tool( 256 | "list_jobs", 257 | lambda: jobs.list_jobs(), 258 | lambda data: f"Discovered {len(data.get('jobs', []))} jobs", 259 | ctx, 260 | ) 261 | 262 | @self.tool(name="create_job", description="Create a Databricks job") 263 | async def create_job(job: Job, ctx: Context | None = None) -> CallToolResult: 264 | async def action() -> Any: 265 | return await jobs.create_job(job.model_dump(exclude_none=True)) 266 | 267 | return await self._run_tool( 268 | "create_job", 269 | action, 270 | lambda data: f"Created job {data.get('job_id')} for {job.name}", 271 | ctx, 272 | ) 273 | 274 | @self.tool(name="delete_job", description="Delete a Databricks job") 275 | async def delete_job(job_id: int, ctx: Context | None = None) -> CallToolResult: 276 | return await self._run_tool( 277 | "delete_job", 278 | lambda: jobs.delete_job(job_id), 279 | lambda _: f"Deleted job {job_id}", 280 | ctx, 281 | ) 282 | 283 | @self.tool(name="run_job", description="Trigger a job run") 284 | async def run_job(job_id: int, notebook_params: Optional[Dict[str, Any]] = None, ctx: Context | None = None) -> CallToolResult: 285 | return await self._run_tool( 286 | "run_job", 287 | lambda: jobs.run_job(job_id, notebook_params or {}), 288 | lambda data: f"Run {data.get('run_id')} started for job {job_id}", 289 | ctx, 290 | ) 291 | 292 | @self.tool( 293 | name="run_notebook", 294 | description="Submit a one-time notebook run", 295 | ) 296 | async def run_notebook( 297 | notebook_path: str, 298 | existing_cluster_id: Optional[str] = None, 299 | base_parameters: Optional[Dict[str, Any]] = None, 300 | ctx: Context | None = None, 301 | ) -> CallToolResult: 302 | return await self._run_tool( 303 | "run_notebook", 304 | lambda: jobs.run_notebook( 305 | notebook_path=notebook_path, 306 | existing_cluster_id=existing_cluster_id, 307 | base_parameters=base_parameters, 308 | ), 309 | lambda data: f"Notebook run {data.get('run_id')} started for {notebook_path}", 310 | ctx, 311 | ) 312 | 313 | @self.tool( 314 | name="sync_repo_and_run_notebook", 315 | description="Pull a repo and run a notebook", 316 | ) 317 | async def sync_repo_and_run_notebook( 318 | repo_id: int, 319 | notebook_path: str, 320 | existing_cluster_id: Optional[str] = None, 321 | base_parameters: Optional[Dict[str, Any]] = None, 322 | ctx: Context | None = None, 323 | ) -> CallToolResult: 324 | async def action() -> Any: 325 | await self._report_progress(ctx, 25, message="Pulling repo") 326 | await repos.pull_repo(repo_id) 327 | await self._report_progress(ctx, 60, message="Triggering notebook run") 328 | return await jobs.run_notebook( 329 | notebook_path=notebook_path, 330 | existing_cluster_id=existing_cluster_id, 331 | base_parameters=base_parameters, 332 | ) 333 | 334 | return await self._run_tool( 335 | "sync_repo_and_run_notebook", 336 | action, 337 | lambda data: f"Repo {repo_id} synced; notebook run {data.get('run_id')} started", 338 | ctx, 339 | ) 340 | 341 | @self.tool(name="get_run_status", description="Get status for a job run") 342 | async def get_run_status(run_id: int, ctx: Context | None = None) -> CallToolResult: 343 | return await self._run_tool( 344 | "get_run_status", 345 | lambda: jobs.get_run_status(run_id), 346 | lambda data: f"Run {run_id} state: {data.get('state')}", 347 | ctx, 348 | ) 349 | 350 | @self.tool(name="list_job_runs", description="List recent job runs") 351 | async def list_job_runs(job_id: Optional[int] = None, ctx: Context | None = None) -> CallToolResult: 352 | return await self._run_tool( 353 | "list_job_runs", 354 | lambda: jobs.list_runs(job_id=job_id), 355 | lambda data: f"Found {len(data.get('runs', []))} runs", 356 | ctx, 357 | ) 358 | 359 | @self.tool(name="cancel_run", description="Cancel a job run") 360 | async def cancel_run(run_id: int, ctx: Context | None = None) -> CallToolResult: 361 | return await self._run_tool( 362 | "cancel_run", 363 | lambda: jobs.cancel_run(run_id), 364 | lambda _: f"Cancel requested for run {run_id}", 365 | ctx, 366 | ) 367 | 368 | # Notebook workspace tools 369 | @self.tool(name="list_notebooks", description="List notebooks in a directory") 370 | async def list_notebooks(path: str, ctx: Context | None = None) -> CallToolResult: 371 | return await self._run_tool( 372 | "list_notebooks", 373 | lambda: notebooks.list_notebooks(path), 374 | lambda data: f"Found {len(data.get('objects', []))} objects in {path}", 375 | ctx, 376 | ) 377 | 378 | @self.tool(name="export_notebook", description="Export a notebook") 379 | async def export_notebook(path: str, format: str = "SOURCE", ctx: Context | None = None) -> CallToolResult: 380 | result = await self._run_tool( 381 | "export_notebook", 382 | lambda: notebooks.export_notebook(path, format=format), 383 | lambda data: f"Exported notebook {path} in {format} format", 384 | ctx, 385 | ) 386 | 387 | data_block = result.structuredContent or {} 388 | if result.isError or not data_block: 389 | return result 390 | 391 | content_b64 = data_block.get("content") 392 | decoded = data_block.get("decoded_content") 393 | 394 | mime = { 395 | "SOURCE": "text/plain", 396 | "HTML": "text/html", 397 | "JUPYTER": "application/json", 398 | "DBC": "application/x-databricks-notebook", 399 | }.get(format, "application/octet-stream") 400 | 401 | if decoded: 402 | resource_uri = self._cache_resource(decoded, mime_type=mime, description=f"Notebook {path} ({format})") 403 | elif content_b64: 404 | raw_bytes = base64.b64decode(content_b64) 405 | resource_uri = self._cache_resource(raw_bytes, mime_type=mime, description=f"Notebook {path} ({format})") 406 | else: 407 | return result 408 | 409 | result.content.append( 410 | { 411 | "type": "resource_link", 412 | "uri": resource_uri, 413 | "name": f"Notebook export ({format})", 414 | "description": f"Notebook {path} ({format})", 415 | "mimeType": mime, 416 | } 417 | ) 418 | structured = result.structuredContent or {} 419 | structured.setdefault("resource_uri", resource_uri) 420 | result.structuredContent = structured 421 | await self._report_progress(ctx, 90, message="Notebook export cached") 422 | return result 423 | 424 | @self.tool(name="import_notebook", description="Import a notebook") 425 | async def import_notebook( 426 | path: str, 427 | content: str, 428 | format: str = "SOURCE", 429 | language: Optional[str] = None, 430 | overwrite: bool = False, 431 | ctx: Context | None = None, 432 | ) -> CallToolResult: 433 | return await self._run_tool( 434 | "import_notebook", 435 | lambda: notebooks.import_notebook(path, content, format=format, language=language, overwrite=overwrite), 436 | lambda _: f"Imported notebook to {path}", 437 | ctx, 438 | ) 439 | 440 | @self.tool( 441 | name="delete_workspace_object", 442 | description="Delete a workspace notebook or directory", 443 | ) 444 | async def delete_workspace_object(path: str, recursive: bool = False, ctx: Context | None = None) -> CallToolResult: 445 | return await self._run_tool( 446 | "delete_workspace_object", 447 | lambda: notebooks.delete_notebook(path, recursive=recursive), 448 | lambda _: f"Deleted workspace path {path}", 449 | ctx, 450 | ) 451 | 452 | @self.tool(name="get_workspace_file_content", description="Retrieve workspace file content") 453 | async def get_workspace_file_content(path: str, format: str = "SOURCE", ctx: Context | None = None) -> CallToolResult: 454 | result = await self._run_tool( 455 | "get_workspace_file_content", 456 | lambda: notebooks.export_workspace_file(path, format=format), 457 | lambda data: f"Retrieved workspace file {path}", 458 | ctx, 459 | ) 460 | 461 | data_block = result.structuredContent or {} 462 | if result.isError or not data_block: 463 | return result 464 | 465 | decoded = data_block.get("decoded_content") 466 | mime = "application/json" if data_block.get("content_type") == "json" else "text/plain" 467 | 468 | if decoded: 469 | resource_uri = self._cache_resource(decoded, mime_type=mime, description=f"Workspace file {path}") 470 | elif data_block.get("content"): 471 | raw_bytes = base64.b64decode(data_block["content"]) 472 | resource_uri = self._cache_resource(raw_bytes, mime_type=mime, description=f"Workspace file {path}") 473 | else: 474 | resource_uri = None 475 | 476 | if resource_uri: 477 | result.content.append( 478 | { 479 | "type": "resource_link", 480 | "uri": resource_uri, 481 | "name": "Workspace export", 482 | "description": f"Workspace file {path}", 483 | "mimeType": mime, 484 | } 485 | ) 486 | structured = result.structuredContent or {} 487 | structured.setdefault("resource_uri", resource_uri) 488 | result.structuredContent = structured 489 | 490 | return result 491 | 492 | @self.tool(name="get_workspace_file_info", description="Retrieve workspace metadata") 493 | async def get_workspace_file_info(path: str, ctx: Context | None = None) -> CallToolResult: 494 | return await self._run_tool( 495 | "get_workspace_file_info", 496 | lambda: notebooks.get_workspace_file_info(path), 497 | lambda data: f"Metadata returned for {data.get('path', path)}", 498 | ctx, 499 | ) 500 | 501 | # DBFS tools 502 | @self.tool(name="list_files", description="List DBFS files for a path") 503 | async def list_files(path: str, ctx: Context | None = None) -> CallToolResult: 504 | return await self._run_tool( 505 | "list_files", 506 | lambda: dbfs.list_files(path), 507 | lambda data: f"Found {len(data.get('paths') or data.get('files', []))} entries at {path}", 508 | ctx, 509 | ) 510 | 511 | @self.tool(name="dbfs_put", description="Upload small content to DBFS") 512 | async def dbfs_put(path: str, content: str, overwrite: bool = True, ctx: Context | None = None) -> CallToolResult: 513 | return await self._run_tool( 514 | "dbfs_put", 515 | lambda: dbfs.put_file(path, content.encode("utf-8"), overwrite=overwrite), 516 | lambda _: f"Uploaded content to {path}", 517 | ctx, 518 | ) 519 | 520 | @self.tool(name="dbfs_delete", description="Delete a DBFS path") 521 | async def dbfs_delete(path: str, recursive: bool = False, ctx: Context | None = None) -> CallToolResult: 522 | return await self._run_tool( 523 | "dbfs_delete", 524 | lambda: dbfs.delete_file(path, recursive=recursive), 525 | lambda _: f"Deleted DBFS path {path}", 526 | ctx, 527 | ) 528 | 529 | # Library tools 530 | @self.tool(name="install_library", description="Install libraries on a cluster") 531 | async def install_library(cluster_id: str, libraries_spec: List[Dict[str, Any]], ctx: Context | None = None) -> CallToolResult: 532 | return await self._run_tool( 533 | "install_library", 534 | lambda: libraries.install_library(cluster_id, libraries_spec), 535 | lambda _: f"Library install requested on cluster {cluster_id}", 536 | ctx, 537 | ) 538 | 539 | @self.tool(name="uninstall_library", description="Uninstall libraries from a cluster") 540 | async def uninstall_library(cluster_id: str, libraries_spec: List[Dict[str, Any]], ctx: Context | None = None) -> CallToolResult: 541 | return await self._run_tool( 542 | "uninstall_library", 543 | lambda: libraries.uninstall_library(cluster_id, libraries_spec), 544 | lambda _: f"Library uninstall requested on cluster {cluster_id}", 545 | ctx, 546 | ) 547 | 548 | @self.tool(name="list_cluster_libraries", description="List libraries for a cluster") 549 | async def list_cluster_libraries(cluster_id: str, ctx: Context | None = None) -> CallToolResult: 550 | return await self._run_tool( 551 | "list_cluster_libraries", 552 | lambda: libraries.list_cluster_libraries(cluster_id), 553 | lambda data: f"Cluster {cluster_id} has {len(data.get('library_statuses', []))} enrolled libraries", 554 | ctx, 555 | ) 556 | 557 | # Repo tools 558 | @self.tool(name="create_repo", description="Create or clone a repo") 559 | async def create_repo(url: str, provider: str, branch: Optional[str] = None, path: Optional[str] = None, ctx: Context | None = None) -> CallToolResult: 560 | return await self._run_tool( 561 | "create_repo", 562 | lambda: repos.create_repo(url, provider, branch=branch, path=path), 563 | lambda data: f"Repo {data.get('id')} created from {url}", 564 | ctx, 565 | ) 566 | 567 | @self.tool(name="update_repo", description="Update repo branch or tag") 568 | async def update_repo(repo_id: int, branch: Optional[str] = None, tag: Optional[str] = None, ctx: Context | None = None) -> CallToolResult: 569 | return await self._run_tool( 570 | "update_repo", 571 | lambda: repos.update_repo(repo_id, branch=branch, tag=tag), 572 | lambda _: f"Updated repo {repo_id}", 573 | ctx, 574 | ) 575 | 576 | @self.tool(name="list_repos", description="List repos in the workspace") 577 | async def list_repos(path_prefix: Optional[str] = None, ctx: Context | None = None) -> CallToolResult: 578 | return await self._run_tool( 579 | "list_repos", 580 | lambda: repos.list_repos(path_prefix=path_prefix), 581 | lambda data: f"Found {len(data.get('repos', []))} repos", 582 | ctx, 583 | ) 584 | 585 | @self.tool(name="pull_repo", description="Pull latest commit for a repo") 586 | async def pull_repo(repo_id: int, ctx: Context | None = None) -> CallToolResult: 587 | return await self._run_tool( 588 | "pull_repo", 589 | lambda: repos.pull_repo(repo_id), 590 | lambda _: f"Pulled latest changes for repo {repo_id}", 591 | ctx, 592 | ) 593 | 594 | # SQL tools 595 | @self.tool(name="execute_sql", description="Execute a SQL statement") 596 | async def execute_sql( 597 | statement: str, 598 | warehouse_id: Optional[str] = None, 599 | catalog: Optional[str] = None, 600 | schema_name: Optional[str] = None, 601 | ctx: Context | None = None, 602 | ) -> CallToolResult: 603 | async def action() -> Any: 604 | await self._report_progress(ctx, 10, message="Submitting SQL statement") 605 | result = await sql.execute_statement( 606 | statement=statement, 607 | warehouse_id=warehouse_id, 608 | catalog=catalog, 609 | schema=schema_name, 610 | ) 611 | await self._report_progress(ctx, 70, message="SQL statement completed") 612 | return result 613 | 614 | return await self._run_tool( 615 | "execute_sql", 616 | action, 617 | lambda data: f"SQL statement {data.get('statement_id', 'completed')} executed", 618 | ctx, 619 | ) 620 | 621 | # Unity catalog tools 622 | @self.tool(name="list_catalogs", description="List Unity Catalog catalogs") 623 | async def list_catalogs(ctx: Context | None = None) -> CallToolResult: 624 | return await self._run_tool( 625 | "list_catalogs", 626 | lambda: unity_catalog.list_catalogs(), 627 | lambda data: f"Found {len(data.get('catalogs', []))} catalogs", 628 | ctx, 629 | ) 630 | 631 | @self.tool(name="create_catalog", description="Create a Unity Catalog catalog") 632 | async def create_catalog(name: str, comment: Optional[str] = None, ctx: Context | None = None) -> CallToolResult: 633 | return await self._run_tool( 634 | "create_catalog", 635 | lambda: unity_catalog.create_catalog(name, comment), 636 | lambda _: f"Created catalog {name}", 637 | ctx, 638 | ) 639 | 640 | @self.tool(name="list_schemas", description="List schemas in a catalog") 641 | async def list_schemas(catalog_name: str, ctx: Context | None = None) -> CallToolResult: 642 | return await self._run_tool( 643 | "list_schemas", 644 | lambda: unity_catalog.list_schemas(catalog_name), 645 | lambda data: f"Catalog {catalog_name} has {len(data.get('schemas', []))} schemas", 646 | ctx, 647 | ) 648 | 649 | @self.tool(name="create_schema", description="Create a schema in Unity Catalog") 650 | async def create_schema(catalog_name: str, name: str, comment: Optional[str] = None, ctx: Context | None = None) -> CallToolResult: 651 | return await self._run_tool( 652 | "create_schema", 653 | lambda: unity_catalog.create_schema(catalog_name, name, comment), 654 | lambda _: f"Created schema {catalog_name}.{name}", 655 | ctx, 656 | ) 657 | 658 | @self.tool(name="list_tables", description="List tables for a schema") 659 | async def list_tables(catalog_name: str, schema_name: str, ctx: Context | None = None) -> CallToolResult: 660 | return await self._run_tool( 661 | "list_tables", 662 | lambda: unity_catalog.list_tables(catalog_name, schema_name), 663 | lambda data: f"Schema {catalog_name}.{schema_name} has {len(data.get('tables', []))} tables", 664 | ctx, 665 | ) 666 | 667 | @self.tool(name="create_table", description="Create a table via SQL") 668 | async def create_table(warehouse_id: str, statement: str, ctx: Context | None = None) -> CallToolResult: 669 | return await self._run_tool( 670 | "create_table", 671 | lambda: unity_catalog.create_table(warehouse_id, statement), 672 | lambda data: f"Table creation run {data.get('run_id', 'submitted')} for warehouse {warehouse_id}", 673 | ctx, 674 | ) 675 | 676 | @self.tool(name="get_table_lineage", description="Get Unity Catalog table lineage") 677 | async def get_table_lineage(full_name: str, ctx: Context | None = None) -> CallToolResult: 678 | return await self._run_tool( 679 | "get_table_lineage", 680 | lambda: unity_catalog.get_table_lineage(full_name), 681 | lambda data: f"Lineage contains {len(data.get('upstream_tables', []))} upstream tables", 682 | ctx, 683 | ) 684 | 685 | def get_metrics_snapshot(self) -> Dict[str, int]: 686 | """Return a copy of collected tool metrics.""" 687 | return dict(self._metrics) 688 | 689 | # ------------------------------------------------------------------ 690 | # Entrypoint 691 | # ------------------------------------------------------------------ 692 | 693 | def main() -> None: 694 | """Main entry point for the MCP server.""" 695 | configure_logging(level=settings.LOG_LEVEL, log_file="databricks_mcp.log") 696 | try: 697 | logger.info("Starting Databricks MCP server") 698 | if hasattr(sys.stdout, "reconfigure"): 699 | sys.stdout.reconfigure(line_buffering=True) 700 | server = DatabricksMCPServer() 701 | server.run() 702 | except Exception: # pylint: disable=broad-except 703 | logger.exception("Fatal error in Databricks MCP server") 704 | raise 705 | 706 | 707 | if __name__ == "__main__": 708 | main() 709 | --------------------------------------------------------------------------------