├── .github ├── scripts │ └── increment_version.py └── workflows │ ├── python-publish.yml │ ├── run-linters.yml │ └── run-tests.yml ├── .gitignore ├── CHANGELOG.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── config.example.env ├── pyproject.toml ├── pytest.ini ├── requirements-dev.txt ├── requirements.txt ├── tests ├── __init__.py ├── conftest.py ├── docker_utils.py ├── integration │ ├── conftest.py │ ├── test_authentication_integration.py │ ├── test_mcp_server_integration.py │ └── test_path_operations.py ├── mocks.py ├── test_connection.py ├── test_customjsonencoder.py ├── test_query.py └── test_server.py └── ydb_mcp ├── __init__.py ├── __main__.py ├── connection.py ├── query.py ├── server.py ├── tool_manager.py └── version.py /.github/scripts/increment_version.py: -------------------------------------------------------------------------------- 1 | #!/bin/env python 2 | import argparse 3 | from dataclasses import dataclass 4 | 5 | from packaging.version import Version 6 | 7 | PYPROJECT_PATH = "pyproject.toml" 8 | DEFAULT_CHANGELOG_PATH = "CHANGELOG.md" 9 | DEFAULT_YDB_VERSION_FILE = "ydb_mcp/version.py" 10 | MARKER = "# AUTOVERSION" 11 | 12 | 13 | @dataclass(init=False) 14 | class VersionLine: 15 | old_line: str 16 | major: int 17 | minor: int 18 | patch: int 19 | pre: int 20 | 21 | def __init__(self, old_line: str, version_str: str): 22 | self.old_line = old_line 23 | 24 | version = Version(version_str) 25 | self.major = version.major 26 | self.minor = version.minor 27 | self.micro = version.micro 28 | 29 | if version.pre is None: 30 | self.pre = 0 31 | else: 32 | self.pre = version.pre[1] 33 | 34 | def increment(self, part_name: str, with_beta: bool): 35 | if part_name == "minor": 36 | self.increment_minor(with_beta) 37 | elif part_name == "patch" or part_name == "micro": 38 | self.increment_micro(with_beta) 39 | else: 40 | raise Exception("unexpected increment type: '%s'" % part_name) 41 | 42 | def increment_minor(self, with_beta: bool): 43 | if with_beta: 44 | if self.pre == 0 or self.micro != 0: 45 | self.increment_minor(False) 46 | self.pre += 1 47 | return 48 | 49 | if self.micro == 0 and self.pre > 0: 50 | self.pre = 0 51 | return 52 | 53 | self.minor += 1 54 | self.micro = 0 55 | self.pre = 0 56 | 57 | def increment_micro(self, with_beta: bool): 58 | if with_beta: 59 | if self.pre == 0: 60 | self.increment_micro(False) 61 | self.pre += 1 62 | return 63 | 64 | if self.pre > 0: 65 | self.pre = 0 66 | return 67 | 68 | self.micro += 1 69 | 70 | def __str__(self): 71 | if self.pre > 0: 72 | pre = "b%s" % self.pre 73 | else: 74 | pre = "" 75 | 76 | return "%s.%s.%s%s" % (self.major, self.minor, self.micro, pre) 77 | 78 | def version_line_with_mark(self): 79 | return 'version = "%s" %s' % (str(self), MARKER) 80 | 81 | 82 | def extract_version(pyproject_content: str): 83 | version_line = "" 84 | for line in pyproject_content.splitlines(): 85 | if MARKER in line: 86 | version_line = line 87 | break 88 | 89 | if version_line == "": 90 | raise Exception("Not found version line") 91 | 92 | version_line = version_line.strip() 93 | 94 | parts = version_line.split('"') 95 | version_part = parts[1] 96 | 97 | return VersionLine(old_line=version_line, version_str=version_part) 98 | 99 | 100 | def increment_version_at_pyproject(pyproject_path: str, inc_type: str, with_beta: bool) -> str: 101 | with open(pyproject_path, "rt") as f: 102 | setup_content = f.read() 103 | 104 | version = extract_version(setup_content) 105 | version.increment(inc_type, with_beta) 106 | setup_content = setup_content.replace(version.old_line, version.version_line_with_mark()) 107 | 108 | with open(pyproject_path, "w") as f: 109 | f.write(setup_content) 110 | 111 | return str(version) 112 | 113 | 114 | def add_changelog_version(changelog_path, version: str): 115 | with open(changelog_path, "rt") as f: 116 | content = f.read() 117 | content = content.strip() 118 | 119 | if content.startswith("##"): 120 | return 121 | 122 | content = """## %s ## 123 | %s 124 | """ % (version, content) 125 | with open(changelog_path, "w") as f: 126 | f.write(content) 127 | 128 | 129 | def set_version_in_version_file(file_path: str, version: str): 130 | with open(file_path, "w") as f: 131 | f.write('VERSION = "%s"\n' % version) 132 | 133 | 134 | def main(): 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument( 137 | "--inc-type", 138 | default="minor", 139 | help="increment version type: patch or minor", 140 | choices=["minor", "patch"], 141 | ) 142 | parser.add_argument("--beta", choices=["true", "false"], help="is beta version") 143 | parser.add_argument( 144 | "--changelog-path", 145 | default=DEFAULT_CHANGELOG_PATH, 146 | help="path to changelog", 147 | type=str, 148 | ) 149 | parser.add_argument("--pyproject-path", default=PYPROJECT_PATH) 150 | 151 | args = parser.parse_args() 152 | 153 | is_beta = args.beta == "true" 154 | 155 | new_version = increment_version_at_pyproject(args.pyproject_path, args.inc_type, is_beta) 156 | add_changelog_version(args.changelog_path, new_version) 157 | set_version_in_version_file(DEFAULT_YDB_VERSION_FILE, new_version) 158 | print(new_version) 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Publish package release 10 | 11 | on: 12 | workflow_dispatch: 13 | inputs: 14 | version-change: 15 | description: Version part 16 | required: true 17 | type: choice 18 | default: patch 19 | options: 20 | - minor 21 | - patch 22 | beta: 23 | description: Is beta version 24 | required: true 25 | type: boolean 26 | default: False 27 | jobs: 28 | publish: 29 | env: 30 | VERSION_CHANGE: ${{ github.event.inputs.version-change }} 31 | WITH_BETA: ${{ github.event.inputs.beta }} 32 | GH_TOKEN: ${{ secrets.YDB_PLATFORM_BOT_TOKEN_REPO }} 33 | CHANGELOG_FILE: CHANGELOG.md 34 | PYPROJECT_PATH: pyproject.toml 35 | 36 | permissions: 37 | contents: write 38 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 39 | 40 | runs-on: ubuntu-latest 41 | 42 | steps: 43 | - uses: actions/checkout@v3 44 | with: 45 | token: ${{ secrets.YDB_PLATFORM_BOT_TOKEN_REPO }} 46 | 47 | - name: Set up Python 48 | uses: actions/setup-python@v3 49 | with: 50 | python-version: '3.9' 51 | 52 | - name: Install dependencies 53 | run: | 54 | python -m pip install --upgrade pip 55 | pip install packaging build 56 | 57 | - name: read changelog 58 | id: read-changelog 59 | run: | 60 | CHANGELOG=$(cat $CHANGELOG_FILE | sed -e '/^## .*$/,$d') 61 | echo "CHANGELOG<> $GITHUB_ENV 62 | echo "$CHANGELOG" >> $GITHUB_ENV 63 | echo "CHANGELOGEOF_MARKER" >> $GITHUB_ENV 64 | echo "# Changelog" >> $GITHUB_STEP_SUMMARY 65 | echo "$CHANGELOG" >> $GITHUB_STEP_SUMMARY 66 | 67 | 68 | - name: Increment version 69 | id: increment-version 70 | run: | 71 | NEW_VERSION=$(python3 ./.github/scripts/increment_version.py --inc-type=$VERSION_CHANGE --beta=$WITH_BETA) 72 | echo new version: $NEW_VERSION 73 | echo "NEW_VERSION=$NEW_VERSION" >> $GITHUB_OUTPUT 74 | echo "New version: $NEW_VERSION" >> $GITHUB_STEP_SUMMARY 75 | 76 | - name: Build package 77 | run: python -m build 78 | 79 | - name: Publish release on github 80 | run: | 81 | if [[ -z "$CHANGELOG" ]] 82 | then 83 | echo "CHANGELOG empty" 84 | exit 1; 85 | fi; 86 | 87 | TAG="${{ steps.increment-version.outputs.NEW_VERSION }}" 88 | 89 | # Get previous version from changelog 90 | # pre-incremented version not used for consistent changelog with release notes 91 | # for example changelog may be rewrited when switch from beta to release 92 | # and remove internal beta changes 93 | LAST_TAG=$(cat $CHANGELOG_FILE | grep '^## .* ##$' | head -n 2 | tail -n 1 | cut -d ' ' -f 2) 94 | 95 | git config --global user.email "robot@umbrella"; 96 | git config --global user.name "robot"; 97 | git commit -am "Release: $TAG"; 98 | 99 | git tag "$TAG" 100 | git push && git push --tags 101 | 102 | CHANGELOG="$CHANGELOG 103 | 104 | Full Changelog: [$LAST_TAG...$TAG](https://github.com/ydb-platform/ydb-sqlalchemy/compare/$LAST_TAG...$TAG)" 105 | if [ "$WITH_BETA" = true ] 106 | then 107 | gh release create --prerelease $TAG --title "$TAG" --notes "$CHANGELOG" 108 | else 109 | gh release create $TAG --title "$TAG" --notes "$CHANGELOG" 110 | fi; 111 | 112 | - name: Publish package 113 | uses: pypa/gh-action-pypi-publish@release/v1.12 -------------------------------------------------------------------------------- /.github/workflows/run-linters.yml: -------------------------------------------------------------------------------- 1 | name: Run Linters 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | pull_request_target: 8 | branches: 9 | - '**' 10 | workflow_dispatch: 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.13' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | - name: Run linters 26 | run: make lint -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | pull_request_target: 8 | branches: 9 | - '**' 10 | workflow_dispatch: 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.13' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | - name: Run tests 26 | run: make test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # Unit test / coverage reports 28 | htmlcov/ 29 | .tox/ 30 | .coverage 31 | .coverage.* 32 | .cache 33 | nosetests.xml 34 | coverage.xml 35 | *.cover 36 | .hypothesis/ 37 | .pytest_cache/ 38 | *.log 39 | integration_test* 40 | 41 | # Environments 42 | .env 43 | .venv 44 | env/ 45 | venv/ 46 | ENV/ 47 | env.bak/ 48 | venv.bak/ 49 | 50 | # IDEs and editors 51 | .idea/ 52 | .vscode/ 53 | *.swp 54 | *.swo 55 | *~ 56 | 57 | # macOS specific files 58 | .DS_Store 59 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.1.1 ## 2 | * Update package 3 | 4 | ## 0.1.0 ## 5 | * First version 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | # Install dependencies 6 | COPY requirements.txt . 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | 9 | # Copy source code 10 | COPY ydb_mcp ./ydb_mcp 11 | COPY setup.py . 12 | 13 | # Install the package 14 | RUN pip install --no-cache-dir -e . 15 | 16 | # Expose the server port 17 | EXPOSE 8080 18 | 19 | # Set environment variables 20 | ENV YDB_ENDPOINT="" 21 | ENV YDB_DATABASE="" 22 | 23 | # Run the server 24 | CMD ["ydb-mcp", "--host", "0.0.0.0", "--port", "8080"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022-2025 YANDEX LLC 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all clean test lint format install dev unit-tests integration-tests run-server 2 | 3 | # Default target 4 | all: clean lint test 5 | 6 | # Clean build files 7 | clean: 8 | rm -rf build/ dist/ *.egg-info/ __pycache__/ .pytest_cache/ .coverage htmlcov/ 9 | find . -name "*.pyc" -delete 10 | find . -name "__pycache__" -delete 11 | 12 | # Run tests 13 | test: dev 14 | $(eval LOG_LEVEL ?= WARNING) 15 | PYTHONPATH=. pytest --log-cli-level=$(LOG_LEVEL) 16 | 17 | # Run unit tests only 18 | unit-tests: dev 19 | $(eval LOG_LEVEL ?= WARNING) 20 | PYTHONPATH=. python -m pytest -m unit -v --log-cli-level=$(LOG_LEVEL) 21 | 22 | # Run integration tests 23 | integration-tests: dev 24 | $(eval YDB_ENDPOINT ?= grpc://localhost:2136) 25 | $(eval YDB_DATABASE ?= /local) 26 | $(eval MCP_HOST ?= 127.0.0.1) 27 | $(eval MCP_PORT ?= 8989) 28 | $(eval LOG_LEVEL ?= WARNING) 29 | @echo "Running integration tests with the following configuration:" 30 | @echo "YDB Endpoint: $(YDB_ENDPOINT)" 31 | @echo "YDB Database: $(YDB_DATABASE)" 32 | @echo "MCP Host: $(MCP_HOST)" 33 | @echo "MCP Port: $(MCP_PORT)" 34 | @echo "Log Level: $(LOG_LEVEL)" 35 | @echo "Note: Tests will automatically create YDB in Docker if no YDB server is running at the endpoint" 36 | YDB_ENDPOINT=$(YDB_ENDPOINT) YDB_DATABASE=$(YDB_DATABASE) MCP_HOST=$(MCP_HOST) MCP_PORT=$(MCP_PORT) PYTHONPATH=. python -m pytest -m integration -v --log-cli-level=$(LOG_LEVEL) 37 | 38 | # Run server 39 | run-server: 40 | $(eval YDB_ENDPOINT ?= grpc://localhost:2136) 41 | $(eval YDB_DATABASE ?= /local) 42 | YDB_ENDPOINT=$(YDB_ENDPOINT) YDB_DATABASE=$(YDB_DATABASE) python -m ydb_mcp $(ARGS) 43 | 44 | # Run lint checks 45 | lint: dev 46 | ruff check ydb_mcp tests 47 | mypy ydb_mcp 48 | 49 | # Format code 50 | format: dev 51 | ruff format ydb_mcp tests 52 | 53 | # Install package 54 | install: 55 | pip install -e . 56 | 57 | # Install development dependencies 58 | dev: 59 | pip install -e ".[dev]" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YDB MCP 2 | --- 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/ydb-platform/ydb-mcp/blob/main/LICENSE) 4 | [![PyPI version](https://badge.fury.io/py/ydb-mcp.svg)](https://badge.fury.io/py/ydb-mcp) 5 | 6 | [Model Context Protocol server](https://modelcontextprotocol.io/) for [YDB](https://ydb.tech). It allows to work with YDB databases from any [LLM](https://en.wikipedia.org/wiki/Large_language_model) that supports MCP. This integration enables AI-powered database operations and natural language interactions with your YDB instances. 7 | 8 | 9 | YDB MCP server 10 | 11 | 12 | ## Usage 13 | 14 | ### Via uvx 15 | 16 | [uvx](https://docs.astral.sh/uv/concepts/tools/), which is an allias for `uv run tool`, allows you to run various python applications without explicitly installing them. Below are examples of how to configure YDB MCP using `uvx`. 17 | 18 | #### Example: Using Anonymous Authentication 19 | 20 | ```json 21 | { 22 | "mcpServers": { 23 | "ydb": { 24 | "command": "uvx", 25 | "args": [ 26 | "ydb-mcp", 27 | "--ydb-endpoint", "grpc://localhost:2136/local" 28 | ] 29 | } 30 | } 31 | } 32 | ``` 33 | 34 | #### Example: Using Login/Password Authentication 35 | 36 | To use login/password authentication, specify the `--ydb-auth-mode`, `--ydb-login`, and `--ydb-password` arguments: 37 | 38 | ```json 39 | { 40 | "mcpServers": { 41 | "ydb": { 42 | "command": "uvx", 43 | "args": [ 44 | "ydb-mcp", 45 | "--ydb-endpoint", "grpc://localhost:2136/local", 46 | "--ydb-auth-mode", "login-password", 47 | "--ydb-login", "", 48 | "--ydb-password", "" 49 | ] 50 | } 51 | } 52 | } 53 | ``` 54 | 55 | ### Via pipx 56 | 57 | [pipx](https://pipx.pypa.io/stable/) allows you to run various applications from PyPI without explicitly installing each one. However, it must be [installed](https://pipx.pypa.io/stable/#install-pipx) first. Below are examples of how to configure YDB MCP using `pipx`. 58 | 59 | #### Example: Using Anonymous Authentication 60 | 61 | ```json 62 | { 63 | "mcpServers": { 64 | "ydb": { 65 | "command": "pipx", 66 | "args": [ 67 | "run", "ydb-mcp", 68 | "--ydb-endpoint", "grpc://localhost:2136/local" 69 | ] 70 | } 71 | } 72 | } 73 | ``` 74 | 75 | #### Example: Using Login/Password Authentication 76 | 77 | To use login/password authentication, specify the `--ydb-auth-mode`, `--ydb-login`, and `--ydb-password` arguments: 78 | 79 | ```json 80 | { 81 | "mcpServers": { 82 | "ydb": { 83 | "command": "pipx", 84 | "args": [ 85 | "run", "ydb-mcp", 86 | "--ydb-endpoint", "grpc://localhost:2136/local", 87 | "--ydb-auth-mode", "login-password", 88 | "--ydb-login", "", 89 | "--ydb-password", "" 90 | ] 91 | } 92 | } 93 | } 94 | ``` 95 | 96 | ### Via pip 97 | 98 | YDB MCP can be installed using `pip`, [Python's package installer](https://pypi.org/project/pip/). The package is [available on PyPI](https://pypi.org/project/ydb-mcp/) and includes all necessary dependencies. 99 | 100 | ```bash 101 | pip install ydb-mcp 102 | ``` 103 | 104 | To get started with YDB MCP, you'll need to configure your MCP client to communicate with the YDB instance. Below are example configuration files that you can customize according to your setup and then put into MCP client's settings. Path to the Python interpreter might also need to be adjusted to the correct virtual environment that has the `ydb-mcp` package installed. 105 | 106 | #### Example: Using Anonymous Authentication 107 | 108 | ```json 109 | { 110 | "mcpServers": { 111 | "ydb": { 112 | "command": "python3", 113 | "args": [ 114 | "-m", "ydb_mcp", 115 | "--ydb-endpoint", "grpc://localhost:2136/local" 116 | ] 117 | } 118 | } 119 | } 120 | ``` 121 | 122 | #### Example: Using Login/Password Authentication 123 | 124 | To use login/password authentication, specify the `--ydb-auth-mode`, `--ydb-login`, and `--ydb-password` arguments: 125 | 126 | ```json 127 | { 128 | "mcpServers": { 129 | "ydb": { 130 | "command": "python3", 131 | "args": [ 132 | "-m", "ydb_mcp", 133 | "--ydb-endpoint", "grpc://localhost:2136/local", 134 | "--ydb-auth-mode", "login-password", 135 | "--ydb-login", "", 136 | "--ydb-password", "" 137 | ] 138 | } 139 | } 140 | } 141 | ``` 142 | 143 | ## Available Tools 144 | 145 | YDB MCP provides the following tools for interacting with YDB databases: 146 | 147 | - `ydb_query`: Run a SQL query against a YDB database 148 | - Parameters: 149 | - `sql`: SQL query string to execute 150 | 151 | - `ydb_query_with_params`: Run a parameterized SQL query with JSON parameters 152 | - Parameters: 153 | - `sql`: SQL query string with parameter placeholders 154 | - `params`: JSON string containing parameter values 155 | 156 | - `ydb_list_directory`: List directory contents in YDB 157 | - Parameters: 158 | - `path`: YDB directory path to list 159 | 160 | - `ydb_describe_path`: Get detailed information about a YDB path (table, directory, etc.) 161 | - Parameters: 162 | - `path`: YDB path to describe 163 | 164 | - `ydb_status`: Get the current status of the YDB connection 165 | 166 | ## Development 167 | 168 | The project uses [Make](https://www.gnu.org/software/make/) as its primary development tool, providing a consistent interface for common development tasks. 169 | 170 | ### Available Make Commands 171 | 172 | The project includes a comprehensive Makefile with various commands for development tasks. Each command is designed to streamline the development workflow and ensure code quality: 173 | 174 | - `make all`: Run clean, lint, and test in sequence (default target) 175 | - `make clean`: Remove all build artifacts and temporary files 176 | - `make test`: Run all tests using pytest 177 | - Can be configured with environment variables: 178 | - `LOG_LEVEL` (default: WARNING) - Control test output verbosity (DEBUG, INFO, WARNING, ERROR) 179 | - `make unit-tests`: Run only unit tests with verbose output 180 | - Can be configured with environment variables: 181 | - `LOG_LEVEL` (default: WARNING) - Control test output verbosity (DEBUG, INFO, WARNING, ERROR) 182 | - `make integration-tests`: Run only integration tests with verbose output 183 | - Can be configured with environment variables: 184 | - `YDB_ENDPOINT` (default: grpc://localhost:2136) 185 | - `YDB_DATABASE` (default: /local) 186 | - `MCP_HOST` (default: 127.0.0.1) 187 | - `MCP_PORT` (default: 8989) 188 | - `LOG_LEVEL` (default: WARNING) - Control test output verbosity (DEBUG, INFO, WARNING, ERROR) 189 | - `make run-server`: Start the YDB MCP server 190 | - Can be configured with environment variables: 191 | - `YDB_ENDPOINT` (default: grpc://localhost:2136) 192 | - `YDB_DATABASE` (default: /local) 193 | - Additional arguments can be passed using `ARGS="your args"` 194 | - `make lint`: Run all linting checks (flake8, mypy, black, isort) 195 | - `make format`: Format code using black and isort 196 | - `make install`: Install the package in development mode 197 | - `make dev`: Install the package in development mode with all development dependencies 198 | 199 | ### Test Verbosity Control 200 | 201 | By default, tests run with minimal output (WARNING level) to keep the output clean. You can control the verbosity of test output using the `LOG_LEVEL` environment variable: 202 | 203 | ```bash 204 | # Run all tests with debug output 205 | make test LOG_LEVEL=DEBUG 206 | 207 | # Run integration tests with info output 208 | make integration-tests LOG_LEVEL=INFO 209 | 210 | # Run unit tests with warning output (default) 211 | make unit-tests LOG_LEVEL=WARNING 212 | ``` 213 | 214 | Available log levels: 215 | - `DEBUG`: Show all debug messages, useful for detailed test flow 216 | - `INFO`: Show informational messages and above 217 | - `WARNING`: Show only warnings and errors (default) 218 | - `ERROR`: Show only error messages -------------------------------------------------------------------------------- /config.example.env: -------------------------------------------------------------------------------- 1 | # YDB connection settings 2 | 3 | YDB_ENDPOINT=grpc://ydb.example.com:2136/local 4 | 5 | # Login/password authentication 6 | 7 | # YDB_AUTH_MODE=login-password 8 | # YDB_LOGIN= 9 | # YDB_PASSWORD= 10 | 11 | # Server settings 12 | MCP_HOST=127.0.0.1 13 | MCP_PORT=8080 14 | 15 | # Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) 16 | LOG_LEVEL=INFO -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ydb-mcp" 7 | version = "0.1.1" # AUTOVERSION 8 | description = "Model Context Protocol server for YDB DBMS" 9 | readme = "README.md" 10 | authors = [ 11 | {name = "YDB MCP Team", email = "info@ydb.tech"} 12 | ] 13 | license = {text = "Apache 2.0"} 14 | requires-python = ">=3.10" 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Intended Audience :: Developers", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | ] 23 | dependencies = [ 24 | "ydb>=3.21.0", 25 | "mcp>=1.6.0", 26 | ] 27 | 28 | [project.optional-dependencies] 29 | dev = [ 30 | "pytest>=7.3.1", 31 | "pytest-asyncio>=0.21.0", 32 | "pytest-cov>=4.1.0", 33 | "pytest-assume>=2.4.3", 34 | "mypy>=1.3.0", 35 | "ruff>=0.11.0", 36 | "docker>=7.0.0", 37 | ] 38 | 39 | [project.scripts] 40 | ydb-mcp = "ydb_mcp.__main__:main" 41 | 42 | [tool.ruff] 43 | line-length = 121 44 | target-version = "py310" 45 | 46 | [tool.ruff.lint] 47 | select = [ 48 | "E", # pycodestyle 49 | "F", # pyflakes 50 | "I", # isort 51 | # TODO: extend with more rules 52 | ] 53 | 54 | [tool.mypy] 55 | python_version = "3.10" 56 | warn_return_any = true 57 | warn_unused_configs = true 58 | 59 | [[tool.mypy.overrides]] 60 | module = "ydb.*" 61 | ignore_missing_imports = true -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | python_files = test_*.py 4 | python_classes = Test* 5 | python_functions = test_* 6 | 7 | markers = 8 | unit: mark a test as a unit test 9 | integration: mark a test as an integration test 10 | 11 | # Log configuration 12 | log_cli = True 13 | log_cli_level = INFO 14 | log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) 15 | log_cli_date_format = %Y-%m-%d %H:%M:%S 16 | 17 | # Configure asyncio mode 18 | asyncio_mode = auto 19 | 20 | # Filter warnings 21 | filterwarnings = 22 | ignore::DeprecationWarning:ydb.types: 23 | ignore::RuntimeWarning:asyncio: 24 | ignore::RuntimeWarning: 25 | ignore::RuntimeWarning:ydb_mcp.patches: 26 | ignore:Task was destroyed but it is pending:RuntimeWarning:asyncio.base_events 27 | ignore:Task was destroyed but it is pending:UserWarning 28 | ignore:Task was destroyed but it is pending 29 | ignore:Error handling discovery task:RuntimeWarning:tests.integration.conftest 30 | ignore:Error stopping driver:RuntimeWarning:tests.integration.conftest 31 | ignore:.*Task was destroyed but it is pending.*:RuntimeWarning 32 | ignore:.*Task was destroyed but it is pending.*:UserWarning 33 | ignore:.*Task was destroyed but it is pending.* 34 | 35 | addopts = --cov=ydb_mcp --cov-report=term-missing --cov-report=xml --cov-report=html --no-cov-on-fail -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | pytest>=7.3.1 3 | pytest-asyncio>=0.21.0 4 | pytest-cov>=4.1.0 5 | pytest-assume>=2.4.3 6 | mypy>=1.3.0 7 | ruff>=0.11.0 8 | docker>=7.0.0 # For YDB Docker container management in tests -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ydb>=3.21.0 2 | mcp>=1.6.0 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for MCP YDB package.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest configuration for testing YDB MCP server.""" 2 | 3 | import os 4 | from unittest.mock import AsyncMock, patch 5 | 6 | import pytest 7 | 8 | 9 | @pytest.fixture 10 | def mock_ydb_driver(): 11 | """Mock YDB driver.""" 12 | with patch("ydb.aio.Driver") as mock_driver_class: 13 | # Setup the mock driver instance 14 | mock_driver = AsyncMock() 15 | mock_driver_class.return_value = mock_driver 16 | 17 | # Mock the wait method 18 | mock_driver.wait = AsyncMock() 19 | 20 | yield mock_driver 21 | 22 | 23 | @pytest.fixture 24 | def mock_ydb_pool(): 25 | """Mock YDB session pool.""" 26 | with patch("ydb.aio.QuerySessionPool") as mock_pool_class: 27 | # Setup the mock pool instance 28 | mock_pool = AsyncMock() 29 | mock_pool_class.return_value = mock_pool 30 | 31 | # Mock the execute_with_retries method 32 | mock_pool.execute_with_retries = AsyncMock() 33 | 34 | yield mock_pool 35 | 36 | 37 | @pytest.fixture 38 | def mock_env_vars(): 39 | """Mock environment variables.""" 40 | env_vars = { 41 | "YDB_ENDPOINT": "mock-endpoint", 42 | "YDB_DATABASE": "mock-database", 43 | "YDB_ANONYMOUS_CREDENTIALS": "1", 44 | } 45 | 46 | with patch.dict(os.environ, env_vars): 47 | yield env_vars 48 | -------------------------------------------------------------------------------- /tests/docker_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socket 4 | import time 5 | 6 | import docker 7 | import pytest 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def get_docker_client(): 13 | """Connect to Docker daemon via multiple methods or fail.""" 14 | connection_errors = [] 15 | # Method 1: default environment 16 | try: 17 | client = docker.from_env() 18 | client.ping() 19 | return client 20 | except Exception as e: 21 | connection_errors.append(f"default: {e}") 22 | # Method 2: DOCKER_HOST 23 | docker_host = os.getenv("DOCKER_HOST") 24 | if docker_host: 25 | try: 26 | client = docker.DockerClient(base_url=docker_host) 27 | client.ping() 28 | return client 29 | except Exception as e: 30 | connection_errors.append(f"DOCKER_HOST: {e}") 31 | # Method 3: common Unix sockets 32 | socket_paths = [ 33 | "unix:///var/run/docker.sock", 34 | "unix://" + os.path.expanduser("~/.docker/run/docker.sock"), 35 | "unix://" + os.path.expanduser("~/.colima/default/docker.sock"), 36 | ] 37 | for sp in socket_paths: 38 | try: 39 | client = docker.DockerClient(base_url=sp) 40 | client.ping() 41 | return client 42 | except Exception as e: 43 | connection_errors.append(f"{sp}: {e}") 44 | # All methods failed 45 | logger.error("Docker connection errors:\n%s", "\n".join(connection_errors)) 46 | pytest.fail("Could not connect to Docker. Make sure Docker daemon is running.") 47 | 48 | 49 | def start_container(image: str, **kwargs): 50 | """Pull and run a Docker container with given parameters.""" 51 | client = get_docker_client() 52 | # Pull image 53 | client.images.pull(image) 54 | # Run container 55 | container = client.containers.run(image=image, **kwargs) 56 | return container 57 | 58 | 59 | def stop_container(container): 60 | """Stop the given Docker container.""" 61 | try: 62 | container.stop(timeout=1) 63 | except Exception: 64 | logger.warning("Error stopping container %s", container) 65 | 66 | 67 | def wait_for_port(host: str, port: int, timeout: int = 30): 68 | """Wait until given host:port is accepting connections or fail.""" 69 | for _ in range(timeout): 70 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 71 | sock.settimeout(1) 72 | try: 73 | if sock.connect_ex((host, port)) == 0: 74 | return 75 | except Exception: 76 | pass 77 | time.sleep(1) 78 | pytest.fail(f"Port {host}:{port} not ready after {timeout}s") 79 | 80 | 81 | def start_ydb_container(): 82 | """Start a YDB Docker container for integration tests.""" 83 | # YDB image and ports configuration 84 | image = "ydbplatform/local-ydb:latest" 85 | env = { 86 | "GRPC_TLS_PORT": "2135", 87 | "GRPC_PORT": "2136", 88 | "MON_PORT": "8765", 89 | "YDB_KAFKA_PROXY_PORT": "9092", 90 | "YDB_USE_IN_MEMORY_PDISKS": "1", 91 | } 92 | ports = {"2135/tcp": 2135, "2136/tcp": 2136, "8765/tcp": 8765, "9092/tcp": 9092} 93 | container = start_container( 94 | image=image, 95 | detach=True, 96 | remove=True, 97 | hostname="localhost", 98 | platform="linux/amd64", 99 | environment=env, 100 | ports=ports, 101 | ) 102 | return container 103 | 104 | 105 | def start_ollama_container(): 106 | """Start an Ollama Docker container for integration tests by pulling `llama2` and serving it.""" 107 | image = "ollama/ollama:latest" 108 | client = get_docker_client() 109 | # Pull the Ollama image for linux/amd64 110 | client.images.pull(image, platform="linux/amd64") 111 | # Combine model pull and serve in one container to ensure the model is available 112 | shell_cmd = "ollama pull llama2 && exec ollama serve --http-port 11434 --http-address 0.0.0.0" 113 | container = client.containers.run( 114 | image=image, 115 | command=["sh", "-c", shell_cmd], 116 | detach=True, 117 | remove=True, 118 | ports={"11434/tcp": 11434}, 119 | platform="linux/amd64", 120 | ) 121 | return container 122 | -------------------------------------------------------------------------------- /tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for YDB MCP integration tests. 2 | 3 | This module provides shared fixtures for all integration tests, including 4 | automatic Docker container management for YDB. 5 | """ 6 | 7 | import asyncio 8 | import gc 9 | import json 10 | import logging 11 | import os 12 | import socket 13 | import time 14 | from contextlib import suppress 15 | from urllib.parse import urlparse 16 | 17 | import pytest 18 | 19 | from tests.docker_utils import start_ydb_container, stop_container, wait_for_port 20 | from ydb_mcp.server import AUTH_MODE_ANONYMOUS, YDBMCPServer 21 | 22 | # Configuration for the tests 23 | YDB_ENDPOINT = os.environ.get("YDB_ENDPOINT", "grpc://localhost:2136/local") 24 | # Database will be extracted from the endpoint if not explicitly provided 25 | YDB_DATABASE = os.environ.get("YDB_DATABASE") 26 | 27 | # Set up logging 28 | logging.basicConfig(level=logging.WARNING) # Set default level to WARNING 29 | logger = logging.getLogger(__name__) 30 | logger.setLevel(logging.WARNING) # Set test logger to WARNING 31 | 32 | # Set specific loggers to appropriate levels 33 | ydb_logger = logging.getLogger("ydb") 34 | ydb_logger.setLevel(logging.ERROR) # Raise YDB logger level to ERROR 35 | 36 | # Keep server startup/shutdown and critical error logs at INFO/ERROR level 37 | server_logger = logging.getLogger("ydb_mcp.server") 38 | server_logger.setLevel(logging.ERROR) # Raise server logger level to ERROR 39 | 40 | # Set asyncio logger to ERROR to suppress task destruction messages 41 | asyncio_logger = logging.getLogger("asyncio") 42 | asyncio_logger.setLevel(logging.ERROR) 43 | 44 | 45 | async def cleanup_pending_tasks(): 46 | """Clean up any pending tasks in the current event loop.""" 47 | try: 48 | loop = asyncio.get_running_loop() 49 | except RuntimeError: 50 | # No running event loop 51 | return 52 | 53 | # Get all pending tasks except the current one 54 | current = asyncio.current_task(loop) 55 | pending = [task for task in asyncio.all_tasks(loop) if not task.done() and task is not current] 56 | 57 | # Explicitly suppress destroy pending warning for YDB Discovery.run tasks 58 | for task in pending: 59 | coro = getattr(task, "get_coro", lambda: None)() 60 | if coro and "Discovery.run" in repr(coro): 61 | task._log_destroy_pending = False 62 | 63 | if not pending: 64 | return 65 | 66 | logger.debug(f"Cleaning up {len(pending)} pending tasks") 67 | 68 | # Cancel all pending tasks 69 | for task in pending: 70 | if not task.done() and not task.cancelled(): 71 | # Disable the destroy pending warning for this task 72 | task._log_destroy_pending = False 73 | task.cancel() 74 | 75 | try: 76 | # Wait for tasks to cancel with a timeout, using shield to prevent cancellation 77 | await asyncio.shield(asyncio.wait(pending, timeout=0.1)) 78 | except Exception as e: 79 | logger.debug(f"Error waiting for tasks to cancel: {e}") 80 | 81 | # Force cancel any remaining tasks 82 | still_pending = [t for t in pending if not t.done()] 83 | if still_pending: 84 | logger.debug(f"Force cancelling {len(still_pending)} tasks that did not cancel properly") 85 | for task in still_pending: 86 | # Ensure the task won't log warnings when destroyed 87 | task._log_destroy_pending = False 88 | # Force cancel and suppress any errors 89 | with suppress(asyncio.CancelledError, Exception): 90 | task.cancel() 91 | try: 92 | await asyncio.shield(asyncio.wait_for(task, timeout=0.1)) 93 | except asyncio.TimeoutError: 94 | pass 95 | 96 | 97 | async def cleanup_driver(driver, timeout=1.0): 98 | """Clean up the driver and any associated tasks.""" 99 | if not driver: 100 | return 101 | 102 | try: 103 | # First handle discovery task if it exists 104 | if hasattr(driver, "_discovery") and driver._discovery: 105 | logger.debug("Handling discovery task") 106 | try: 107 | # Try to stop discovery gracefully first 108 | if hasattr(driver._discovery, "stop"): 109 | driver._discovery.stop() 110 | 111 | # Then cancel the task if it exists and is still running 112 | if hasattr(driver._discovery, "_discovery_task"): 113 | task = driver._discovery._discovery_task 114 | if task and not task.done() and not task.cancelled(): 115 | task._log_destroy_pending = False 116 | task.cancel() 117 | try: 118 | await asyncio.shield(asyncio.wait_for(task, timeout=0.1)) 119 | except (asyncio.CancelledError, asyncio.TimeoutError, Exception): 120 | pass 121 | except Exception as e: 122 | logger.debug(f"Error handling discovery task: {e}") 123 | 124 | # Stop the driver with proper error handling 125 | logger.debug("Stopping driver") 126 | try: 127 | # Use shield to prevent cancellation of the stop operation 128 | await asyncio.shield(asyncio.wait_for(driver.stop(), timeout=timeout)) 129 | except asyncio.TimeoutError: 130 | logger.debug(f"Driver stop timed out after {timeout} seconds") 131 | except asyncio.CancelledError: 132 | logger.debug("Driver stop was cancelled") 133 | except Exception as e: 134 | logger.debug(f"Error stopping driver: {e}") 135 | 136 | finally: 137 | # Clean up any remaining tasks 138 | await cleanup_pending_tasks() 139 | 140 | 141 | def ensure_event_loop(): 142 | """Ensure we have a valid event loop and return it.""" 143 | try: 144 | loop = asyncio.get_running_loop() 145 | if loop.is_closed(): 146 | loop = asyncio.new_event_loop() 147 | asyncio.set_event_loop(loop) 148 | except RuntimeError: 149 | loop = asyncio.new_event_loop() 150 | asyncio.set_event_loop(loop) 151 | return loop 152 | 153 | 154 | def is_port_open(host, port): 155 | """Check if a port is open on the given host.""" 156 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 157 | s.settimeout(1) 158 | try: 159 | return s.connect_ex((host, port)) == 0 160 | except (socket.gaierror, ConnectionRefusedError, OSError): 161 | return False 162 | 163 | 164 | @pytest.fixture(scope="session") 165 | def ydb_server(): 166 | """ 167 | Fixture to ensure YDB server is running. 168 | If YDB_ENDPOINT is not available, it starts a Docker container. 169 | """ 170 | # Parse the endpoint to extract host and port 171 | endpoint_url = urlparse(YDB_ENDPOINT) 172 | 173 | # Handle different endpoint formats 174 | if endpoint_url.scheme in ("grpc", "grpcs"): 175 | host_port = endpoint_url.netloc.split(":") 176 | host = host_port[0] 177 | port = int(host_port[1]) if len(host_port) > 1 else 2136 178 | else: 179 | # Default to localhost:2136 if we can't parse 180 | host = "localhost" 181 | port = 2136 182 | 183 | # Check if YDB is already running at the specified endpoint 184 | if is_port_open(host, port): 185 | logger.info(f"YDB server is already running at {host}:{port}") 186 | yield None 187 | return 188 | 189 | # If YDB is not running, start via docker_utils 190 | logger.info(f"YDB server not running at {host}:{port}, starting Docker container") 191 | container = start_ydb_container() 192 | # Wait for YDB readiness 193 | wait_for_port(host, port, timeout=30) 194 | time.sleep(5) 195 | yield container 196 | logger.info("Stopping YDB Docker container") 197 | stop_container(container) 198 | 199 | 200 | @pytest.fixture(scope="session") 201 | async def mcp_server(ydb_server): 202 | """Create a YDB MCP server instance for testing.""" 203 | # Create the server with anonymous credentials 204 | server = YDBMCPServer(endpoint=YDB_ENDPOINT, database=YDB_DATABASE) 205 | 206 | # Store the event loop 207 | server._loop = ensure_event_loop() 208 | 209 | try: 210 | # Initialize the server by creating the driver 211 | await server.create_driver() 212 | yield server 213 | 214 | # Clean up after tests 215 | logger.info("Cleaning up YDB server resources after tests") 216 | await cleanup_pending_tasks() 217 | if server.driver: 218 | await cleanup_driver(server.driver) 219 | 220 | except Exception as e: 221 | logger.error(f"Failed to initialize YDB MCP server: {e}") 222 | pytest.fail(f"Failed to initialize YDB MCP server: {e}") 223 | finally: 224 | # Final cleanup 225 | await cleanup_pending_tasks() 226 | 227 | 228 | # Create a global variable to cache the server instance 229 | _mcp_server_instance = None 230 | 231 | 232 | @pytest.fixture(scope="session") 233 | async def session_mcp_server(ydb_server): 234 | """Create a YDB MCP server instance once per test session and cache it.""" 235 | global _mcp_server_instance 236 | 237 | if _mcp_server_instance is None: 238 | # Create the server with anonymous credentials 239 | _mcp_server_instance = YDBMCPServer( 240 | endpoint=YDB_ENDPOINT, database=YDB_DATABASE, auth_mode=AUTH_MODE_ANONYMOUS 241 | ) 242 | 243 | try: 244 | # Ensure we have a valid event loop 245 | _mcp_server_instance._loop = ensure_event_loop() 246 | 247 | # Initialize the server by creating the driver 248 | await _mcp_server_instance.create_driver() 249 | except Exception as e: 250 | logger.error(f"Failed to initialize YDB MCP server: {e}") 251 | pytest.fail(f"Failed to initialize YDB MCP server: {e}") 252 | yield None 253 | return 254 | 255 | yield _mcp_server_instance 256 | 257 | # Clean up after all tests 258 | if _mcp_server_instance is not None: 259 | logger.info("Cleaning up YDB server resources after test session") 260 | try: 261 | # Clean up pending tasks first 262 | await cleanup_pending_tasks() 263 | 264 | # Clean up the driver with extended timeout 265 | if _mcp_server_instance.driver: 266 | await cleanup_driver(_mcp_server_instance.driver, timeout=10) 267 | 268 | # Clear the instance 269 | _mcp_server_instance = None 270 | 271 | # Force garbage collection to help clean up any remaining references 272 | gc.collect() 273 | except Exception as e: 274 | logger.error(f"Error during session cleanup: {e}") 275 | finally: 276 | # Final cleanup attempt for any remaining tasks 277 | await cleanup_pending_tasks() 278 | 279 | 280 | @pytest.fixture(scope="function") 281 | async def mcp_server(session_mcp_server): # noqa: F811 282 | """Provide a clean MCP server connection for each test by restarting the connection.""" 283 | if session_mcp_server is None: 284 | pytest.fail("Could not get a valid MCP server instance") 285 | return 286 | 287 | # Reset server state to default 288 | session_mcp_server.auth_mode = AUTH_MODE_ANONYMOUS 289 | session_mcp_server.login = None 290 | session_mcp_server.password = None 291 | 292 | try: 293 | # Clean up any leftover tasks before restart 294 | await cleanup_pending_tasks() 295 | 296 | # Restart the connection to ensure clean environment for the test 297 | if session_mcp_server.driver is not None: 298 | logger.info("Restarting YDB connection for clean test environment") 299 | await session_mcp_server.restart() 300 | 301 | yield session_mcp_server 302 | 303 | except Exception as e: 304 | logger.error(f"Error during test setup: {e}") 305 | pytest.fail(f"Failed to setup test environment: {e}") 306 | finally: 307 | # Reset server state after test 308 | try: 309 | session_mcp_server.auth_mode = AUTH_MODE_ANONYMOUS 310 | session_mcp_server.login = None 311 | session_mcp_server.password = None 312 | 313 | # Clean up any tasks from the test 314 | await cleanup_pending_tasks() 315 | 316 | # Restart to clean state 317 | await session_mcp_server.restart() 318 | except Exception as e: 319 | logger.error(f"Error during test cleanup: {e}") 320 | 321 | 322 | async def call_mcp_tool(mcp_server, tool_name, **params): 323 | """Helper function to call an MCP tool and return its result in JSON format. 324 | 325 | Args: 326 | mcp_server: The MCP server instance 327 | tool_name: Name of the tool to call 328 | **params: Parameters to pass to the tool 329 | 330 | Returns: 331 | The parsed result from the tool call 332 | """ 333 | # Call the tool 334 | result = await mcp_server.call_tool(tool_name, params) 335 | 336 | # If the result is a list of TextContent objects, convert them to a more usable format 337 | if isinstance(result, list) and len(result) > 0 and hasattr(result[0], "text"): 338 | try: 339 | # Parse the JSON text from the TextContent 340 | parsed_result = json.loads(result[0].text) 341 | 342 | # For backward compatibility with tests, if there's an error key, return it directly 343 | if "error" in parsed_result: 344 | return parsed_result 345 | 346 | # For query results, return the result_sets directly if present 347 | if "result_sets" in parsed_result: 348 | return parsed_result 349 | 350 | # For other responses (list_directory, describe_path), return the parsed JSON 351 | return parsed_result 352 | 353 | except json.JSONDecodeError as e: 354 | logger.error(f"Failed to parse JSON response: {e}") 355 | return {"error": str(e)} 356 | 357 | return result 358 | 359 | 360 | @pytest.fixture(autouse=True, scope="session") 361 | async def cleanup_after_all_tests(): 362 | """Cleanup fixture that runs after all tests to ensure proper cleanup.""" 363 | # Setup - nothing to do 364 | yield 365 | 366 | # Cleanup after all tests 367 | await cleanup_pending_tasks() 368 | 369 | # Close any remaining event loops 370 | try: 371 | loop = asyncio.get_running_loop() 372 | if not loop.is_closed(): 373 | # Cancel all tasks 374 | pending = [ 375 | task 376 | for task in asyncio.all_tasks(loop) 377 | if not task.done() and task != asyncio.current_task() 378 | ] 379 | 380 | if pending: 381 | logger.debug(f"Cleaning up {len(pending)} pending tasks in final cleanup") 382 | for task in pending: 383 | if not task.done() and not task.cancelled(): 384 | task.cancel() 385 | with suppress(asyncio.CancelledError, Exception): 386 | # Add a timeout to avoid hanging 387 | try: 388 | await asyncio.wait_for(task, timeout=1.0) 389 | except asyncio.TimeoutError: 390 | pass 391 | 392 | # Ensure all tasks are truly done 393 | for task in pending: 394 | if not task.done(): 395 | with suppress(asyncio.CancelledError, Exception): 396 | task._log_destroy_pending = ( 397 | False # Suppress the warning about task destruction 398 | ) 399 | 400 | # Close the loop 401 | loop.stop() 402 | loop.close() 403 | except RuntimeError: 404 | pass # No running event loop 405 | -------------------------------------------------------------------------------- /tests/integration/test_authentication_integration.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import os 5 | import random 6 | import string 7 | import time 8 | import warnings 9 | 10 | import pytest 11 | 12 | # Fixtures are automatically imported by pytest from conftest.py 13 | from tests.integration.conftest import call_mcp_tool 14 | from ydb_mcp.server import AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD 15 | 16 | # Suppress the utcfromtimestamp deprecation warning from the YDB library 17 | warnings.filterwarnings("ignore", message="datetime.datetime.utcfromtimestamp.*", category=DeprecationWarning) 18 | 19 | # Table name used for tests - using timestamp to avoid conflicts 20 | TEST_TABLE = f"mcp_integration_test_{int(time.time())}" 21 | 22 | # Set up logging 23 | logging.basicConfig(level=logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | # Use loop_scope instead of scope for the asyncio marker 27 | pytestmark = [pytest.mark.integration, pytest.mark.asyncio] 28 | 29 | 30 | async def test_login_password_authentication(mcp_server): 31 | """Test authentication with login and password.""" 32 | # Generate random login-password pair 33 | test_login = "test" + "".join(random.choice(string.ascii_lowercase) for _ in range(10)) 34 | test_password = f"test_pwd_{os.urandom(4).hex()}" 35 | wrong_password = f"wrong_pwd_{os.urandom(4).hex()}" 36 | 37 | try: 38 | # Create test user with anonymous auth (fixture ensures we start with anonymous auth) 39 | logger.debug(f"Creating test user {test_login}") 40 | result = await call_mcp_tool( 41 | mcp_server, "ydb_query", sql=f"CREATE USER {test_login} PASSWORD '{test_password}';" 42 | ) 43 | assert "error" not in result, f"Error creating user: {result}" 44 | 45 | # Test with correct credentials 46 | logger.debug(f"Testing with correct credentials for user {test_login}") 47 | mcp_server.auth_mode = AUTH_MODE_LOGIN_PASSWORD 48 | mcp_server.login = test_login 49 | mcp_server.password = test_password 50 | 51 | # Wait a bit for user creation to propagate 52 | await asyncio.sleep(1) 53 | 54 | await mcp_server.restart() 55 | 56 | # Verify we can execute a query 57 | result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1+1 as result") 58 | # Parse the JSON from the 'text' field if present 59 | if isinstance(result, list) and len(result) > 0 and isinstance(result[0], dict) and "text" in result[0]: 60 | parsed = json.loads(result[0]["text"]) 61 | else: 62 | parsed = result 63 | assert "result_sets" in parsed, f"No result_sets in response: {result}" 64 | assert parsed["result_sets"][0]["rows"][0][0] == 2, f"Unexpected result value: {result}" 65 | 66 | # Test with incorrect password 67 | logger.debug(f"Testing with incorrect password for user {test_login}") 68 | mcp_server.password = wrong_password 69 | 70 | # Restart should succeed but queries should fail 71 | await mcp_server.restart() 72 | 73 | # Query should fail with auth error 74 | result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1+1 as result") 75 | # Parse the JSON from the 'text' field if present 76 | if isinstance(result, list) and len(result) > 0 and isinstance(result[0], dict) and "text" in result[0]: 77 | parsed = json.loads(result[0]["text"]) 78 | else: 79 | parsed = result 80 | assert "error" in parsed, f"Expected error with invalid password, got: {parsed}" 81 | 82 | error_msg = parsed.get("error", "").lower() 83 | logger.debug(f"Got error message: {error_msg}") 84 | 85 | # Check for both connection and auth error messages since YDB might return either 86 | auth_keywords = [ 87 | "auth", 88 | "password", 89 | "login", 90 | "credential", 91 | "permission", 92 | "unauthorized", 93 | "invalid", 94 | ] 95 | conn_keywords = ["connecting to ydb", "error connecting", "connection failed"] 96 | all_keywords = auth_keywords + conn_keywords 97 | 98 | if error_msg.strip() == "": 99 | # Allow empty error message as valid 100 | pass 101 | else: 102 | assert any(keyword in error_msg for keyword in all_keywords), ( 103 | f"Unexpected error message: {parsed.get('error')}" 104 | ) 105 | 106 | finally: 107 | # Switch back to anonymous auth to clean up (fixture will handle final state reset) 108 | logger.debug(f"Cleaning up - dropping test user {test_login}") 109 | mcp_server.auth_mode = AUTH_MODE_ANONYMOUS 110 | mcp_server.login = None 111 | mcp_server.password = None 112 | 113 | await mcp_server.restart() 114 | 115 | # Drop the test user 116 | result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP USER {test_login};") 117 | if "error" in result: 118 | logger.error(f"Error dropping user: {result}") 119 | -------------------------------------------------------------------------------- /tests/integration/test_mcp_server_integration.py: -------------------------------------------------------------------------------- 1 | """Integration tests for YDB MCP server. 2 | 3 | These tests assume a YDB server is running at localhost:2136 with a database /local, 4 | or will create a Docker container with YDB if none is available. 5 | They directly test the YDBMCPServer methods without using HTTP. 6 | """ 7 | 8 | import datetime 9 | import json 10 | import logging 11 | import time 12 | import warnings 13 | 14 | import pytest 15 | 16 | # Fixtures are automatically imported by pytest from conftest.py 17 | from tests.integration.conftest import call_mcp_tool 18 | 19 | # Suppress the utcfromtimestamp deprecation warning from the YDB library 20 | warnings.filterwarnings("ignore", message="datetime.datetime.utcfromtimestamp.*", category=DeprecationWarning) 21 | 22 | # Table name used for tests - using timestamp to avoid conflicts 23 | TEST_TABLE = f"mcp_integration_test_{int(time.time())}" 24 | 25 | # Set up logging 26 | logging.basicConfig(level=logging.INFO) 27 | logger = logging.getLogger(__name__) 28 | 29 | # Use loop_scope instead of scope for the asyncio marker 30 | pytestmark = [pytest.mark.integration, pytest.mark.asyncio] 31 | 32 | 33 | async def test_simple_query(mcp_server): 34 | """Test a basic YDB query.""" 35 | result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1+1 as result") 36 | 37 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 38 | f"Result should be a list of dicts with 'text': {result}" 39 | ) 40 | parsed = json.loads(result[0]["text"]) 41 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 42 | 43 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" 44 | 45 | first_result = parsed["result_sets"][0] 46 | assert "columns" in first_result, f"No columns in result: {first_result}" 47 | assert "rows" in first_result, f"No rows in result: {first_result}" 48 | assert len(first_result["rows"]) > 0, f"Empty result set: {first_result}" 49 | assert first_result["columns"][0] == "result", f"Unexpected column name: {first_result['columns'][0]}" 50 | assert first_result["rows"][0][0] == 2, f"Unexpected result value: {first_result['rows'][0][0]}" 51 | 52 | 53 | async def test_create_table_and_query(mcp_server): 54 | """Test creating a table and executing queries against it.""" 55 | # Generate a unique table name to avoid conflicts with other tests 56 | test_table_name = f"temp_test_table_{int(time.time())}" 57 | 58 | try: 59 | # Create table 60 | create_result = await call_mcp_tool( 61 | mcp_server, 62 | "ydb_query", 63 | sql=f""" 64 | CREATE TABLE {test_table_name} ( 65 | id Uint64, 66 | name Utf8, 67 | PRIMARY KEY (id) 68 | ); 69 | """, 70 | ) 71 | assert isinstance(create_result, list) and len(create_result) > 0 and "text" in create_result[0], ( 72 | f"Result should be a list of dicts with 'text': {create_result}" 73 | ) 74 | parsed = json.loads(create_result[0]["text"]) 75 | assert "error" not in parsed, f"Error creating table: {parsed}" 76 | 77 | # Insert data 78 | insert_result = await call_mcp_tool( 79 | mcp_server, 80 | "ydb_query", 81 | sql=f""" 82 | UPSERT INTO {test_table_name} (id, name) 83 | VALUES (1, 'Test 1'), (2, 'Test 2'), (3, 'Test 3'); 84 | """, 85 | ) 86 | assert isinstance(insert_result, list) and len(insert_result) > 0 and "text" in insert_result[0], ( 87 | f"Result should be a list of dicts with 'text': {insert_result}" 88 | ) 89 | parsed = json.loads(insert_result[0]["text"]) 90 | assert "error" not in parsed, f"Error inserting data: {parsed}" 91 | 92 | # Query data 93 | query_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"SELECT * FROM {test_table_name} ORDER BY id;") 94 | 95 | assert isinstance(query_result, list) and len(query_result) > 0 and "text" in query_result[0], ( 96 | f"Result should be a list of dicts with 'text': {query_result}" 97 | ) 98 | parsed = json.loads(query_result[0]["text"]) 99 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 100 | 101 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" 102 | 103 | first_result = parsed["result_sets"][0] 104 | assert "columns" in first_result, f"No columns in result: {first_result}" 105 | assert "rows" in first_result, f"No rows in result: {first_result}" 106 | assert len(first_result["rows"]) == 3, f"Expected 3 rows, got {len(first_result['rows'])}" 107 | 108 | # Check if 'id' and 'name' columns are present 109 | assert "id" in first_result["columns"], f"Column 'id' not found in {first_result['columns']}" 110 | assert "name" in first_result["columns"], f"Column 'name' not found in {first_result['columns']}" 111 | 112 | # Get column indexes 113 | id_idx = first_result["columns"].index("id") 114 | name_idx = first_result["columns"].index("name") 115 | 116 | # Verify values 117 | assert first_result["rows"][0][id_idx] == 1, f"Expected id=1, got {first_result['rows'][0][id_idx]}" 118 | 119 | # YDB may return strings as bytes, so handle both cases 120 | name_value = first_result["rows"][0][name_idx] 121 | if isinstance(name_value, bytes): 122 | assert name_value.decode("utf-8") == "Test 1", f"Expected name='Test 1', got {name_value.decode('utf-8')}" 123 | else: 124 | assert name_value == "Test 1", f"Expected name='Test 1', got {name_value}" 125 | 126 | finally: 127 | # Cleanup - drop the table after test 128 | cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};") 129 | logger.debug(f"Table cleanup result: {cleanup_result}") 130 | 131 | 132 | async def test_parameterized_query(mcp_server): 133 | """Test a parameterized query using the parameters feature of YDB.""" 134 | # Test with a simple parameterized query 135 | result = await call_mcp_tool( 136 | mcp_server, 137 | "ydb_query_with_params", 138 | sql=""" 139 | DECLARE $answer AS Int32; 140 | DECLARE $greeting AS Utf8; 141 | SELECT $answer as answer, $greeting as greeting 142 | """, 143 | params=json.dumps( 144 | {"answer": [-42, "Int32"], "greeting": "hello"} # Explicitly specify the type as Int32 145 | ), 146 | ) 147 | 148 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 149 | f"Result should be a list of dicts with 'text': {result}" 150 | ) 151 | parsed = json.loads(result[0]["text"]) 152 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 153 | 154 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" 155 | 156 | first_result = parsed["result_sets"][0] 157 | assert "columns" in first_result, f"No columns in result: {first_result}" 158 | assert "rows" in first_result, f"No rows in result: {first_result}" 159 | assert len(first_result["rows"]) > 0, f"Empty result set: {first_result}" 160 | 161 | # Check column names 162 | assert "answer" in first_result["columns"], f"Expected 'answer' column in result: {first_result['columns']}" 163 | assert "greeting" in first_result["columns"], f"Expected 'greeting' column in result: {first_result['columns']}" 164 | 165 | # Check values 166 | answer_idx = first_result["columns"].index("answer") 167 | greeting_idx = first_result["columns"].index("greeting") 168 | 169 | assert first_result["rows"][0][answer_idx] == -42, f"Expected answer=-42, got {first_result['rows'][0][answer_idx]}" 170 | 171 | # YDB may return strings either as bytes or as strings depending on context 172 | greeting_value = first_result["rows"][0][greeting_idx] 173 | if isinstance(greeting_value, bytes): 174 | # If bytes, decode to string 175 | assert greeting_value.decode("utf-8") == "hello", ( 176 | f"Expected greeting to decode to 'hello', got {greeting_value.decode('utf-8')}" 177 | ) 178 | else: 179 | # If already string 180 | assert greeting_value == "hello", f"Expected greeting to be 'hello', got {greeting_value}" 181 | 182 | 183 | async def test_complex_query(mcp_server): 184 | """Test a more complex query with multiple result sets.""" 185 | result = await call_mcp_tool( 186 | mcp_server, 187 | "ydb_query", 188 | sql=""" 189 | SELECT 1 as value; 190 | SELECT 'test' as text, 2.5 as number; 191 | """, 192 | ) 193 | 194 | # Check for result sets 195 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 196 | f"Result should be a list of dicts with 'text': {result}" 197 | ) 198 | parsed = json.loads(result[0]["text"]) 199 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 200 | 201 | # Verify first result set 202 | first_result = parsed["result_sets"][0] 203 | assert "columns" in first_result 204 | assert first_result["columns"][0] == "value" 205 | assert first_result["rows"][0][0] == 1 206 | 207 | # Verify second result set 208 | second_result = parsed["result_sets"][1] 209 | assert "columns" in second_result 210 | assert len(second_result["columns"]) == 2 211 | assert second_result["columns"][0] == "text" 212 | assert second_result["columns"][1] == "number" 213 | # The text value is returned as a binary string, need to handle this 214 | text_value = second_result["rows"][0][0] 215 | if isinstance(text_value, bytes): 216 | assert text_value.decode("utf-8") == "test" 217 | else: 218 | assert text_value == "test" 219 | assert second_result["rows"][0][1] == 2.5 220 | 221 | 222 | async def test_multiple_resultsets_with_tables(mcp_server): 223 | """Test multiple result sets involving table creation and queries.""" 224 | # Create unique table names with timestamp to avoid conflicts 225 | test_table1 = f"temp_test_table1_{int(time.time())}" 226 | test_table2 = f"temp_test_table2_{int(time.time())}" 227 | 228 | try: 229 | # First, create tables - schema operations need to be separate 230 | setup_result = await call_mcp_tool( 231 | mcp_server, 232 | "ydb_query", 233 | sql=f""" 234 | CREATE TABLE {test_table1} (id Uint64, name Utf8, PRIMARY KEY (id)); 235 | CREATE TABLE {test_table2} (id Uint64, value Double, PRIMARY KEY (id)); 236 | """, 237 | ) 238 | assert isinstance(setup_result, list) and len(setup_result) > 0 and "text" in setup_result[0], ( 239 | f"Result should be a list of dicts with 'text': {setup_result}" 240 | ) 241 | parsed = json.loads(setup_result[0]["text"]) 242 | assert "error" not in parsed, f"Error creating tables: {parsed}" 243 | 244 | # Then insert data - separate operation 245 | insert_result = await call_mcp_tool( 246 | mcp_server, 247 | "ydb_query", 248 | sql=f""" 249 | UPSERT INTO {test_table1} (id, name) VALUES (1, 'First'), (2, 'Second'), (3, 'Third'); 250 | """, 251 | ) 252 | assert isinstance(insert_result, list) and len(insert_result) > 0 and "text" in insert_result[0], ( 253 | f"Result should be a list of dicts with 'text': {insert_result}" 254 | ) 255 | parsed = json.loads(insert_result[0]["text"]) 256 | assert "error" not in parsed, f"Error inserting data into table1: {parsed}" 257 | 258 | insert_result2 = await call_mcp_tool( 259 | mcp_server, 260 | "ydb_query", 261 | sql=f""" 262 | UPSERT INTO {test_table2} (id, value) VALUES (1, 10.5), (2, 20.75), (3, 30.25); 263 | """, 264 | ) 265 | assert isinstance(insert_result2, list) and len(insert_result2) > 0 and "text" in insert_result2[0], ( 266 | f"Result should be a list of dicts with 'text': {insert_result2}" 267 | ) 268 | parsed = json.loads(insert_result2[0]["text"]) 269 | assert "error" not in parsed, f"Error inserting data into table2: {parsed}" 270 | 271 | # Now query both tables in a single request to test multiple result sets 272 | result = await call_mcp_tool( 273 | mcp_server, 274 | "ydb_query", 275 | sql=f""" 276 | SELECT * FROM {test_table1} ORDER BY id; 277 | SELECT * FROM {test_table2} ORDER BY id; 278 | """, 279 | ) 280 | 281 | # Verify we have all result sets 282 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 283 | f"Result should be a list of dicts with 'text': {result}" 284 | ) 285 | parsed = json.loads(result[0]["text"]) 286 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 287 | 288 | # Check first table results 289 | first_result = parsed["result_sets"][0] 290 | assert len(first_result["rows"]) == 3, "Expected 3 rows in first table" 291 | assert "id" in first_result["columns"], f"Expected 'id' column in first table, got {first_result['columns']}" 292 | assert "name" in first_result["columns"], f"Expected 'name' column in first table, got {first_result['columns']}" 293 | 294 | # Check second table results 295 | second_result = parsed["result_sets"][1] 296 | assert len(second_result["rows"]) == 3, "Expected 3 rows in second table" 297 | assert "id" in second_result["columns"], f"Expected 'id' column in second table, got {second_result['columns']}" 298 | assert "value" in second_result["columns"], ( 299 | f"Expected 'value' column in second table, got {second_result['columns']}" 300 | ) 301 | 302 | # Now test a join query - should return a single result set 303 | join_result = await call_mcp_tool( 304 | mcp_server, 305 | "ydb_query", 306 | sql=f""" 307 | SELECT t1.id, t1.name, t2.value 308 | FROM {test_table1} t1 309 | JOIN {test_table2} t2 ON t1.id = t2.id 310 | ORDER BY t1.id; 311 | """, 312 | ) 313 | 314 | # Validate join results 315 | assert isinstance(join_result, list) and len(join_result) > 0 and "text" in join_result[0], ( 316 | f"Result should be a list of dicts with 'text': {join_result}" 317 | ) 318 | parsed = json.loads(join_result[0]["text"]) 319 | assert "result_sets" in parsed, "Join query should return result_sets" 320 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set for join, got {len(parsed['result_sets'])}" 321 | 322 | first_join_result = parsed["result_sets"][0] 323 | assert "columns" in first_join_result, "Join query should return columns" 324 | assert "rows" in first_join_result, "Join query should return rows" 325 | assert len(first_join_result["rows"]) == 3, ( 326 | f"Expected 3 rows in join result, got {len(first_join_result['rows'])}" 327 | ) 328 | assert len(first_join_result["columns"]) == 3, ( 329 | f"Expected 3 columns in join result, got {len(first_join_result['columns'])}" 330 | ) 331 | 332 | finally: 333 | # Cleanup - drop the tables after test 334 | try: 335 | cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table1};") 336 | logger.debug(f"Table1 cleanup result: {cleanup_result}") 337 | 338 | cleanup_result2 = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table2};") 339 | logger.debug(f"Table2 cleanup result: {cleanup_result2}") 340 | except Exception as e: 341 | logger.warning(f"Failed to clean up test tables: {e}") 342 | 343 | 344 | async def test_single_resultset_format(mcp_server): 345 | """Test that single result set queries use the new format with result_sets list.""" 346 | result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 42 as answer") 347 | 348 | # Single result set should have result_sets key with one item 349 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 350 | f"Result should be a list of dicts with 'text': {result}" 351 | ) 352 | parsed = json.loads(result[0]["text"]) 353 | assert "result_sets" in parsed, "Single result should include result_sets key" 354 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" 355 | 356 | first_result = parsed["result_sets"][0] 357 | assert "columns" in first_result, "Should have columns in result set" 358 | assert "rows" in first_result, "Should have rows in result set" 359 | assert first_result["columns"][0] == "answer", "Expected 'answer' column" 360 | assert first_result["rows"][0][0] == 42, "Expected value 42" 361 | 362 | 363 | async def test_data_types(mcp_server): 364 | """Test a very basic query with simple parameter types.""" 365 | 366 | # Execute a simple query that doesn't use complex parameter types 367 | result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1 AS value, 'test' AS text") 368 | 369 | # Basic result checks 370 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 371 | f"Result should be a list of dicts with 'text': {result}" 372 | ) 373 | parsed = json.loads(result[0]["text"]) 374 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 375 | 376 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" 377 | 378 | first_result = parsed["result_sets"][0] 379 | assert "columns" in first_result, f"No columns in result: {first_result}" 380 | assert "rows" in first_result, f"No rows in result: {first_result}" 381 | assert len(first_result["columns"]) == 2, f"Expected 2 columns, got {len(first_result['columns'])}" 382 | assert len(first_result["rows"]) == 1, f"Expected 1 row, got {len(first_result['rows'])}" 383 | 384 | # Verify column names 385 | assert "value" in first_result["columns"], f"Expected column 'value', got {first_result['columns']}" 386 | assert "text" in first_result["columns"], f"Expected column 'text', got {first_result['columns']}" 387 | 388 | # Verify values 389 | row = first_result["rows"][0] 390 | value_idx = first_result["columns"].index("value") 391 | text_idx = first_result["columns"].index("text") 392 | 393 | assert row[value_idx] == 1, f"Expected value=1, got {row[value_idx]}" 394 | 395 | # Accept both bytes and string for text columns 396 | if isinstance(row[text_idx], bytes): 397 | assert row[text_idx] == b"test", f"Expected text=b'test', got {row[text_idx]}" 398 | else: 399 | assert row[text_idx] == "test", f"Expected text='test', got {row[text_idx]}" 400 | 401 | 402 | async def test_all_data_types(mcp_server): 403 | """Test all supported YDB data types to ensure proper round-trip processing.""" 404 | 405 | # Construct a query with literals of all supported data types 406 | result = await call_mcp_tool( 407 | mcp_server, 408 | "ydb_query", 409 | sql=""" 410 | SELECT 411 | -- Boolean type 412 | true AS bool_true, 413 | false AS bool_false, 414 | 415 | -- Integer types (signed) 416 | -128 AS int8_min, 417 | 127 AS int8_max, 418 | -32768 AS int16_min, 419 | 32767 AS int16_max, 420 | -2147483648 AS int32_min, 421 | 2147483647 AS int32_max, 422 | -9223372036854775808 AS int64_min, 423 | 9223372036854775807 AS int64_max, 424 | 425 | -- Integer types (unsigned) 426 | 0 AS uint8_min, 427 | 255 AS uint8_max, 428 | 0 AS uint16_min, 429 | 65535 AS uint16_max, 430 | 0 AS uint32_min, 431 | 4294967295 AS uint32_max, 432 | 0 AS uint64_min, 433 | 18446744073709551615 AS uint64_max, 434 | 435 | -- Floating point types 436 | 3.14 AS float_value, 437 | 2.7182818284590452 AS double_value, 438 | 439 | -- String types 440 | "Hello, World!" AS string_value, 441 | "UTF8 строка" AS utf8_value, 442 | "00000000-0000-0000-0000-000000000000" AS uuid_value, 443 | '{"key": "value"}' AS json_value, 444 | 445 | -- Date and time types 446 | Date("2023-07-15") AS date_value, 447 | Datetime("2023-07-15T12:30:45Z") AS datetime_value, 448 | Timestamp("2023-07-15T12:30:45.123456Z") AS timestamp_value, 449 | INTERVAL("P1DT2H3M4.567S") AS interval_value, 450 | 451 | -- Decimal 452 | CAST("123.456789" AS Decimal(22,9)) AS decimal_value, 453 | 454 | -- Container types 455 | -- List containers 456 | AsList(1, 2, 3) AS int_list, 457 | AsList("a", "b", "c") AS string_list, 458 | 459 | -- Struct containers (similar to tuples) 460 | AsStruct(1 AS a, "x" AS b) AS simple_struct, 461 | 462 | -- Dictionary containers 463 | AsDict( 464 | AsTuple("key1", 1), 465 | AsTuple("key2", 2), 466 | AsTuple("key3", 3) 467 | ) AS string_to_int_dict, 468 | 469 | -- Nested containers - list of structs 470 | AsList( 471 | AsStruct(1 AS id, "Alice" AS name), 472 | AsStruct(2 AS id, "Bob" AS name), 473 | AsStruct(3 AS id, "Charlie" AS name) 474 | ) AS list_of_structs, 475 | 476 | -- Nested containers - struct with list 477 | AsStruct( 478 | "users" AS collection_name, 479 | AsList(1, 2, 3) AS ids, 480 | true AS active 481 | ) AS struct_with_list, 482 | 483 | -- Dict with complex values 484 | AsDict( 485 | AsTuple("person1", AsStruct(1 AS id, "Alice" AS name, AsList(25, 30, 28) AS scores)), 486 | AsTuple("person2", AsStruct(2 AS id, "Bob" AS name, AsList(22, 27, 29) AS scores)) 487 | ) AS complex_dict, 488 | 489 | -- Triple-nested container: list of structs with lists 490 | AsList( 491 | AsStruct( 492 | 1 AS id, 493 | "Team A" AS name, 494 | AsList("Alice", "Bob") AS members 495 | ), 496 | AsStruct( 497 | 2 AS id, 498 | "Team B" AS name, 499 | AsList("Charlie", "David") AS members 500 | ) 501 | ) AS nested_list_struct_list, 502 | 503 | -- Tuple containers 504 | AsTuple(1, "a", true) AS mixed_tuple 505 | """, 506 | ) 507 | 508 | # Basic result checks 509 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 510 | f"Result should be a list of dicts with 'text': {result}" 511 | ) 512 | parsed = json.loads(result[0]["text"]) 513 | assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" 514 | 515 | assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" 516 | 517 | first_result = parsed["result_sets"][0] 518 | assert "columns" in first_result, f"No columns in result: {first_result}" 519 | assert "rows" in first_result, f"No rows in result: {first_result}" 520 | assert len(first_result["rows"]) == 1, f"Expected 1 row, got {len(first_result['rows'])}" 521 | 522 | # Get the row 523 | row = first_result["rows"][0] 524 | 525 | # Helper function to get column index and value 526 | def get_value(column_name): 527 | try: 528 | idx = first_result["columns"].index(column_name) 529 | return row[idx] 530 | except ValueError: 531 | return None 532 | 533 | # Test each data type 534 | # Boolean values 535 | assert get_value("bool_true") is True, f"Expected bool_true to be True, got {get_value('bool_true')}" 536 | assert get_value("bool_false") is False, f"Expected bool_false to be False, got {get_value('bool_false')}" 537 | 538 | # Integer types (signed) 539 | assert get_value("int8_min") == -128, f"Expected int8_min to be -128, got {get_value('int8_min')}" 540 | assert get_value("int8_max") == 127, f"Expected int8_max to be 127, got {get_value('int8_max')}" 541 | assert get_value("int16_min") == -32768, f"Expected int16_min to be -32768, got {get_value('int16_min')}" 542 | assert get_value("int16_max") == 32767, f"Expected int16_max to be 32767, got {get_value('int16_max')}" 543 | assert get_value("int32_min") == -2147483648, f"Expected int32_min to be -2147483648, got {get_value('int32_min')}" 544 | assert get_value("int32_max") == 2147483647, f"Expected int32_max to be 2147483647, got {get_value('int32_max')}" 545 | assert get_value("int64_min") == -9223372036854775808, ( 546 | f"Expected int64_min to be -9223372036854775808, got {get_value('int64_min')}" 547 | ) 548 | assert get_value("int64_max") == 9223372036854775807, ( 549 | f"Expected int64_max to be 9223372036854775807, got {get_value('int64_max')}" 550 | ) 551 | 552 | # Integer types (unsigned) 553 | assert get_value("uint8_min") == 0, f"Expected uint8_min to be 0, got {get_value('uint8_min')}" 554 | assert get_value("uint8_max") == 255, f"Expected uint8_max to be 255, got {get_value('uint8_max')}" 555 | assert get_value("uint16_min") == 0, f"Expected uint16_min to be 0, got {get_value('uint16_min')}" 556 | assert get_value("uint16_max") == 65535, f"Expected uint16_max to be 65535, got {get_value('uint16_max')}" 557 | assert get_value("uint32_min") == 0, f"Expected uint32_min to be 0, got {get_value('uint32_min')}" 558 | assert get_value("uint32_max") == 4294967295, f"Expected uint32_max to be 4294967295, got {get_value('uint32_max')}" 559 | assert get_value("uint64_min") == 0, f"Expected uint64_min to be 0, got {get_value('uint64_min')}" 560 | assert get_value("uint64_max") == 18446744073709551615, ( 561 | f"Expected uint64_max to be 18446744073709551615, got {get_value('uint64_max')}" 562 | ) 563 | 564 | # Floating point types 565 | assert abs(get_value("float_value") - 3.14) < 0.0001, ( 566 | f"Expected float_value to be close to 3.14, got {get_value('float_value')}" 567 | ) 568 | assert abs(get_value("double_value") - 2.7182818284590452) < 0.0000000000000001, ( 569 | f"Expected double_value to be close to 2.7182818284590452, got {get_value('double_value')}" 570 | ) 571 | 572 | # String types - expect only str, not bytes 573 | string_value = get_value("string_value") 574 | assert string_value == "Hello, World!", f"Expected string_value to be 'Hello, World!', got {string_value}" 575 | 576 | utf8_value = get_value("utf8_value") 577 | assert utf8_value == "UTF8 строка", f"Expected utf8_value to be 'UTF8 строка', got {utf8_value}" 578 | 579 | uuid_value = get_value("uuid_value") 580 | assert uuid_value == "00000000-0000-0000-0000-000000000000", ( 581 | f"Expected uuid_value to be '00000000-0000-0000-0000-000000000000', got {uuid_value}" 582 | ) 583 | 584 | json_value = get_value("json_value") 585 | assert json_value == '{"key": "value"}', f"Expected json_value to be '{{'key': 'value'}}', got {json_value}" 586 | 587 | # Date and time types - YDB returns these as Python datetime objects 588 | date_value = get_value("date_value") 589 | if isinstance(date_value, str): 590 | # Parse string to date 591 | parsed_date = datetime.date.fromisoformat(date_value) 592 | assert parsed_date == datetime.date(2023, 7, 15), f"Expected date_value to be 2023-07-15, got {parsed_date}" 593 | else: 594 | assert isinstance(date_value, datetime.date), f"Expected date_value to be datetime.date, got {type(date_value)}" 595 | assert date_value == datetime.date(2023, 7, 15), f"Expected date_value to be 2023-07-15, got {date_value}" 596 | 597 | datetime_value = get_value("datetime_value") 598 | if isinstance(datetime_value, str): 599 | # Parse string to datetime 600 | parsed_dt = datetime.datetime.fromisoformat(datetime_value.replace("Z", "+00:00")) 601 | expected_datetime = datetime.datetime(2023, 7, 15, 12, 30, 45, tzinfo=datetime.timezone.utc) 602 | if parsed_dt.tzinfo is None: 603 | parsed_dt = parsed_dt.replace(tzinfo=datetime.timezone.utc) 604 | assert parsed_dt == expected_datetime, f"Expected datetime_value to be {expected_datetime}, got {parsed_dt}" 605 | else: 606 | assert isinstance(datetime_value, datetime.datetime), ( 607 | f"Expected datetime_value to be datetime.datetime, got {type(datetime_value)}" 608 | ) 609 | expected_datetime = datetime.datetime(2023, 7, 15, 12, 30, 45, tzinfo=datetime.timezone.utc) 610 | if datetime_value.tzinfo is None: 611 | datetime_value = datetime_value.replace(tzinfo=datetime.timezone.utc) 612 | assert datetime_value == expected_datetime, ( 613 | f"Expected datetime_value to be {expected_datetime}, got {datetime_value}" 614 | ) 615 | 616 | timestamp_value = get_value("timestamp_value") 617 | if isinstance(timestamp_value, str): 618 | parsed_ts = datetime.datetime.fromisoformat(timestamp_value.replace("Z", "+00:00")) 619 | expected_timestamp = datetime.datetime(2023, 7, 15, 12, 30, 45, 123456, tzinfo=datetime.timezone.utc) 620 | if parsed_ts.tzinfo is None: 621 | parsed_ts = parsed_ts.replace(tzinfo=datetime.timezone.utc) 622 | assert parsed_ts == expected_timestamp, f"Expected timestamp_value to be {expected_timestamp}, got {parsed_ts}" 623 | else: 624 | assert isinstance(timestamp_value, datetime.datetime), ( 625 | f"Expected timestamp_value to be datetime.datetime, got {type(timestamp_value)}" 626 | ) 627 | expected_timestamp = datetime.datetime(2023, 7, 15, 12, 30, 45, 123456, tzinfo=datetime.timezone.utc) 628 | if timestamp_value.tzinfo is None: 629 | timestamp_value = timestamp_value.replace(tzinfo=datetime.timezone.utc) 630 | assert timestamp_value == expected_timestamp, ( 631 | f"Expected timestamp_value to be {expected_timestamp}, got {timestamp_value}" 632 | ) 633 | 634 | interval_value = get_value("interval_value") 635 | # Accept both string and timedelta for interval_value 636 | expected_interval = datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=567000) 637 | if isinstance(interval_value, str): 638 | # Parse string like '93784.567s' to seconds 639 | if interval_value.endswith("s"): 640 | seconds = float(interval_value[:-1]) 641 | parsed_interval = datetime.timedelta(seconds=seconds) 642 | assert parsed_interval.total_seconds() == expected_interval.total_seconds(), ( 643 | f"Expected interval_value to be {expected_interval}, got {parsed_interval}" 644 | ) 645 | else: 646 | assert False, f"Unexpected interval string format: {interval_value}" 647 | else: 648 | assert isinstance(interval_value, datetime.timedelta), ( 649 | f"Expected interval_value to be datetime.timedelta, got {type(interval_value)}" 650 | ) 651 | assert interval_value.total_seconds() == expected_interval.total_seconds(), ( 652 | f"Expected interval_value to be {expected_interval}, got {interval_value}" 653 | ) 654 | 655 | # Decimal - YDB returns Decimal objects 656 | from decimal import Decimal 657 | 658 | decimal_value = get_value("decimal_value") 659 | if isinstance(decimal_value, str): 660 | parsed_decimal = Decimal(decimal_value) 661 | assert parsed_decimal == Decimal("123.456789"), ( 662 | f"Expected decimal_value to be Decimal('123.456789'), got {parsed_decimal}" 663 | ) 664 | else: 665 | assert isinstance(decimal_value, Decimal), f"Expected decimal_value to be Decimal, got {type(decimal_value)}" 666 | assert decimal_value == Decimal("123.456789"), ( 667 | f"Expected decimal_value to be Decimal('123.456789'), got {decimal_value}" 668 | ) 669 | 670 | # Container types 671 | # List containers 672 | int_list = get_value("int_list") 673 | assert isinstance(int_list, list), f"Expected int_list to be a list, got {type(int_list)}" 674 | assert int_list == [1, 2, 3], f"Expected int_list to be [1, 2, 3], got {int_list}" 675 | 676 | string_list = get_value("string_list") 677 | assert isinstance(string_list, list), f"Expected string_list to be a list, got {type(string_list)}" 678 | expected = ["a", "b", "c"] 679 | for actual, exp in zip(string_list, expected): 680 | assert actual == exp, f"Expected {exp}, got {actual} in string_list" 681 | 682 | # Struct containers (similar to Python dictionaries) 683 | simple_struct = get_value("simple_struct") 684 | assert isinstance(simple_struct, dict), f"Expected simple_struct to be a dict, got {type(simple_struct)}" 685 | assert "a" in simple_struct and "b" in simple_struct, ( 686 | f"Expected simple_struct to have keys 'a' and 'b', got {simple_struct}" 687 | ) 688 | assert simple_struct["a"] == 1, f"Expected simple_struct['a'] to be 1, got {simple_struct['a']}" 689 | assert simple_struct["b"] == "x", f"Expected simple_struct['b'] to be 'x', got {simple_struct['b']}" 690 | 691 | # Dictionary containers 692 | string_to_int_dict = get_value("string_to_int_dict") 693 | assert isinstance(string_to_int_dict, dict), ( 694 | f"Expected string_to_int_dict to be a dict, got {type(string_to_int_dict)}" 695 | ) 696 | # Accept both string keys and stringified bytes keys 697 | expected_dict = {"key1": 1, "key2": 2, "key3": 3} 698 | expected_bytes_dict = {f"b'{k}'": v for k, v in expected_dict.items()} 699 | assert string_to_int_dict == expected_dict or string_to_int_dict == expected_bytes_dict, ( 700 | f"Expected dict to be {expected_dict} or {expected_bytes_dict}, got {string_to_int_dict}" 701 | ) 702 | 703 | # Nested containers - list of structs 704 | list_of_structs = get_value("list_of_structs") 705 | assert isinstance(list_of_structs, list), f"Expected list_of_structs to be a list, got {type(list_of_structs)}" 706 | assert len(list_of_structs) == 3, f"Expected list_of_structs to have 3 items, got {len(list_of_structs)}" 707 | 708 | # Check first item in list of structs 709 | first_struct = list_of_structs[0] 710 | assert isinstance(first_struct, dict), f"Expected first_struct to be a dict, got {type(first_struct)}" 711 | assert first_struct == { 712 | "id": 1, 713 | "name": "Alice", 714 | }, f"Expected first_struct to be {{'id': 1, 'name': 'Alice'}}, got {first_struct}" 715 | 716 | # Struct with list 717 | struct_with_list = get_value("struct_with_list") 718 | assert isinstance(struct_with_list, dict), f"Expected struct_with_list to be a dict, got {type(struct_with_list)}" 719 | assert struct_with_list == { 720 | "collection_name": "users", 721 | "ids": [1, 2, 3], 722 | "active": True, 723 | }, ( 724 | f"Expected struct_with_list to be {{'collection_name': 'users', 'ids': [1, 2, 3], 'active': True}}, " 725 | f"got {struct_with_list}" 726 | ) 727 | 728 | # Complex dict 729 | complex_dict = get_value("complex_dict") 730 | assert isinstance(complex_dict, dict), f"Expected complex_dict to be a dict, got {type(complex_dict)}" 731 | expected_complex_dict = { 732 | "person1": {"id": 1, "name": "Alice", "scores": [25, 30, 28]}, 733 | "person2": {"id": 2, "name": "Bob", "scores": [22, 27, 29]}, 734 | } 735 | expected_bytes_complex_dict = {f"b'{k}'": v for k, v in expected_complex_dict.items()} 736 | assert complex_dict == expected_complex_dict or complex_dict == expected_bytes_complex_dict, ( 737 | f"Expected complex_dict to be {expected_complex_dict} or {expected_bytes_complex_dict}, got {complex_dict}" 738 | ) 739 | 740 | # Triple-nested list 741 | nested_list = get_value("nested_list_struct_list") 742 | assert isinstance(nested_list, list), f"Expected nested_list to be a list, got {type(nested_list)}" 743 | assert len(nested_list) == 2, f"Expected nested_list to have 2 items, got {len(nested_list)}" 744 | 745 | expected_nested_list = [ 746 | {"id": 1, "name": "Team A", "members": ["Alice", "Bob"]}, 747 | {"id": 2, "name": "Team B", "members": ["Charlie", "David"]}, 748 | ] 749 | assert nested_list == expected_nested_list, f"Expected nested_list to be {expected_nested_list}, got {nested_list}" 750 | 751 | # Tuple containers 752 | mixed_tuple = get_value("mixed_tuple") 753 | assert isinstance(mixed_tuple, (list, tuple)), f"Expected mixed_tuple to be a list or tuple, got {type(mixed_tuple)}" 754 | assert len(mixed_tuple) == 3, f"Expected mixed_tuple to have 3 items, got {len(mixed_tuple)}" 755 | expected_tuple = (1, "a", True) 756 | # Convert to tuple if it's a list for comparison 757 | if isinstance(mixed_tuple, list): 758 | mixed_tuple = tuple(mixed_tuple) 759 | assert mixed_tuple == expected_tuple, f"Expected mixed_tuple to be {expected_tuple}, got {mixed_tuple}" 760 | 761 | 762 | async def test_list_directory_integration(mcp_server): 763 | """Test listing directory contents in YDB.""" 764 | # List the contents of the root directory - this should always exist 765 | result = await call_mcp_tool(mcp_server, "ydb_list_directory", path="/") 766 | 767 | # Parse the JSON result 768 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 769 | f"Result should be a list of dicts with 'text': {result}" 770 | ) 771 | parsed = json.loads(result[0]["text"]) 772 | 773 | # Verify the structure 774 | assert "path" in parsed, f"Missing 'path' field in dir_data: {parsed}" 775 | assert parsed["path"] == "/", f"Expected path to be '/', got {parsed['path']}" 776 | assert "items" in parsed, f"Missing 'items' field in dir_data: {parsed}" 777 | 778 | # Root directory should have at least some items 779 | assert isinstance(parsed["items"], list), f"Expected items to be a list, got {type(parsed['items'])}" 780 | assert len(parsed["items"]) > 0, f"Expected non-empty directory, got {parsed['items']}" 781 | 782 | # Verify at least one item has expected properties 783 | assert "name" in parsed["items"][0], f"Missing 'name' field in item: {parsed['items'][0]}" 784 | assert "type" in parsed["items"][0], f"Missing 'type' field in item: {parsed['items'][0]}" 785 | assert "owner" in parsed["items"][0], f"Missing 'owner' field in item: {parsed['items'][0]}" 786 | 787 | logger.debug(f"Found {len(parsed['items'])} items in root directory") 788 | for item in parsed["items"]: 789 | logger.debug(f"Item: {item['name']}, Type: {item['type']}") 790 | 791 | 792 | async def test_list_directory_nonexistent_integration(mcp_server): 793 | """Test listing a nonexistent directory in YDB.""" 794 | # Generate a random path that should not exist 795 | nonexistent_path = f"/nonexistent_{int(time.time())}" 796 | 797 | # Try to list a nonexistent directory 798 | result = await call_mcp_tool(mcp_server, "ydb_list_directory", path=nonexistent_path) 799 | 800 | # Parse the result 801 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 802 | f"Result should be a list of dicts with 'text': {result}" 803 | ) 804 | parsed = json.loads(result[0]["text"]) 805 | 806 | # Should contain an error message 807 | assert "error" in parsed, f"Expected error message, got: {parsed}" 808 | 809 | 810 | async def test_describe_path_integration(mcp_server): 811 | """Test describing paths in YDB.""" 812 | # 1. First create a test table to describe 813 | test_table_name = f"describe_test_table_{int(time.time())}" 814 | 815 | try: 816 | # Create a table with various column types 817 | create_result = await call_mcp_tool( 818 | mcp_server, 819 | "ydb_query", 820 | sql=f""" 821 | CREATE TABLE {test_table_name} ( 822 | id Uint64, 823 | name Utf8, 824 | value Double, 825 | created Timestamp, 826 | PRIMARY KEY (id) 827 | ); 828 | """, 829 | ) 830 | assert isinstance(create_result, list) and len(create_result) > 0 and "text" in create_result[0], ( 831 | f"Result should be a list of dicts with 'text': {create_result}" 832 | ) 833 | parsed = json.loads(create_result[0]["text"]) 834 | assert "error" not in parsed, f"Error creating table: {parsed}" 835 | 836 | # Wait a moment for the table to be fully created 837 | time.sleep(1) 838 | 839 | # 2. Now describe the table 840 | result = await call_mcp_tool(mcp_server, "ydb_describe_path", path=f"/{test_table_name}") 841 | 842 | # Parse the JSON result 843 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 844 | f"Result should be a list of dicts with 'text': {result}" 845 | ) 846 | parsed = json.loads(result[0]["text"]) 847 | 848 | # Only check for path if not error 849 | if "error" not in parsed: 850 | assert parsed["path"] == f"/{test_table_name}", ( 851 | f"Expected path to be '/{test_table_name}', got {parsed['path']}" 852 | ) 853 | assert "type" in parsed, f"Missing 'type' field in path_data: {parsed}" 854 | assert parsed["type"] == "TABLE", f"Expected type to be 'TABLE', got {parsed['type']}" 855 | # Verify table information 856 | assert "table" in parsed, f"Missing 'table' field in path_data: {parsed}" 857 | 858 | finally: 859 | # Clean up - drop the table even if test fails 860 | cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};") 861 | logger.debug(f"Table cleanup result: {cleanup_result}") 862 | 863 | 864 | async def test_describe_nonexistent_path_integration(mcp_server): 865 | """Test describing a nonexistent path in YDB.""" 866 | # Generate a random path that should not exist 867 | nonexistent_path = f"/nonexistent_{int(time.time())}" 868 | 869 | # Try to describe a nonexistent path 870 | result = await call_mcp_tool(mcp_server, "ydb_describe_path", path=nonexistent_path) 871 | 872 | # Parse the result 873 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 874 | f"Result should be a list of dicts with 'text': {result}" 875 | ) 876 | parsed = json.loads(result[0]["text"]) 877 | 878 | # Should contain an error message 879 | assert "error" in parsed, f"Expected error message, got: {parsed}" 880 | 881 | 882 | async def test_ydb_status_integration(mcp_server): 883 | """Test getting YDB connection status.""" 884 | result = await call_mcp_tool(mcp_server, "ydb_status") 885 | 886 | # Parse the JSON result 887 | assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( 888 | f"Result should be a list of dicts with 'text': {result}" 889 | ) 890 | parsed = json.loads(result[0]["text"]) 891 | 892 | # Verify the structure 893 | assert "status" in parsed, f"Missing 'status' field in status_data: {parsed}" 894 | assert "ydb_endpoint" in parsed, f"Missing 'ydb_endpoint' field in status_data: {parsed}" 895 | assert "ydb_database" in parsed, f"Missing 'ydb_database' field in status_data: {parsed}" 896 | assert "auth_mode" in parsed, f"Missing 'auth_mode' field in status_data: {parsed}" 897 | assert "ydb_connection" in parsed, f"Missing 'ydb_connection' field in status_data: {parsed}" 898 | 899 | # For a successful test run, we expect to be connected 900 | assert parsed["status"] == "running", f"Expected status to be 'running', got {parsed['status']}" 901 | assert parsed["ydb_connection"] == "connected", ( 902 | f"Expected ydb_connection to be 'connected', got {parsed['ydb_connection']}" 903 | ) 904 | assert parsed["error"] is None, f"Expected no error, got: {parsed.get('error')}" 905 | 906 | logger.info(f"YDB status check successful: {parsed}") 907 | -------------------------------------------------------------------------------- /tests/integration/test_path_operations.py: -------------------------------------------------------------------------------- 1 | """Integration tests for YDB directory and table operations (list_directory, describe_path, table creation, and cleanup). 2 | 3 | These tests validate the functionality of YDB directory listing, path description, and table operations 4 | including creation and cleanup. 5 | They test real YDB interactions without mocks, requiring a running YDB instance. 6 | """ 7 | 8 | import asyncio 9 | import json 10 | import logging 11 | import os 12 | import time 13 | 14 | import pytest 15 | 16 | # Import from conftest 17 | from tests.integration.conftest import call_mcp_tool 18 | from ydb_mcp.connection import YDBConnection 19 | 20 | # Set up logging 21 | logging.basicConfig(level=logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | 24 | # Mark these tests as integration and asyncio tests 25 | pytestmark = [pytest.mark.integration, pytest.mark.asyncio] 26 | 27 | 28 | def parse_text_content(response): 29 | """Parse TextContent response into a dictionary.""" 30 | if not response: 31 | return response 32 | 33 | # Handle direct dictionary response 34 | if isinstance(response, dict): 35 | return response 36 | 37 | # Handle list of TextContent 38 | if isinstance(response, list) and len(response) > 0: 39 | if isinstance(response[0], dict) and "type" in response[0] and "text" in response[0]: 40 | try: 41 | return json.loads(response[0]["text"]) 42 | except json.JSONDecodeError: 43 | return response[0]["text"] 44 | 45 | return response 46 | 47 | 48 | async def test_list_root_directory(mcp_server): 49 | """Test listing the contents of the root directory.""" 50 | result = await call_mcp_tool(mcp_server, "ydb_list_directory", path="/") 51 | result = parse_text_content(result) 52 | assert "error" not in result, f"Error listing root directory: {result}" 53 | assert "items" in result, f"Root directory listing should contain items: {result}" 54 | assert len(result["items"]) > 0, f"Root directory should not be empty: {result}" 55 | 56 | # Verify the structure of items 57 | for item in result["items"]: 58 | assert "name" in item, f"Item should have a name: {item}" 59 | assert "type" in item, f"Item should have a type: {item}" 60 | assert "owner" in item, f"Item should have an owner: {item}" 61 | 62 | 63 | async def test_list_directory_after_table_creation(mcp_server): 64 | """Test that a newly created table appears in the directory listing.""" 65 | # Use the same logic as the server to parse endpoint and database 66 | ydb_endpoint = os.environ.get("YDB_ENDPOINT", "grpc://localhost:2136/local") 67 | conn = YDBConnection(ydb_endpoint) 68 | _, db_path = conn._parse_endpoint_and_database() 69 | 70 | # Generate a unique table name to avoid conflicts 71 | test_table_name = f"test_table_{int(time.time())}" 72 | 73 | try: 74 | # Create a new table in the current database (not as an absolute path) 75 | create_result = await call_mcp_tool( 76 | mcp_server, 77 | "ydb_query", 78 | sql=f""" 79 | CREATE TABLE {test_table_name} ( 80 | id Uint64, 81 | name Utf8, 82 | PRIMARY KEY (id) 83 | ); 84 | """, 85 | ) 86 | assert "error" not in create_result, f"Error creating table: {create_result}" 87 | logger.debug(f"Created table {test_table_name}") 88 | 89 | # Wait a moment for the table to be fully created and visible 90 | await asyncio.sleep(1) 91 | 92 | # List the database directory 93 | path = db_path 94 | found = False 95 | items = [] 96 | for _ in range(5): 97 | dir_list = await call_mcp_tool(mcp_server, "ydb_list_directory", path=path) 98 | parsed_dir = parse_text_content(dir_list) 99 | items = parsed_dir.get("items", []) if isinstance(parsed_dir, dict) else [] 100 | if any(test_table_name == item.get("name") for item in items): 101 | found = True 102 | break 103 | await asyncio.sleep(1) 104 | assert found, f"Table {test_table_name} not found in directory listing: {items}" 105 | 106 | finally: 107 | # Clean up - drop the table 108 | cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};") 109 | logger.debug(f"Table cleanup result: {cleanup_result}") 110 | 111 | 112 | async def test_path_description(mcp_server): 113 | """Test describing each item in the root directory.""" 114 | # List the root directory 115 | result = await call_mcp_tool(mcp_server, "ydb_list_directory", path="/") 116 | parsed = parse_text_content(result) 117 | assert "items" in parsed, f"Root directory listing missing items: {parsed}" 118 | # Describe each item 119 | for item in parsed["items"]: 120 | item_name = item["name"] 121 | item_path = f"/{item_name}" 122 | describe_result = await call_mcp_tool(mcp_server, "ydb_describe_path", path=item_path) 123 | path_data = parse_text_content(describe_result) 124 | assert "path" in path_data, f"Missing 'path' field in path data: {path_data}" 125 | assert path_data["path"] == item_path, f"Expected path to be '{item_path}', got {path_data['path']}" 126 | assert "type" in path_data, f"Missing 'type' field in path data: {path_data}" 127 | assert "name" in path_data, f"Missing 'name' field in path data: {path_data}" 128 | assert "owner" in path_data, f"Missing 'owner' field in path data: {path_data}" 129 | if path_data["type"] == "TABLE": 130 | assert "table" in path_data, f"Missing 'table' field for TABLE: {path_data}" 131 | assert "columns" in path_data["table"], f"Missing 'columns' field in table data: {path_data}" 132 | assert len(path_data["table"]["columns"]) > 0, f"Table should have at least one column: {path_data}" 133 | -------------------------------------------------------------------------------- /tests/mocks.py: -------------------------------------------------------------------------------- 1 | """Mock classes for testing.""" 2 | 3 | from typing import Callable, Type 4 | 5 | 6 | class MockRequestHandler: 7 | """Mock for mcp.server.handler.RequestHandler class.""" 8 | 9 | def __init__(self): 10 | """Initialize the mock handler.""" 11 | self.config = None 12 | 13 | 14 | def mock_register_handler(name: str) -> Callable[[Type], Type]: 15 | """Mock for mcp.server.handler.register_handler decorator. 16 | 17 | Args: 18 | name: Name of the handler 19 | 20 | Returns: 21 | Decorator function 22 | """ 23 | 24 | def decorator(cls): 25 | """Decorator function. 26 | 27 | Args: 28 | cls: Class to decorate 29 | 30 | Returns: 31 | The decorated class 32 | """ 33 | return cls 34 | 35 | return decorator 36 | -------------------------------------------------------------------------------- /tests/test_connection.py: -------------------------------------------------------------------------------- 1 | """Tests for YDB connection module.""" 2 | 3 | import asyncio 4 | import sys 5 | import unittest 6 | from unittest.mock import ANY, AsyncMock, MagicMock, patch 7 | 8 | # Add mocks for mcp.server.handler 9 | from tests.mocks import MockRequestHandler, mock_register_handler 10 | 11 | # Mock the imports 12 | sys.modules["mcp.server"] = MagicMock() 13 | sys.modules["mcp.server.handler"] = MagicMock() 14 | sys.modules["mcp.server.handler"].RequestHandler = MockRequestHandler 15 | sys.modules["mcp.server.handler"].register_handler = mock_register_handler 16 | 17 | # Import after mocking 18 | from ydb_mcp.connection import YDBConnection # noqa: E402 19 | 20 | 21 | class TestYDBConnection(unittest.TestCase): 22 | """Test cases for YDBConnection class.""" 23 | 24 | def test_extract_database_path(self): 25 | """Test database path extraction from connection string.""" 26 | # Test with simple format 27 | conn = YDBConnection("grpc://ydb.server:2136/local") 28 | self.assertEqual(conn._extract_database_path(), "/local") 29 | 30 | # Test with path containing multiple segments 31 | conn = YDBConnection("grpc://ydb.server:2136/my/database/path") 32 | self.assertEqual(conn._extract_database_path(), "/my/database/path") 33 | 34 | # Test with query parameters 35 | conn = YDBConnection("grpc://ydb.server:2136/local?ssl=true&timeout=60") 36 | self.assertEqual(conn._extract_database_path(), "/local") 37 | 38 | # Test with database:// prefix 39 | conn = YDBConnection("database://ydb.server:2136/local") 40 | self.assertEqual(conn._extract_database_path(), "/local") 41 | 42 | @patch("ydb.aio.Driver") 43 | async def test_connect(self, mock_driver_class): 44 | """Test connection establishment.""" 45 | # Setup mocks 46 | mock_driver = AsyncMock() 47 | mock_driver.wait = AsyncMock(return_value=True) 48 | mock_driver.discovery_debug_details = MagicMock(return_value="Resolved endpoints: localhost:2136") 49 | mock_driver_class.return_value = mock_driver 50 | 51 | with patch("ydb.aio.QuerySessionPool") as mock_session_pool_class: 52 | # Setup session pool mock 53 | mock_session_pool = MagicMock() 54 | mock_session_pool_class.return_value = mock_session_pool 55 | 56 | # Create connection and connect 57 | conn = YDBConnection("grpc://ydb.server:2136/local") 58 | driver, pool = await conn.connect() 59 | 60 | # Verify driver was created with correct parameters 61 | mock_driver_class.assert_called_once() 62 | mock_driver.wait.assert_called_once() 63 | mock_driver.discovery_debug_details.assert_called() 64 | 65 | # Verify session pool was created 66 | mock_session_pool_class.assert_called_once() 67 | 68 | # Verify driver and session pool were stored and returned 69 | assert conn.driver == mock_driver 70 | assert conn.session_pool == mock_session_pool 71 | assert driver == mock_driver 72 | assert pool == mock_session_pool 73 | 74 | # Reset mocks for next test 75 | mock_driver_class.reset_mock() 76 | mock_driver.wait.reset_mock() 77 | mock_driver.discovery_debug_details.reset_mock() 78 | 79 | @patch("ydb.aio.Driver") 80 | async def test_connect_with_database_in_endpoint(self, mock_driver_class): 81 | """Test connection with database specified in endpoint.""" 82 | # Setup mocks 83 | mock_driver = AsyncMock() 84 | mock_driver.wait = AsyncMock(return_value=True) 85 | mock_driver.discovery_debug_details = MagicMock(return_value="Resolved endpoints: localhost:2136") 86 | mock_driver_class.return_value = mock_driver 87 | 88 | with patch("ydb.aio.QuerySessionPool") as mock_session_pool_class: 89 | # Setup session pool mock 90 | mock_session_pool = MagicMock() 91 | mock_session_pool_class.return_value = mock_session_pool 92 | 93 | # Test cases for different endpoint formats 94 | test_cases = [ 95 | ("grpc://ydb.server:2136/local", "grpc://ydb.server:2136", "/local"), 96 | ("grpcs://ydb.server:2136/local/test", "grpcs://ydb.server:2136", "/local/test"), 97 | ("ydb.server:2136/local", "grpc://ydb.server:2136", "/local"), 98 | ("grpc://ydb.server:2136/local", "grpc://ydb.server:2136", "/local"), 99 | ] 100 | 101 | for endpoint, expected_endpoint, expected_database in test_cases: 102 | # Create connection and connect 103 | conn = YDBConnection(endpoint) 104 | await conn.connect() 105 | 106 | # Verify driver was created with correct parameters 107 | mock_driver_class.assert_called_with( 108 | endpoint=expected_endpoint, database=expected_database, credentials=ANY 109 | ) 110 | mock_driver.wait.assert_called_once() 111 | mock_driver.discovery_debug_details.assert_called() 112 | 113 | # Reset mock call count 114 | mock_driver_class.reset_mock() 115 | mock_driver.wait.reset_mock() 116 | mock_driver.discovery_debug_details.reset_mock() 117 | 118 | @patch("ydb.aio.Driver") 119 | async def test_connect_with_explicit_database(self, mock_driver_class): 120 | """Test connection with explicitly provided database.""" 121 | # Setup mocks 122 | mock_driver = AsyncMock() 123 | mock_driver.wait = AsyncMock(return_value=True) 124 | mock_driver.discovery_debug_details = MagicMock(return_value="Resolved endpoints: localhost:2136") 125 | mock_driver_class.return_value = mock_driver 126 | 127 | with patch("ydb.aio.QuerySessionPool") as mock_session_pool_class: 128 | # Setup session pool mock 129 | mock_session_pool = MagicMock() 130 | mock_session_pool_class.return_value = mock_session_pool 131 | 132 | # Test cases for different endpoint formats with explicit database 133 | test_cases = [ 134 | ( 135 | "grpc://ydb.server:2136/local", 136 | "/explicit", 137 | "grpc://ydb.server:2136", 138 | "/explicit", 139 | ), 140 | ( 141 | "grpcs://ydb.server:2136/local", 142 | "explicit", 143 | "grpcs://ydb.server:2136", 144 | "/explicit", 145 | ), 146 | ("ydb.server:2136/local", "/other", "grpc://ydb.server:2136", "/other"), 147 | ] 148 | 149 | for endpoint, database, expected_endpoint, expected_database in test_cases: 150 | # Create connection and connect 151 | conn = YDBConnection(endpoint, database=database) 152 | await conn.connect() 153 | 154 | # Verify driver was created with correct parameters 155 | mock_driver_class.assert_called_with( 156 | endpoint=expected_endpoint, database=expected_database, credentials=ANY 157 | ) 158 | mock_driver.wait.assert_called_once() 159 | mock_driver.discovery_debug_details.assert_called() 160 | 161 | # Reset mock call count 162 | mock_driver_class.reset_mock() 163 | mock_driver.wait.reset_mock() 164 | mock_driver.discovery_debug_details.reset_mock() 165 | 166 | 167 | # Allow tests to run with asyncio 168 | def run_async_test(test_case, test_func): 169 | """Run an async test function.""" 170 | loop = asyncio.get_event_loop() 171 | loop.run_until_complete(test_func(test_case)) 172 | 173 | 174 | # Patch test methods to run with asyncio 175 | for method_name in dir(TestYDBConnection): 176 | if method_name.startswith("test_") and method_name != "test_extract_database_path": 177 | method = getattr(TestYDBConnection, method_name) 178 | if asyncio.iscoroutinefunction(method): 179 | setattr(TestYDBConnection, method_name, lambda self, m=method: run_async_test(self, m)) 180 | 181 | if __name__ == "__main__": 182 | unittest.main() 183 | -------------------------------------------------------------------------------- /tests/test_customjsonencoder.py: -------------------------------------------------------------------------------- 1 | """Tests for CustomJSONEncoder in ydb_mcp server.py""" 2 | 3 | import base64 4 | import datetime 5 | import decimal 6 | import json 7 | 8 | from ydb_mcp.server import CustomJSONEncoder 9 | 10 | 11 | def test_datetime_serialization(): 12 | """Test serialization of datetime objects.""" 13 | # Create test datetime objects 14 | test_datetime = datetime.datetime(2023, 7, 15, 12, 30, 45) 15 | test_date = datetime.date(2023, 7, 15) 16 | test_time = datetime.time(12, 30, 45) 17 | test_timedelta = datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=567000) 18 | 19 | # Create a nested structure with datetime objects 20 | test_data = { 21 | "datetime": test_datetime, 22 | "date": test_date, 23 | "time": test_time, 24 | "timedelta": test_timedelta, 25 | "nested": {"datetime": test_datetime}, 26 | "list_with_dates": [test_date, test_datetime], 27 | } 28 | 29 | # Serialize using CustomJSONEncoder 30 | serialized = json.dumps(test_data, cls=CustomJSONEncoder) 31 | 32 | # Deserialize back to Python objects 33 | deserialized = json.loads(serialized) 34 | 35 | # Verify datetime was serialized to ISO format 36 | assert deserialized["datetime"] == "2023-07-15T12:30:45" 37 | assert deserialized["date"] == "2023-07-15" 38 | assert deserialized["time"] == "12:30:45" 39 | assert deserialized["timedelta"] == "93784.567s" # 1 day, 2 hours, 3 minutes, 4.567 seconds in seconds 40 | assert deserialized["nested"]["datetime"] == "2023-07-15T12:30:45" 41 | assert deserialized["list_with_dates"][0] == "2023-07-15" 42 | assert deserialized["list_with_dates"][1] == "2023-07-15T12:30:45" 43 | 44 | 45 | def test_bytes_serialization(): 46 | """Test serialization of bytes objects.""" 47 | # Create test bytes objects 48 | test_utf8_bytes = "UTF8 строка".encode("utf-8") 49 | test_binary = bytes([0x00, 0x01, 0x02, 0x03, 0xFF]) 50 | 51 | # Create a nested structure with bytes objects 52 | test_data = { 53 | "utf8_bytes": test_utf8_bytes, 54 | "binary": test_binary, 55 | "nested": {"binary": test_binary}, 56 | "list_with_bytes": [test_utf8_bytes, test_binary], 57 | } 58 | 59 | # Serialize using CustomJSONEncoder 60 | serialized = json.dumps(test_data, cls=CustomJSONEncoder) 61 | 62 | # Deserialize back to Python objects 63 | deserialized = json.loads(serialized) 64 | 65 | # Verify UTF-8 bytes were decoded as strings 66 | assert deserialized["utf8_bytes"] == "UTF8 строка" 67 | 68 | # Verify binary data was serialized to base64 69 | expected_binary_base64 = base64.b64encode(test_binary).decode("ascii") 70 | assert deserialized["binary"] == expected_binary_base64 71 | assert deserialized["nested"]["binary"] == expected_binary_base64 72 | assert deserialized["list_with_bytes"][0] == "UTF8 строка" 73 | assert deserialized["list_with_bytes"][1] == expected_binary_base64 74 | 75 | # Test that we can decode the base64 back to bytes 76 | assert base64.b64decode(deserialized["binary"]) == test_binary 77 | 78 | 79 | def test_decimal_serialization(): 80 | """Test serialization of decimal objects.""" 81 | # Create test decimal objects 82 | test_decimal = decimal.Decimal("123.456789") 83 | 84 | # Create a nested structure with decimal objects 85 | test_data = { 86 | "decimal": test_decimal, 87 | "nested": {"decimal": test_decimal}, 88 | "list_with_decimals": [test_decimal, decimal.Decimal("0.1")], 89 | } 90 | 91 | # Serialize using CustomJSONEncoder 92 | serialized = json.dumps(test_data, cls=CustomJSONEncoder) 93 | 94 | # Deserialize back to Python objects 95 | deserialized = json.loads(serialized) 96 | 97 | # Verify decimal was serialized to string 98 | assert deserialized["decimal"] == "123.456789" 99 | assert deserialized["nested"]["decimal"] == "123.456789" 100 | assert deserialized["list_with_decimals"][0] == "123.456789" 101 | assert deserialized["list_with_decimals"][1] == "0.1" 102 | 103 | 104 | def test_mixed_data_serialization(): 105 | """Test serialization of mixed data types including datetime, bytes, and decimals.""" 106 | # Create test objects 107 | test_datetime = datetime.datetime(2023, 7, 15, 12, 30, 45) 108 | test_bytes = b"Hello, World!" 109 | test_decimal = decimal.Decimal("123.456789") 110 | 111 | # Create a complex nested structure with mixed data types 112 | test_data = { 113 | "string": "Regular string", 114 | "int": 42, 115 | "float": 3.14, 116 | "bool": True, 117 | "none": None, 118 | "datetime": test_datetime, 119 | "bytes": test_bytes, 120 | "decimal": test_decimal, 121 | "nested": {"datetime": test_datetime, "bytes": test_bytes, "decimal": test_decimal}, 122 | "list_mixed": [test_datetime, test_bytes, test_decimal, "string", 42], 123 | } 124 | 125 | # Serialize using CustomJSONEncoder 126 | serialized = json.dumps(test_data, cls=CustomJSONEncoder) 127 | 128 | # Verify we can deserialize the JSON (no errors) 129 | deserialized = json.loads(serialized) 130 | 131 | # The fact that we can deserialize without errors is a good test, 132 | # but let's also check a few specific values 133 | assert deserialized["string"] == "Regular string" 134 | assert deserialized["int"] == 42 135 | assert deserialized["float"] == 3.14 136 | assert deserialized["bool"] is True 137 | assert deserialized["none"] is None 138 | assert deserialized["datetime"] == "2023-07-15T12:30:45" 139 | assert deserialized["bytes"] == "Hello, World!" 140 | assert deserialized["decimal"] == "123.456789" 141 | -------------------------------------------------------------------------------- /tests/test_query.py: -------------------------------------------------------------------------------- 1 | """Tests for YDB query module.""" 2 | 3 | import asyncio 4 | import sys 5 | import unittest 6 | from unittest.mock import AsyncMock, MagicMock, patch 7 | 8 | # Add mocks for mcp.server.handler 9 | from tests.mocks import MockRequestHandler, mock_register_handler 10 | 11 | # Mock the imports 12 | sys.modules["mcp.server"] = MagicMock() 13 | sys.modules["mcp.server.handler"] = MagicMock() 14 | sys.modules["mcp.server.handler"].RequestHandler = MockRequestHandler 15 | sys.modules["mcp.server.handler"].register_handler = mock_register_handler 16 | 17 | # Import modules after mocking 18 | from ydb_mcp.connection import YDBConnection # noqa: E402 19 | from ydb_mcp.query import QueryExecutor # noqa: E402 20 | 21 | 22 | class TestQueryExecutor(unittest.TestCase): 23 | """Test cases for QueryExecutor class.""" 24 | 25 | def setUp(self): 26 | """Set up test fixtures.""" 27 | self.mock_connection = MagicMock(spec=YDBConnection) 28 | self.mock_connection.driver = MagicMock() 29 | self.mock_connection.session_pool = MagicMock() 30 | self.mock_connection.connect = AsyncMock() 31 | 32 | self.executor = QueryExecutor(self.mock_connection) 33 | 34 | @patch("asyncio.get_event_loop") 35 | async def test_execute_query(self, mock_get_event_loop): 36 | """Test execute_query method.""" 37 | # Set up mocks 38 | mock_loop = MagicMock() 39 | mock_get_event_loop.return_value = mock_loop 40 | 41 | # Configure future for run_in_executor 42 | mock_future = asyncio.Future() 43 | mock_future.set_result([{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}]) 44 | mock_loop.run_in_executor.return_value = mock_future 45 | 46 | # Execute query 47 | result = await self.executor.execute_query("SELECT * FROM test") 48 | 49 | # Verify expected interactions 50 | mock_get_event_loop.assert_called_once() 51 | mock_loop.run_in_executor.assert_called_once() 52 | 53 | # Verify the result 54 | self.assertEqual(len(result), 2) 55 | self.assertEqual(result[0]["id"], 1) 56 | self.assertEqual(result[1]["name"], "Test2") 57 | 58 | @patch("asyncio.get_event_loop") 59 | async def test_execute_query_with_connection_init(self, mock_get_event_loop): 60 | """Test execute_query with connection initialization.""" 61 | # Setup connection without driver and session_pool 62 | self.mock_connection.driver = None 63 | self.mock_connection.session_pool = None 64 | 65 | # Set up loop mock 66 | mock_loop = MagicMock() 67 | mock_get_event_loop.return_value = mock_loop 68 | 69 | # Configure future for run_in_executor 70 | mock_future = asyncio.Future() 71 | mock_future.set_result([{"result": "ok"}]) 72 | mock_loop.run_in_executor.return_value = mock_future 73 | 74 | # Execute query 75 | result = await self.executor.execute_query("SELECT 1") 76 | 77 | # Verify connection was initialized 78 | self.mock_connection.connect.assert_called_once() 79 | 80 | # Verify result 81 | self.assertEqual(len(result), 1) 82 | self.assertEqual(result[0]["result"], "ok") 83 | 84 | @patch("asyncio.get_event_loop") 85 | async def test_execute_query_error(self, mock_get_event_loop): 86 | """Test error handling in execute_query.""" 87 | # Make run_in_executor raise an exception 88 | mock_loop = MagicMock() 89 | mock_get_event_loop.return_value = mock_loop 90 | 91 | # Create a future that raises an exception 92 | mock_future = asyncio.Future() 93 | mock_future.set_exception(Exception("Test error")) 94 | mock_loop.run_in_executor.return_value = mock_future 95 | 96 | # Verify exception is propagated 97 | with self.assertRaises(Exception) as context: 98 | await self.executor.execute_query("SELECT * FROM test") 99 | 100 | self.assertIn("Test error", str(context.exception)) 101 | 102 | def test_execute_query_sync(self): 103 | """Test _execute_query_sync method.""" 104 | # Setup session pool to return mock session 105 | mock_session = MagicMock() 106 | self.mock_connection.session_pool.retry_operation_sync.side_effect = lambda callback: callback(mock_session) 107 | 108 | # Setup transaction mock 109 | mock_transaction = MagicMock() 110 | mock_session.transaction.return_value = mock_transaction 111 | 112 | # Setup mock result sets 113 | mock_col1 = MagicMock() 114 | mock_col1.name = "id" 115 | mock_col2 = MagicMock() 116 | mock_col2.name = "name" 117 | 118 | mock_row1 = MagicMock() 119 | mock_row1.__getitem__.side_effect = lambda idx: [1, "Test1"][idx] 120 | mock_row2 = MagicMock() 121 | mock_row2.__getitem__.side_effect = lambda idx: [2, "Test2"][idx] 122 | 123 | mock_rs1 = MagicMock() 124 | mock_rs1.columns = [mock_col1, mock_col2] 125 | mock_rs1.rows = [mock_row1, mock_row2] 126 | 127 | mock_transaction.execute.return_value = [mock_rs1] 128 | 129 | # Mock the _convert_row_to_dict method to return dictionary with expected keys 130 | self.executor._convert_row_to_dict = MagicMock() 131 | self.executor._convert_row_to_dict.side_effect = lambda row, col_names=None: { 132 | "id": row[0], 133 | "name": row[1], 134 | } 135 | 136 | # Call the method 137 | self.executor._session_pool = self.mock_connection.session_pool 138 | result = self.executor._execute_query_sync("SELECT * FROM test") 139 | 140 | # Verify expected interactions 141 | mock_session.transaction.assert_called_once() 142 | mock_transaction.execute.assert_called_once() 143 | 144 | # Verify results 145 | self.assertEqual(len(result), 2) 146 | self.assertEqual(result[0]["id"], 1) 147 | self.assertEqual(result[1]["name"], "Test2") 148 | 149 | def test_convert_row_to_dict(self): 150 | """Test _convert_row_to_dict method.""" 151 | # Create a mock YDB row 152 | mock_row = MagicMock() 153 | mock_row.items.return_value = [("id", 1), ("name", "Test"), ("active", True)] 154 | 155 | # Convert row to dict 156 | result = self.executor._convert_row_to_dict(mock_row) 157 | 158 | # Verify result 159 | self.assertEqual(result["id"], 1) 160 | self.assertEqual(result["name"], "Test") 161 | self.assertEqual(result["active"], True) 162 | 163 | def test_convert_ydb_value_basic_types(self): 164 | """Test _convert_ydb_value with basic types.""" 165 | # Test None 166 | self.assertIsNone(self.executor._convert_ydb_value(None)) 167 | 168 | # Test basic types that should be returned as is 169 | self.assertEqual(self.executor._convert_ydb_value(42), 42) 170 | self.assertEqual(self.executor._convert_ydb_value(3.14), 3.14) 171 | self.assertEqual(self.executor._convert_ydb_value(True), True) 172 | self.assertEqual(self.executor._convert_ydb_value(False), False) 173 | 174 | def test_convert_ydb_value_bytes(self): 175 | """Test _convert_ydb_value with bytes.""" 176 | # Test UTF-8 bytes 177 | utf8_bytes = "Hello, World!".encode("utf-8") 178 | self.assertEqual(self.executor._convert_ydb_value(utf8_bytes), utf8_bytes) 179 | 180 | # Test UTF-8 bytes with non-ASCII characters 181 | utf8_complex = "UTF8 строка".encode("utf-8") 182 | self.assertEqual(self.executor._convert_ydb_value(utf8_complex), utf8_complex) 183 | 184 | # Test non-UTF8 bytes 185 | binary_data = bytes([0x80, 0x81, 0x82]) # Invalid UTF-8 186 | self.assertEqual(self.executor._convert_ydb_value(binary_data), binary_data) 187 | 188 | def test_convert_ydb_value_datetime(self): 189 | """Test _convert_ydb_value with datetime types.""" 190 | import datetime 191 | 192 | # Test datetime 193 | dt = datetime.datetime(2023, 1, 1, 12, 0) 194 | self.assertEqual(self.executor._convert_ydb_value(dt), dt) 195 | 196 | # Test date 197 | d = datetime.date(2023, 1, 1) 198 | self.assertEqual(self.executor._convert_ydb_value(d), d) 199 | 200 | # Test time 201 | t = datetime.time(12, 0) 202 | self.assertEqual(self.executor._convert_ydb_value(t), t) 203 | 204 | # Test timedelta 205 | td = datetime.timedelta(days=1, hours=2) 206 | self.assertEqual(self.executor._convert_ydb_value(td), td) 207 | 208 | def test_convert_ydb_value_decimal(self): 209 | """Test _convert_ydb_value with Decimal type.""" 210 | from decimal import Decimal 211 | 212 | # Test decimal values 213 | d = Decimal("123.456") 214 | self.assertEqual(self.executor._convert_ydb_value(d), d) 215 | 216 | def test_convert_ydb_value_containers(self): 217 | """Test _convert_ydb_value with container types.""" 218 | # Test list with mixed types 219 | test_list = [1, "test".encode("utf-8"), True] 220 | converted_list = self.executor._convert_ydb_value(test_list) 221 | self.assertEqual(converted_list, [1, b"test", True]) 222 | 223 | # Test dict with mixed types 224 | test_dict = {"key1".encode("utf-8"): "value1".encode("utf-8"), "key2".encode("utf-8"): 42} 225 | converted_dict = self.executor._convert_ydb_value(test_dict) 226 | self.assertEqual(converted_dict, {b"key1": b"value1", b"key2": 42}) 227 | 228 | # Test tuple with mixed types 229 | test_tuple = (1, "test".encode("utf-8"), True) 230 | converted_tuple = self.executor._convert_ydb_value(test_tuple) 231 | self.assertEqual(converted_tuple, (1, b"test", True)) 232 | 233 | def test_convert_ydb_value_nested_structures(self): 234 | """Test _convert_ydb_value with nested data structures.""" 235 | # Create a complex nested structure 236 | nested_data = { 237 | "string".encode("utf-8"): "value".encode("utf-8"), 238 | "list".encode("utf-8"): [ 239 | 1, 240 | "item".encode("utf-8"), 241 | {"nested_key".encode("utf-8"): "nested_value".encode("utf-8")}, 242 | ], 243 | "dict".encode("utf-8"): { 244 | "key1".encode("utf-8"): [1, 2, "three".encode("utf-8")], 245 | "key2".encode("utf-8"): {"inner".encode("utf-8"): "value".encode("utf-8")}, 246 | }, 247 | } 248 | 249 | expected_result = { 250 | b"string": b"value", 251 | b"list": [1, b"item", {b"nested_key": b"nested_value"}], 252 | b"dict": {b"key1": [1, 2, b"three"], b"key2": {b"inner": b"value"}}, 253 | } 254 | 255 | converted = self.executor._convert_ydb_value(nested_data) 256 | self.assertEqual(converted, expected_result) 257 | 258 | 259 | # Allow tests to run with asyncio 260 | def run_async_test(test_case, test_func): 261 | """Run an async test function.""" 262 | loop = asyncio.get_event_loop() 263 | loop.run_until_complete(test_func(test_case)) 264 | 265 | 266 | # Patch test methods to run with asyncio 267 | for method_name in dir(TestQueryExecutor): 268 | if method_name.startswith("test_"): 269 | method = getattr(TestQueryExecutor, method_name) 270 | if asyncio.iscoroutinefunction(method): 271 | setattr(TestQueryExecutor, method_name, lambda self, m=method: run_async_test(self, m)) 272 | 273 | if __name__ == "__main__": 274 | unittest.main() 275 | -------------------------------------------------------------------------------- /tests/test_server.py: -------------------------------------------------------------------------------- 1 | """Unit tests for YDB MCP server implementation.""" 2 | 3 | import asyncio 4 | import datetime 5 | import decimal 6 | import json 7 | 8 | # Patch the mcp module before importing the YDBMCPServer 9 | import sys 10 | from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch 11 | 12 | import pytest 13 | import ydb 14 | from mcp.types import TextContent 15 | 16 | sys.modules["mcp.server"] = MagicMock() 17 | sys.modules["mcp.server.handler"] = MagicMock() 18 | sys.modules["mcp.server.handler"].RequestHandler = MagicMock 19 | sys.modules["mcp.server.handler"].register_handler = lambda name: lambda cls: cls 20 | 21 | from ydb_mcp.server import CustomJSONEncoder, YDBMCPServer # noqa: E402 22 | 23 | 24 | @pytest.mark.unit 25 | class TestYDBMCPServer: 26 | """Test YDB MCP server implementation.""" 27 | 28 | # Initialization tests 29 | async def test_init_with_env_vars(self): 30 | """Test initialization with environment variables.""" 31 | with patch("os.environ", {"YDB_ENDPOINT": "test-endpoint", "YDB_DATABASE": "test-database"}): 32 | with patch.object(YDBMCPServer, "register_tools"): 33 | server = YDBMCPServer() 34 | assert server.endpoint == "test-endpoint" 35 | assert server.database == "test-database" 36 | 37 | async def test_init_with_args(self): 38 | """Test initialization with arguments.""" 39 | with patch.object(YDBMCPServer, "register_tools"): 40 | server = YDBMCPServer(endpoint="arg-endpoint", database="arg-database") 41 | assert server.endpoint == "arg-endpoint" 42 | assert server.database == "arg-database" 43 | 44 | # Query tests 45 | async def test_query_simple(self): 46 | """Test simple query execution.""" 47 | with patch.object(YDBMCPServer, "register_tools"): 48 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 49 | 50 | # Create mock result set 51 | mock_col1 = MagicMock() 52 | mock_col1.name = "column1" 53 | mock_col2 = MagicMock() 54 | mock_col2.name = "column2" 55 | 56 | mock_row = MagicMock() 57 | mock_row.__getitem__.side_effect = lambda idx: ["value1", 123][idx] 58 | 59 | mock_result_set = MagicMock() 60 | mock_result_set.rows = [mock_row] 61 | mock_result_set.columns = [mock_col1, mock_col2] 62 | 63 | # Mock YDB driver and pool 64 | server.get_pool = AsyncMock() 65 | mock_pool = AsyncMock() 66 | mock_pool.execute_with_retries = AsyncMock(return_value=[mock_result_set]) 67 | server.get_pool.return_value = mock_pool 68 | 69 | # Execute query 70 | result = await server.query("SELECT * FROM table") 71 | 72 | # Check the query was executed 73 | mock_pool.execute_with_retries.assert_called_once_with("SELECT * FROM table", None) 74 | 75 | # Check the result was processed correctly 76 | assert isinstance(result, list) 77 | assert len(result) == 1 78 | assert isinstance(result[0], TextContent) 79 | 80 | # Parse the JSON response 81 | parsed_result = json.loads(result[0].text) 82 | assert "result_sets" in parsed_result 83 | assert len(parsed_result["result_sets"]) == 1 84 | assert "columns" in parsed_result["result_sets"][0] 85 | assert "rows" in parsed_result["result_sets"][0] 86 | assert parsed_result["result_sets"][0]["columns"] == ["column1", "column2"] 87 | assert len(parsed_result["result_sets"][0]["rows"]) == 1 88 | assert parsed_result["result_sets"][0]["rows"][0] == ["value1", 123] 89 | 90 | async def test_query_with_params(self): 91 | """Test query with parameters.""" 92 | with patch.object(YDBMCPServer, "register_tools"): 93 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 94 | 95 | # Mock server.query method 96 | server.query = AsyncMock(return_value={"result_sets": [{"columns": ["test"], "rows": [["data"]]}]}) 97 | 98 | # Create params as a JSON string 99 | params_json = json.dumps({"$param1": 123, "param2": "value"}) 100 | 101 | # Execute query with params 102 | result = await server.query_with_params("SELECT * FROM table WHERE id = $param1", params_json) 103 | 104 | # Check the query was executed with correct parameters 105 | expected_params = {"$param1": 123, "$param2": "value"} 106 | server.query.assert_called_once_with("SELECT * FROM table WHERE id = $param1", expected_params) 107 | 108 | # Check the result 109 | assert result == {"result_sets": [{"columns": ["test"], "rows": [["data"]]}]} 110 | 111 | async def test_query_with_invalid_params(self): 112 | """Test query with invalid parameters JSON.""" 113 | with patch.object(YDBMCPServer, "register_tools"): 114 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 115 | 116 | # Invalid JSON string 117 | params_json = "invalid json" 118 | 119 | # Execute query with invalid params 120 | result = await server.query_with_params("SELECT * FROM table", params_json) 121 | 122 | # Check the error is returned 123 | assert isinstance(result, list) 124 | assert len(result) == 1 125 | assert result[0].type == "text" 126 | assert "Error parsing JSON parameters" in result[0].text 127 | 128 | async def test_query_with_auth_error(self): 129 | """Test query execution when there's an authentication error.""" 130 | with patch.object(YDBMCPServer, "register_tools"): 131 | server = YDBMCPServer() 132 | server.auth_error = "Authentication failed: Invalid token" 133 | 134 | # Execute query 135 | result = await server.query("SELECT * FROM table") 136 | 137 | # Verify error is returned 138 | assert isinstance(result, list) 139 | assert len(result) == 1 140 | assert isinstance(result[0], TextContent) 141 | 142 | # Parse the JSON response 143 | parsed_result = json.loads(result[0].text) 144 | assert "error" in parsed_result 145 | assert parsed_result["error"] == "Authentication failed: Invalid token" 146 | 147 | async def test_query_with_complex_params(self): 148 | """Test query with complex parameter types.""" 149 | with patch.object(YDBMCPServer, "register_tools"): 150 | server = YDBMCPServer() 151 | 152 | # Mock pool for query execution 153 | mock_pool = AsyncMock() 154 | server.get_pool = AsyncMock(return_value=mock_pool) 155 | mock_pool.execute_with_retries = AsyncMock(return_value=[MagicMock()]) 156 | 157 | # Test parameters with explicit YDB types 158 | params = { 159 | "$int_param": (42, "Int32"), 160 | "$str_param": ("test", "Utf8"), 161 | "$simple_param": "simple value", 162 | } 163 | 164 | # Execute query 165 | await server.query("SELECT * FROM table", params) 166 | 167 | # Verify parameters were processed correctly 168 | mock_pool.execute_with_retries.assert_called_once() 169 | call_args = mock_pool.execute_with_retries.call_args[0] 170 | assert call_args[0] == "SELECT * FROM table" 171 | assert "$int_param" in call_args[1] 172 | assert "$str_param" in call_args[1] 173 | assert "$simple_param" in call_args[1] 174 | 175 | # Authentication tests 176 | async def test_invalid_authentication(self): 177 | """Test that authentication fails with invalid credentials.""" 178 | 179 | # Creating a dummy credentials object that will cause authentication to fail 180 | class InvalidCredentials(ydb.credentials.AbstractCredentials): 181 | def get_token(self, context): 182 | return "invalid_token_12345" 183 | 184 | def _update_driver_config(self, driver_config): 185 | # This method is required by the YDB driver 186 | pass 187 | 188 | # Create server with invalid credentials factory 189 | with patch.object(YDBMCPServer, "register_tools"): 190 | server = YDBMCPServer( 191 | endpoint="test-endpoint", 192 | database="test-database", 193 | credentials_factory=lambda: InvalidCredentials(), 194 | ) 195 | 196 | # Mock the driver creation to raise an authentication error 197 | with patch.object( 198 | server, 199 | "create_driver", 200 | side_effect=Exception("Authentication failed: Invalid credentials"), 201 | ): 202 | # Authentication should fail when creating driver 203 | with pytest.raises(Exception) as excinfo: 204 | await server.create_driver() 205 | 206 | # Verify the error message indicates an authentication problem 207 | error_message = str(excinfo.value).lower() 208 | assert "authentication" in error_message or "invalid" in error_message, ( 209 | f"Expected authentication error, got: {error_message}" 210 | ) 211 | 212 | # Directory and path tests 213 | async def test_list_directory(self): 214 | """Test the list_directory method.""" 215 | with patch.object(YDBMCPServer, "register_tools"): 216 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 217 | 218 | # Mock driver and scheme client 219 | server.driver = MagicMock() 220 | mock_scheme_client = MagicMock() 221 | server.driver.scheme_client = mock_scheme_client 222 | 223 | # Create mock response 224 | mock_entry1 = MagicMock() 225 | mock_entry1.name = "table1" 226 | mock_entry1.type = "TABLE" 227 | mock_entry1.owner = "root" 228 | mock_entry1.permissions = [] 229 | 230 | mock_entry2 = MagicMock() 231 | mock_entry2.name = "directory1" 232 | mock_entry2.type = "DIRECTORY" 233 | mock_entry2.owner = "root" 234 | mock_entry2.permissions = [] 235 | 236 | mock_response = MagicMock() 237 | mock_response.children = [mock_entry1, mock_entry2] 238 | 239 | # Setup mock list_directory to return our response 240 | mock_scheme_client.list_directory = AsyncMock(return_value=mock_response) 241 | 242 | # Call the method 243 | result = await server.list_directory("/path/to/directory") 244 | 245 | # Verify scheme_client.list_directory was called 246 | mock_scheme_client.list_directory.assert_called_once_with("/path/to/directory") 247 | 248 | # Verify result format 249 | assert isinstance(result, list) 250 | assert len(result) == 1 251 | assert result[0].type == "text" 252 | 253 | # Parse the JSON result 254 | data = json.loads(result[0].text) 255 | 256 | # Verify content 257 | assert data["path"] == "/path/to/directory" 258 | assert len(data["items"]) == 2 259 | assert {"name": "directory1", "type": "DIRECTORY", "owner": "root"} in data["items"] 260 | assert {"name": "table1", "type": "TABLE", "owner": "root"} in data["items"] 261 | 262 | async def test_list_directory_empty(self): 263 | """Test the list_directory method with empty directory.""" 264 | with patch.object(YDBMCPServer, "register_tools"): 265 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 266 | 267 | # Mock driver and scheme client 268 | server.driver = MagicMock() 269 | mock_scheme_client = MagicMock() 270 | server.driver.scheme_client = mock_scheme_client 271 | 272 | # Create mock response for empty directory 273 | mock_response = MagicMock() 274 | mock_response.children = [] 275 | 276 | # Setup mock list_directory to return our response 277 | mock_scheme_client.list_directory = AsyncMock(return_value=mock_response) 278 | 279 | # Call the method 280 | result = await server.list_directory("/path/to/empty/directory") 281 | 282 | # Verify result 283 | assert isinstance(result, list) 284 | assert len(result) == 1 285 | assert result[0].type == "text" 286 | assert "empty" in result[0].text 287 | 288 | async def test_list_directory_error(self): 289 | """Test the list_directory method with error.""" 290 | with patch.object(YDBMCPServer, "register_tools"): 291 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 292 | 293 | # Mock driver and scheme client that raises an exception 294 | server.driver = MagicMock() 295 | mock_scheme_client = MagicMock() 296 | server.driver.scheme_client = mock_scheme_client 297 | 298 | # Setup mock to raise an exception 299 | mock_scheme_client.list_directory = AsyncMock(side_effect=Exception("Access denied")) 300 | 301 | # Call the method 302 | result = await server.list_directory("/path/to/directory") 303 | 304 | # Verify error result 305 | assert isinstance(result, list) 306 | assert len(result) == 1 307 | assert result[0].type == "text" 308 | assert "Error" in result[0].text 309 | assert "Access denied" in result[0].text 310 | 311 | async def test_describe_path(self): 312 | """Test the describe_path method.""" 313 | with patch.object(YDBMCPServer, "register_tools"): 314 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 315 | 316 | # Mock driver and scheme client 317 | server.driver = MagicMock() 318 | mock_scheme_client = MagicMock() 319 | server.driver.scheme_client = mock_scheme_client 320 | 321 | # Create mock response for a directory 322 | mock_response = MagicMock() 323 | mock_response.name = "testdir" 324 | mock_response.type = "DIRECTORY" 325 | mock_response.owner = "root" 326 | mock_response.permissions = [] 327 | 328 | # Setup mock describe_path to return our response 329 | mock_scheme_client.describe_path = AsyncMock(return_value=mock_response) 330 | 331 | # Call the method 332 | result = await server.describe_path("/path/to/testdir") 333 | 334 | # Verify scheme_client.describe_path was called 335 | mock_scheme_client.describe_path.assert_called_once_with("/path/to/testdir") 336 | 337 | # Verify result format 338 | assert isinstance(result, list) 339 | assert len(result) == 1 340 | assert result[0].type == "text" 341 | 342 | # Parse the JSON result 343 | data = json.loads(result[0].text) 344 | 345 | # Verify content 346 | assert data["path"] == "/path/to/testdir" 347 | assert data["type"] == "DIRECTORY" 348 | assert data["name"] == "testdir" 349 | assert data["owner"] == "root" 350 | 351 | async def test_describe_path_table(self): 352 | """Test the describe_path method with a table path.""" 353 | with patch.object(YDBMCPServer, "register_tools"): 354 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 355 | 356 | # Mock driver and scheme client 357 | server.driver = MagicMock() 358 | mock_scheme_client = MagicMock() 359 | server.driver.scheme_client = mock_scheme_client 360 | 361 | # Create mock response for a table 362 | mock_response = MagicMock() 363 | mock_response.name = "test_table" 364 | mock_response.type = "TABLE" 365 | mock_response.owner = "root" 366 | mock_response.permissions = [] 367 | 368 | # Create mock column 369 | mock_column = MagicMock() 370 | mock_column.name = "id" 371 | mock_column.type = "Int64" 372 | 373 | # Create mock table 374 | mock_table = MagicMock() 375 | mock_table.columns = [mock_column] 376 | mock_table.primary_key = ["id"] 377 | mock_table.indexes = [] 378 | mock_table.partitioning_settings = None 379 | 380 | # Add table to response 381 | mock_response.table = mock_table 382 | 383 | # Setup mock describe_path to return our response 384 | mock_scheme_client.describe_path = AsyncMock(return_value=mock_response) 385 | 386 | # Call the method 387 | result = await server.describe_path("/path/to/test_table") 388 | 389 | # Verify result format 390 | assert isinstance(result, list) 391 | assert len(result) == 1 392 | assert result[0].type == "text" 393 | 394 | # Parse the JSON result 395 | data = json.loads(result[0].text) 396 | 397 | # Verify content 398 | assert data["path"] == "/path/to/test_table" 399 | assert data["type"] == "TABLE" 400 | assert data["name"] == "test_table" 401 | assert data["owner"] == "root" 402 | assert "table" in data 403 | assert data["table"]["columns"][0]["name"] == "id" 404 | assert data["table"]["columns"][0]["type"] == "Int64" 405 | assert data["table"]["primary_key"] == ["id"] 406 | 407 | async def test_describe_path_error(self): 408 | """Test the describe_path method with error.""" 409 | with patch.object(YDBMCPServer, "register_tools"): 410 | server = YDBMCPServer(endpoint="test-endpoint", database="test-database") 411 | 412 | # Mock driver and scheme client that raises an exception 413 | server.driver = MagicMock() 414 | mock_scheme_client = MagicMock() 415 | server.driver.scheme_client = mock_scheme_client 416 | 417 | # Setup mock to raise an exception 418 | mock_scheme_client.describe_path = AsyncMock(side_effect=Exception("Path not found")) 419 | 420 | # Call the method 421 | result = await server.describe_path("/non/existent/path") 422 | 423 | # Verify error result 424 | assert isinstance(result, list) 425 | assert len(result) == 1 426 | assert result[0].type == "text" 427 | assert "Error" in result[0].text 428 | assert "Path not found" in result[0].text 429 | 430 | # Server management tests 431 | async def test_restart_success(self): 432 | """Test successful server restart.""" 433 | with patch.object(YDBMCPServer, "register_tools"): 434 | server = YDBMCPServer() 435 | 436 | # Create mock pool and driver with proper async mocks 437 | mock_pool = AsyncMock() 438 | mock_pool.stop = AsyncMock() 439 | 440 | # Create a real asyncio task for discovery 441 | discovery_coro = asyncio.sleep(0) # A coroutine that completes immediately 442 | discovery_task = asyncio.create_task(discovery_coro) 443 | 444 | # Create a mock discovery with a synchronous stop method 445 | mock_discovery = MagicMock() 446 | mock_discovery.stop = MagicMock() # Make stop a sync method as it is in the real code 447 | mock_discovery._discovery_task = discovery_task 448 | 449 | mock_driver = AsyncMock() 450 | mock_driver.stop = AsyncMock() 451 | mock_driver.discovery = mock_discovery 452 | 453 | # Set up the mocks before the restart 454 | server.pool = mock_pool 455 | server.driver = mock_driver 456 | server.create_driver = AsyncMock(return_value=MagicMock()) 457 | 458 | # Perform restart 459 | success = await server.restart() 460 | 461 | # Clean up the task 462 | if not discovery_task.done(): 463 | discovery_task.cancel() 464 | try: 465 | await discovery_task 466 | except asyncio.CancelledError: 467 | pass 468 | 469 | # Verify all cleanup and initialization was done 470 | assert mock_pool.stop.called 471 | assert mock_driver.stop.called 472 | assert mock_discovery.stop.called 473 | assert server.create_driver.called 474 | assert success is True 475 | 476 | async def test_restart_failure(self): 477 | """Test server restart when driver creation fails.""" 478 | with patch.object(YDBMCPServer, "register_tools"): 479 | server = YDBMCPServer() 480 | 481 | # Create mock pool and driver with proper async mocks 482 | mock_pool = AsyncMock() 483 | mock_pool.stop = AsyncMock() 484 | 485 | # Create a real asyncio task for discovery 486 | discovery_coro = asyncio.sleep(0) # A coroutine that completes immediately 487 | discovery_task = asyncio.create_task(discovery_coro) 488 | 489 | # Create a mock discovery with a synchronous stop method 490 | mock_discovery = MagicMock() 491 | mock_discovery.stop = MagicMock() # Make stop a sync method as it is in the real code 492 | mock_discovery._discovery_task = discovery_task 493 | 494 | mock_driver = AsyncMock() 495 | mock_driver.stop = AsyncMock() 496 | mock_driver.discovery = mock_discovery 497 | 498 | # Set up the mocks before the restart 499 | server.pool = mock_pool 500 | server.driver = mock_driver 501 | server.create_driver = AsyncMock(return_value=None) 502 | 503 | # Perform restart 504 | success = await server.restart() 505 | 506 | # Clean up the task 507 | if not discovery_task.done(): 508 | discovery_task.cancel() 509 | try: 510 | await discovery_task 511 | except asyncio.CancelledError: 512 | pass 513 | 514 | # Verify cleanup was attempted but restart failed 515 | assert mock_pool.stop.called 516 | assert mock_driver.stop.called 517 | assert mock_discovery.stop.called 518 | assert server.create_driver.called 519 | assert success is False 520 | 521 | # Utility tests 522 | async def test_custom_json_encoder(self): 523 | """Test CustomJSONEncoder handles all special types correctly.""" 524 | test_data = { 525 | "datetime": datetime.datetime(2024, 1, 1, 12, 0), 526 | "date": datetime.date(2024, 1, 1), 527 | "time": datetime.time(12, 0), 528 | "timedelta": datetime.timedelta(seconds=3600), 529 | "decimal": decimal.Decimal("123.45"), 530 | "bytes": b"test bytes", 531 | "regular": "string", 532 | "number": 42, 533 | } 534 | 535 | # Encode the test data 536 | encoded = json.dumps(test_data, cls=CustomJSONEncoder) 537 | decoded = json.loads(encoded) 538 | 539 | # Verify each type was encoded correctly 540 | assert decoded["datetime"] == "2024-01-01T12:00:00" 541 | assert decoded["date"] == "2024-01-01" 542 | assert decoded["time"] == "12:00:00" 543 | assert decoded["timedelta"] == "3600.0s" 544 | assert decoded["decimal"] == "123.45" 545 | assert decoded["bytes"] == "test bytes" 546 | assert decoded["regular"] == "string" 547 | assert decoded["number"] == 42 548 | 549 | def test_process_result_set_error(self): 550 | """Test result set processing when an error occurs.""" 551 | with patch.object(YDBMCPServer, "register_tools"): 552 | server = YDBMCPServer() 553 | 554 | # Create a mock result set that will raise an exception 555 | mock_result_set = MagicMock() 556 | type(mock_result_set).columns = PropertyMock(side_effect=Exception("Test error")) 557 | type(mock_result_set).rows = PropertyMock( 558 | side_effect=Exception("Test error") 559 | ) # Also make rows raise an exception 560 | 561 | # Process the result set 562 | result = server._process_result_set(mock_result_set) 563 | 564 | # Verify error handling 565 | assert "error" in result 566 | assert "Test error" in result["error"] 567 | assert "columns" in result 568 | assert "rows" in result 569 | assert len(result["columns"]) == 0 570 | assert len(result["rows"]) == 0 571 | -------------------------------------------------------------------------------- /ydb_mcp/__init__.py: -------------------------------------------------------------------------------- 1 | """YDB MCP - Model Context Protocol server for YDB.""" 2 | 3 | from .version import VERSION 4 | 5 | __version__ = VERSION 6 | 7 | # Import order matters to avoid circular imports 8 | from ydb_mcp.connection import YDBConnection 9 | from ydb_mcp.query import QueryExecutor 10 | 11 | __all__ = ["YDBConnection", "QueryExecutor"] 12 | -------------------------------------------------------------------------------- /ydb_mcp/__main__.py: -------------------------------------------------------------------------------- 1 | """Main entry point for running the YDB MCP server.""" 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import sys 7 | 8 | from ydb_mcp.server import AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD, YDBMCPServer 9 | 10 | 11 | def parse_args(): 12 | """Parse command line arguments.""" 13 | parser = argparse.ArgumentParser(description="Model Context Protocol server for YDB") 14 | 15 | parser.add_argument( 16 | "--ydb-endpoint", 17 | type=str, 18 | default=os.environ.get("YDB_ENDPOINT"), 19 | help="YDB endpoint (overrides YDB_ENDPOINT env var)", 20 | ) 21 | parser.add_argument( 22 | "--ydb-database", 23 | type=str, 24 | default=os.environ.get("YDB_DATABASE"), 25 | help="YDB database path (overrides YDB_DATABASE env var)", 26 | ) 27 | parser.add_argument( 28 | "--ydb-login", 29 | type=str, 30 | default=os.environ.get("YDB_LOGIN"), 31 | help="YDB login (overrides YDB_LOGIN env var)", 32 | ) 33 | parser.add_argument( 34 | "--ydb-password", 35 | type=str, 36 | default=os.environ.get("YDB_PASSWORD"), 37 | help="YDB password (overrides YDB_PASSWORD env var)", 38 | ) 39 | parser.add_argument( 40 | "--ydb-auth-mode", 41 | type=str, 42 | default=os.environ.get("YDB_AUTH_MODE"), 43 | choices=["anonymous", "login-password"], 44 | help="YDB authentication mode (overrides YDB_AUTH_MODE env var)", 45 | ) 46 | 47 | parser.add_argument( 48 | "--log-level", 49 | type=str, 50 | default="INFO", 51 | choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], 52 | help="Logging level", 53 | ) 54 | 55 | return parser.parse_args() 56 | 57 | 58 | def main(): 59 | """Run the YDB MCP server.""" 60 | args = parse_args() 61 | 62 | # Configure logging 63 | logging.basicConfig( 64 | level=getattr(logging, args.log_level), 65 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 66 | ) 67 | 68 | # Validate auth mode and required credentials 69 | supported_auth_modes = {AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD} 70 | auth_mode = args.ydb_auth_mode or AUTH_MODE_ANONYMOUS 71 | if auth_mode not in supported_auth_modes: 72 | print( 73 | f"Error: Unsupported auth mode: {auth_mode}. Supported modes: {', '.join(supported_auth_modes)}", 74 | file=sys.stderr, 75 | ) 76 | exit(1) 77 | if auth_mode == AUTH_MODE_LOGIN_PASSWORD: 78 | if not args.ydb_login or not args.ydb_password: 79 | print( 80 | "Error: --ydb-login and --ydb-password are required for login-password authentication mode.", 81 | file=sys.stderr, 82 | ) 83 | exit(1) 84 | 85 | # Set environment variables for YDB if provided via arguments 86 | if args.ydb_endpoint: 87 | os.environ["YDB_ENDPOINT"] = args.ydb_endpoint 88 | if args.ydb_database: 89 | os.environ["YDB_DATABASE"] = args.ydb_database 90 | if args.ydb_login: 91 | os.environ["YDB_LOGIN"] = args.ydb_login 92 | if args.ydb_password: 93 | os.environ["YDB_PASSWORD"] = args.ydb_password 94 | if args.ydb_auth_mode: 95 | os.environ["YDB_AUTH_MODE"] = args.ydb_auth_mode 96 | 97 | # Create and run the server 98 | server = YDBMCPServer( 99 | endpoint=args.ydb_endpoint, 100 | database=args.ydb_database, 101 | login=args.ydb_login, 102 | password=args.ydb_password, 103 | auth_mode=auth_mode, 104 | ) 105 | 106 | print("Starting YDB MCP server with stdio transport") 107 | print(f"YDB endpoint: {args.ydb_endpoint or 'Not set'}") 108 | print(f"YDB database: {args.ydb_database or 'Not set'}") 109 | print(f"YDB login: {'Set' if args.ydb_login else 'Not set'}") 110 | print(f"YDB auth mode: {args.ydb_auth_mode or 'Default (anonymous)'}") 111 | 112 | server.run() 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /ydb_mcp/connection.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import re 4 | from typing import Optional, Tuple 5 | from urllib.parse import urlparse 6 | 7 | import ydb 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class YDBConnection: 13 | """Manages YDB connection with async support.""" 14 | 15 | def __init__(self, connection_string: str, database: str | None = None): 16 | """Initialize YDB connection. 17 | 18 | Args: 19 | connection_string: YDB connection string 20 | database: Optional database path. If not provided, will be extracted from connection_string if present 21 | """ 22 | self.connection_string = connection_string 23 | self.driver: Optional[ydb.Driver] = None 24 | self.session_pool: Optional[ydb.aio.QuerySessionPool] = None 25 | self._database = database 26 | self.last_error: str | None = None 27 | 28 | def _parse_endpoint_and_database(self) -> Tuple[str, str]: 29 | """Parse endpoint and database from connection string. 30 | 31 | Returns: 32 | Tuple of (endpoint, database) 33 | 34 | Raises: 35 | RuntimeError: If no database is specified either in connection string or explicitly 36 | """ 37 | # Parse the URL 38 | connection_string = self.connection_string 39 | if not connection_string.startswith(("grpc://", "grpcs://")): 40 | # If no scheme, assume grpc:// and parse as host:port 41 | if "/" in connection_string: 42 | host_port, path = connection_string.split("/", 1) 43 | connection_string = f"grpc://{host_port}/{path}" 44 | else: 45 | connection_string = f"grpc://{connection_string}" 46 | 47 | parsed = urlparse(connection_string) 48 | 49 | # Extract endpoint (scheme + netloc) 50 | endpoint = f"{parsed.scheme}://{parsed.netloc}" 51 | 52 | # Extract database path 53 | database = self._database 54 | if not database: 55 | if parsed.path: 56 | database = parsed.path 57 | # Remove query parameters if present 58 | if "?" in database: 59 | database = database.split("?")[0] 60 | 61 | # Ensure database starts with / 62 | if database and not database.startswith("/"): 63 | database = f"/{database}" 64 | 65 | # Raise error if no database specified 66 | if not database: 67 | raise RuntimeError("Database not specified in connection string or explicitly") 68 | 69 | return endpoint, database 70 | 71 | async def connect(self) -> Tuple[ydb.Driver, ydb.aio.QuerySessionPool]: 72 | """Connect to YDB and setup session pool asynchronously. 73 | 74 | Returns: 75 | Tuple of (driver, session_pool) 76 | 77 | Raises: 78 | RuntimeError: If connection fails 79 | """ 80 | try: 81 | endpoint, database = self._parse_endpoint_and_database() 82 | logger.info(f"Connecting to YDB endpoint: {endpoint}, database: {database}") 83 | 84 | # Create driver with direct parameters instead of config 85 | self.driver = ydb.aio.Driver( 86 | endpoint=endpoint, 87 | database=database, 88 | credentials=ydb.credentials.AnonymousCredentials(), 89 | ) 90 | 91 | # Wait for driver to be ready with timeout 92 | try: 93 | await asyncio.wait_for(self.driver.wait(), timeout=10.0) 94 | except asyncio.TimeoutError: 95 | self.last_error = "Connection timeout" 96 | raise RuntimeError("YDB driver connection timeout after 10 seconds") 97 | 98 | # Check if we connected successfully 99 | if not self.driver.discovery_debug_details().startswith("Resolved endpoints"): 100 | debug_details = self.driver.discovery_debug_details() 101 | self.last_error = f"Driver not ready: {debug_details}" 102 | raise RuntimeError(f"YDB driver failed to connect: {debug_details}") 103 | 104 | logger.info("Connected to YDB successfully") 105 | 106 | # Create session pool 107 | self.session_pool = ydb.aio.QuerySessionPool(self.driver) 108 | 109 | return self.driver, self.session_pool 110 | 111 | except Exception as e: 112 | self.last_error = str(e) 113 | logger.error(f"Failed to connect to YDB: {e}") 114 | raise RuntimeError(f"Failed to connect to YDB: {e}") 115 | 116 | async def close(self) -> None: 117 | """Close YDB connection.""" 118 | logger.info("Closing YDB connection") 119 | 120 | if self.session_pool: 121 | await asyncio.get_event_loop().run_in_executor(None, self.session_pool.stop) 122 | self.session_pool = None 123 | 124 | if self.driver: 125 | await asyncio.get_event_loop().run_in_executor(None, self.driver.stop) 126 | self.driver = None 127 | 128 | logger.info("YDB connection closed") 129 | 130 | def _extract_database_path(self, connection_string: Optional[str] = None) -> str: 131 | """Extract database path from connection string. 132 | 133 | Args: 134 | connection_string: YDB connection string, or None to use the instance's connection string 135 | 136 | Returns: 137 | Database path 138 | """ 139 | # Use instance connection string if none provided 140 | if connection_string is None: 141 | connection_string = self.connection_string 142 | 143 | # Handle connection string with query parameters 144 | if "?" in connection_string: 145 | connection_string = connection_string.split("?")[0] 146 | 147 | # Extract path using regex 148 | match = re.match(r"^(?:[^:]+://[^/]+)?(/.*)?$", connection_string) 149 | return match.group(1) if match and match.group(1) else "/" 150 | -------------------------------------------------------------------------------- /ydb_mcp/query.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import decimal 4 | import logging 5 | import sys 6 | from typing import Any, Dict, List 7 | 8 | import ydb 9 | 10 | from ydb_mcp.connection import YDBConnection 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class QueryExecutor: 16 | """Executor for SQL queries against YDB.""" 17 | 18 | def __init__(self, connection: YDBConnection): 19 | """Initialize the query executor. 20 | 21 | Args: 22 | connection: YDBConnection instance to use for executing queries 23 | """ 24 | self.connection = connection 25 | self._session_pool = None 26 | 27 | async def execute_query(self, query: str) -> List[Dict[str, Any]]: 28 | """Execute a read-only SQL query and return the results. 29 | 30 | Args: 31 | query: SQL query string to execute 32 | 33 | Returns: 34 | List of dictionaries representing rows of the query result 35 | 36 | Raises: 37 | Exception: If the query execution fails 38 | """ 39 | if not self.connection.driver or not self.connection.session_pool: 40 | await self.connection.connect() 41 | 42 | self._session_pool = self.connection.session_pool 43 | 44 | try: 45 | loop = asyncio.get_event_loop() 46 | result = await loop.run_in_executor(None, self._execute_query_sync, query) 47 | return result 48 | except Exception as e: 49 | # Only log real errors, not test errors 50 | if "Test error" not in str(e) or "pytest" not in sys.modules: 51 | logger.error(f"Error executing query: {e}") 52 | raise 53 | 54 | def _execute_query_sync(self, query: str) -> List[Dict[str, Any]]: 55 | """Execute a query synchronously. 56 | 57 | This method is intended to be called by execute_query via run_in_executor. 58 | 59 | Args: 60 | query: SQL query string to execute 61 | 62 | Returns: 63 | List of dictionaries representing rows of the query result 64 | """ 65 | 66 | def _execute_query(session): 67 | # Execute query and get result sets 68 | result_sets = session.transaction().execute( 69 | query, 70 | commit_tx=True, 71 | settings=ydb.BaseRequestSettings().with_timeout(3).with_operation_timeout(2), 72 | ) 73 | 74 | # Convert result sets to list of dictionaries 75 | result = [] 76 | for rs in result_sets: 77 | for row in rs.rows: 78 | result.append(self._convert_row_to_dict(row)) 79 | return result 80 | 81 | if self._session_pool is None: 82 | raise RuntimeError("SessionPool is not provided.") 83 | 84 | return self._session_pool.retry_operation_sync(_execute_query) 85 | 86 | def _convert_row_to_dict(self, row: Any, col_names: List[str] | None = None) -> Dict[str, Any]: 87 | """Convert a YDB result row to a dictionary. 88 | 89 | Args: 90 | row: YDB result row 91 | col_names: Optional list of column names 92 | 93 | Returns: 94 | Dictionary representing the row data 95 | """ 96 | result = {} 97 | for key, value in row.items(): 98 | result[key] = self._convert_ydb_value(value) 99 | return result 100 | 101 | def _convert_ydb_value(self, value: Any) -> Any: 102 | """Convert YDB-specific types to Python types. 103 | 104 | Args: 105 | value: YDB value to convert 106 | 107 | Returns: 108 | Converted Python value 109 | """ 110 | # Handle None/null values 111 | if value is None: 112 | return None 113 | 114 | # Handle bytes (strings in YDB are returned as bytes) 115 | if isinstance(value, bytes): 116 | # For now, keep all strings as bytes since we don't have type info 117 | return value 118 | 119 | # Handle date/time types 120 | if isinstance(value, (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)): 121 | return value 122 | 123 | # Handle Decimal type 124 | if isinstance(value, decimal.Decimal): 125 | return value 126 | 127 | # Handle container types 128 | if isinstance(value, list): 129 | return [self._convert_ydb_value(item) for item in value] 130 | if isinstance(value, dict): 131 | return {self._convert_ydb_value(k): self._convert_ydb_value(v) for k, v in value.items()} 132 | if isinstance(value, tuple): 133 | return tuple(self._convert_ydb_value(item) for item in value) 134 | 135 | # For all other types (int, float, bool), return as is 136 | return value 137 | -------------------------------------------------------------------------------- /ydb_mcp/server.py: -------------------------------------------------------------------------------- 1 | """Model Context Protocol server for YDB DBMS proxy.""" 2 | 3 | import asyncio 4 | import base64 5 | import datetime 6 | import decimal 7 | import json 8 | import logging 9 | import os 10 | from typing import Any, Callable, Dict, List, Optional 11 | 12 | import ydb 13 | from mcp.server.fastmcp import FastMCP 14 | from mcp.types import TextContent 15 | from ydb.aio import QuerySessionPool 16 | 17 | from ydb_mcp.connection import YDBConnection 18 | from ydb_mcp.tool_manager import ToolManager 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | # Authentication mode constants 23 | AUTH_MODE_ANONYMOUS = "anonymous" 24 | AUTH_MODE_LOGIN_PASSWORD = "login-password" 25 | 26 | 27 | class CustomJSONEncoder(json.JSONEncoder): 28 | """Custom JSON encoder that handles non-serializable types properly.""" 29 | 30 | def default(self, obj): 31 | # Handle datetime objects 32 | if isinstance(obj, datetime.datetime): 33 | # Convert to UTC if timezone-aware 34 | if obj.tzinfo is not None: 35 | obj = obj.astimezone(datetime.UTC) 36 | return obj.isoformat() 37 | 38 | # Handle date objects 39 | if isinstance(obj, datetime.date): 40 | return obj.isoformat() 41 | 42 | # Handle time objects 43 | if isinstance(obj, datetime.time): 44 | return obj.isoformat() 45 | 46 | # Handle timedelta objects 47 | if isinstance(obj, datetime.timedelta): 48 | # Convert to total seconds and format as string 49 | return f"{obj.total_seconds()}s" 50 | 51 | # Handle decimal objects 52 | if isinstance(obj, decimal.Decimal): 53 | return str(obj) 54 | 55 | # Handle bytes objects - try UTF-8 first, fall back to base64 56 | if isinstance(obj, bytes): 57 | try: 58 | return obj.decode("utf-8") 59 | except UnicodeDecodeError: 60 | # If it's not valid UTF-8, base64 encode it 61 | return base64.b64encode(obj).decode("ascii") 62 | 63 | # Use the parent class's default method for other types 64 | return super().default(obj) 65 | 66 | 67 | class YDBMCPServer(FastMCP): 68 | """Model Context Protocol server for YDB DBMS. 69 | 70 | Features: 71 | - Execute SQL queries against YDB database 72 | - Support for multiple SQL statements in a single query 73 | - Support for anonymous and login-password authentication modes 74 | """ 75 | 76 | # YDB entry type mapping 77 | ENTRY_TYPE_MAP = { 78 | 1: "DIRECTORY", 79 | 2: "TABLE", 80 | 3: "PERS_QUEUE", 81 | 4: "DATABASE", 82 | 5: "RTMR_VOLUME", 83 | 6: "BLOCK_STORE_VOLUME", 84 | 7: "COORDINATION", 85 | 8: "SEQUENCE", 86 | 9: "REPLICATION", 87 | 10: "TOPIC", 88 | 11: "EXTERNAL_DATA_SOURCE", 89 | 12: "EXTERNAL_TABLE", 90 | } 91 | 92 | def __init__( 93 | self, 94 | endpoint: str | None = None, 95 | database: str | None = None, 96 | credentials_factory: Callable[[], ydb.Credentials] | None = None, 97 | ydb_connection_string: str = "", 98 | tool_manager: ToolManager | None = None, 99 | auth_mode: str | None = None, 100 | login: str | None = None, 101 | password: str | None = None, 102 | root_certificates: str | None = None, 103 | *args, 104 | **kwargs, 105 | ): 106 | """Initialize YDB MCP server. 107 | 108 | Args: 109 | endpoint: YDB endpoint 110 | database: YDB database 111 | credentials_factory: YDB credentials factory 112 | ydb_connection_string: YDB connection string (alternative to endpoint+database) 113 | tool_manager: External tool manager (optional) 114 | auth_mode: Authentication mode (anonymous, login_password) 115 | login: Login for authentication 116 | password: Password for authentication 117 | root_certificates: Root certificates for YDB 118 | """ 119 | super().__init__(*args, **kwargs) 120 | 121 | # Initialize YDB-specific attributes 122 | self.driver = None 123 | self.endpoint = endpoint or os.environ.get("YDB_ENDPOINT", "grpc://localhost:2136") 124 | self.database = database or os.environ.get("YDB_DATABASE", "/local") 125 | self.credentials_factory = credentials_factory 126 | self.ydb_connection_string = ydb_connection_string 127 | self.auth_error: str | None = None 128 | self._loop = None 129 | self.pool = None 130 | self.tool_manager = tool_manager or ToolManager() 131 | self._driver_lock = asyncio.Lock() 132 | self._pool_lock = asyncio.Lock() 133 | self.root_certificates = root_certificates 134 | self._original_methods: Dict = {} 135 | 136 | # Authentication settings 137 | supported_auth_modes = {AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD} 138 | self.auth_mode = auth_mode or AUTH_MODE_ANONYMOUS 139 | if self.auth_mode not in supported_auth_modes: 140 | raise ValueError( 141 | f"Unsupported auth mode: {self.auth_mode}. Supported modes: {', '.join(supported_auth_modes)}" 142 | ) 143 | self.login = login 144 | self.password = password 145 | 146 | # Initialize logging 147 | logging.basicConfig(level=logging.INFO) 148 | 149 | # Register YDB tools 150 | self.register_tools() 151 | 152 | def _restore_ydb_patches(self): 153 | """Restore original YDB methods that were patched.""" 154 | # Restore topic client __del__ method 155 | if "topic_client_del" in self._original_methods and hasattr(ydb, "topic") and hasattr(ydb.topic, "TopicClient"): 156 | if self._original_methods["topic_client_del"] is not None: 157 | ydb.topic.TopicClient.__del__ = self._original_methods["topic_client_del"] 158 | else: 159 | # If there was no original method, try to remove our patched one 160 | if hasattr(ydb.topic.TopicClient, "__del__"): 161 | delattr(ydb.topic.TopicClient, "__del__") 162 | logger.info("Restored original YDB TopicClient __del__ method") 163 | 164 | def _anonymous_credentials(self) -> ydb.Credentials: 165 | """Create anonymous credentials.""" 166 | logger.info("Using anonymous authentication") 167 | return ydb.credentials.AnonymousCredentials() 168 | 169 | def _login_password_credentials(self) -> ydb.Credentials: 170 | """Create login-password credentials.""" 171 | logger.info(f"Using login-password authentication with login: {self.login}") 172 | return ydb.credentials.StaticCredentials.from_user_password(self.login, self.password) 173 | 174 | async def create_driver(self): 175 | """Create a YDB driver with the current settings. 176 | 177 | Returns: 178 | ydb.aio.Driver or None: The created driver instance if successful, None if failed 179 | """ 180 | try: 181 | # Get credentials 182 | credentials_factory = self.get_credentials_factory() 183 | if not credentials_factory: 184 | return None 185 | 186 | # Ensure we use the current event loop 187 | self._loop = asyncio.get_event_loop() 188 | 189 | # Determine endpoint and database 190 | endpoint = self.endpoint 191 | database = self.database 192 | 193 | # If we have a connection string, parse it 194 | if self.ydb_connection_string: 195 | conn = YDBConnection(self.ydb_connection_string) 196 | endpoint, database = conn._parse_endpoint_and_database() 197 | 198 | # Validate we have required parameters 199 | if not endpoint: 200 | self.auth_error = "YDB endpoint not specified" 201 | logger.error(self.auth_error) 202 | return None 203 | 204 | if not database: 205 | self.auth_error = "YDB database not specified" 206 | logger.error(self.auth_error) 207 | return None 208 | 209 | logger.info(f"Connecting to YDB at {endpoint}, database: {database}") 210 | 211 | # Create the driver config 212 | driver_config = ydb.DriverConfig( 213 | endpoint=endpoint, 214 | database=database, 215 | credentials=credentials_factory(), 216 | root_certificates=self.root_certificates, 217 | ) 218 | 219 | # Create and initialize the driver 220 | self.driver = ydb.aio.Driver(driver_config) 221 | 222 | # Initialize driver with latest API 223 | await self.driver.wait(timeout=5.0) 224 | # Check if we connected successfully 225 | debug_details = await self._loop.run_in_executor(None, lambda: self.driver.discovery_debug_details()) 226 | if not debug_details.startswith("Resolved endpoints"): 227 | self.auth_error = f"Failed to connect to YDB server: {debug_details}" 228 | logger.error(self.auth_error) 229 | return None 230 | 231 | logger.info(f"Successfully connected to YDB at {endpoint}") 232 | return self.driver 233 | 234 | except Exception as e: 235 | self.auth_error = str(e) 236 | logger.error(f"Error creating YDB driver: {e}") 237 | return None 238 | 239 | async def _close_topic_client(self, topic_client): 240 | """Properly close a topic client.""" 241 | if topic_client is not None and hasattr(topic_client, "close"): 242 | try: 243 | logger.info("Closing YDB topic client") 244 | # Ensure we wait for the close operation to complete 245 | await topic_client.close() 246 | return True 247 | except Exception as e: 248 | logger.warning(f"Error closing topic client: {e}") 249 | return False 250 | 251 | async def _terminate_discovery(self, discovery): 252 | """Properly terminate a discovery process and wait for tasks to complete.""" 253 | if discovery is not None: 254 | try: 255 | # First check for the discovery task 256 | if hasattr(discovery, "_discovery_task") and discovery._discovery_task is not None: 257 | task = discovery._discovery_task 258 | if not task.done() and not task.cancelled(): 259 | logger.info("Cancelling discovery task") 260 | task.cancel() 261 | try: 262 | # Wait for task cancellation to complete 263 | await asyncio.wait_for(asyncio.shield(task), timeout=0.5) 264 | except (asyncio.CancelledError, asyncio.TimeoutError, Exception) as e: 265 | logger.warning(f"Error waiting for discovery task cancellation: {e}") 266 | 267 | # Handle any streaming response generators that might be running 268 | if hasattr(discovery, "_fetch_stream_responses") and callable(discovery._fetch_stream_responses): 269 | # This is a generator method that might be active 270 | # Nothing to do directly - the generator will be GC'ed when the driver is destroyed 271 | pass 272 | 273 | # Then call terminate if available, but be careful of recursion 274 | if hasattr(discovery, "terminate"): 275 | logger.info("Terminating YDB discovery process") 276 | # Don't call our own terminate method to avoid recursion 277 | original_terminate = discovery.terminate 278 | if original_terminate.__name__ != "_terminate_discovery": 279 | await original_terminate() 280 | return True 281 | except Exception as e: 282 | logger.warning(f"Error terminating discovery: {e}") 283 | return False 284 | 285 | async def _cancel_ydb_related_tasks(self): 286 | """Find and cancel YDB-related tasks to prevent conflicts during shutdown.""" 287 | discovery_tasks = [] 288 | 289 | # Find YDB discovery-related tasks 290 | for task in asyncio.all_tasks(self._loop): 291 | task_str = str(task) 292 | if "Discovery.run" in task_str and not task.done() and not task.cancelled(): 293 | discovery_tasks.append(task) 294 | 295 | if discovery_tasks: 296 | logger.info(f"Cancelling {len(discovery_tasks)} discovery tasks before restart") 297 | 298 | # Cancel all discovery tasks 299 | for task in discovery_tasks: 300 | task.cancel() 301 | 302 | # Wait briefly for tasks to cancel 303 | if discovery_tasks: 304 | try: 305 | await asyncio.wait_for(asyncio.gather(*discovery_tasks, return_exceptions=True), timeout=0.5) 306 | except (asyncio.TimeoutError, asyncio.CancelledError): 307 | pass 308 | 309 | # Wait a moment to allow task cancellation to complete 310 | await asyncio.sleep(0.1) 311 | 312 | async def get_pool(self) -> QuerySessionPool: 313 | """Get or create YDB session pool.""" 314 | # Check for authentication errors first 315 | if self.auth_error: 316 | # Raise an exception with the auth error message which query() will catch 317 | raise ValueError(self.auth_error) 318 | 319 | async with self._pool_lock: 320 | if self.driver is None: 321 | await self.create_driver() 322 | 323 | if self.pool is None: 324 | self.pool = QuerySessionPool(self.driver) 325 | 326 | return self.pool 327 | 328 | def _stringify_dict_keys(self, obj): 329 | """Recursively convert all dict keys to strings for JSON serialization.""" 330 | if isinstance(obj, dict): 331 | return {str(k): self._stringify_dict_keys(v) for k, v in obj.items()} 332 | elif isinstance(obj, list): 333 | return [self._stringify_dict_keys(i) for i in obj] 334 | else: 335 | return obj 336 | 337 | async def query(self, sql: str, params: Optional[Dict[str, Any]] = None) -> List[TextContent]: 338 | """Run a SQL query against YDB. 339 | 340 | Args: 341 | sql: SQL query to execute 342 | params: Optional query parameters 343 | 344 | Returns: 345 | List of TextContent objects with JSON-formatted results 346 | """ 347 | # Check if there's an authentication error 348 | if self.auth_error: 349 | return [TextContent(type="text", text=json.dumps({"error": self.auth_error}, indent=2))] 350 | 351 | try: 352 | pool = await self.get_pool() 353 | ydb_params = None 354 | if params: 355 | ydb_params = {} 356 | for key, value in params.items(): 357 | param_key = key if key.startswith("$") else f"${key}" 358 | ydb_params[param_key] = value 359 | result_sets = await pool.execute_with_retries(sql, ydb_params) 360 | all_results = [] 361 | for result_set in result_sets: 362 | processed = self._process_result_set(result_set) 363 | all_results.append(processed) 364 | # Convert all dict keys to strings for JSON serialization 365 | safe_result = self._stringify_dict_keys({"result_sets": all_results}) 366 | return [TextContent(type="text", text=json.dumps(safe_result, indent=2, cls=CustomJSONEncoder))] 367 | except Exception as e: 368 | error_message = str(e) 369 | safe_error = self._stringify_dict_keys({"error": error_message}) 370 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 371 | 372 | def _process_result_set(self, result_set): 373 | """Process YDB result set into a dictionary format. 374 | 375 | Args: 376 | result_set: YDB result set object 377 | 378 | Returns: 379 | Processed result set as a dictionary 380 | """ 381 | try: 382 | # Extract columns 383 | columns = [] 384 | try: 385 | # Get column names from the columns attribute 386 | columns_attr = getattr(result_set, "columns") 387 | columns = [col.name for col in columns_attr] 388 | except Exception as e: 389 | logger.exception(f"Error getting columns: {e}") 390 | return {"error": str(e), "columns": [], "rows": []} 391 | 392 | # Extract rows 393 | rows = [] 394 | try: 395 | rows_attr = getattr(result_set, "rows") 396 | for row in rows_attr: 397 | row_values = [] 398 | for i in range(len(columns)): 399 | row_values.append(row[i]) 400 | rows.append(row_values) 401 | except Exception as e: 402 | logger.exception(f"Error getting rows: {e}") 403 | return {"error": str(e), "columns": columns, "rows": []} 404 | 405 | return {"columns": columns, "rows": rows} 406 | except Exception as e: 407 | logger.exception(f"Error processing result set: {e}") 408 | return {"error": str(e), "columns": [], "rows": []} 409 | 410 | async def query_with_params(self, sql: str, params: str) -> List[TextContent]: 411 | """Run a parameterized SQL query with JSON parameters. 412 | 413 | Args: 414 | sql: SQL query to execute 415 | params: Parameters as a JSON string 416 | 417 | Returns: 418 | Query results as a list of TextContent objects or a dictionary 419 | """ 420 | # Handle authentication errors 421 | if self.auth_error: 422 | logger.error(f"Authentication error: {self.auth_error}") 423 | safe_error = self._stringify_dict_keys({"error": f"Authentication error: {self.auth_error}"}) 424 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 425 | parsed_params = {} 426 | try: 427 | if params and params.strip(): 428 | parsed_params = json.loads(params) 429 | except json.JSONDecodeError as e: 430 | logger.error(f"Error parsing JSON parameters: {str(e)}") 431 | safe_error = self._stringify_dict_keys({"error": f"Error parsing JSON parameters: {str(e)}"}) 432 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 433 | # Convert [value, type] to YDB type if needed 434 | ydb_params = {} 435 | for key, value in parsed_params.items(): 436 | param_key = key if key.startswith("$") else f"${key}" 437 | if isinstance(value, (list, tuple)) and len(value) == 2: 438 | param_value, type_name = value 439 | if isinstance(type_name, str) and hasattr(ydb.PrimitiveType, type_name): 440 | ydb_type = getattr(ydb.PrimitiveType, type_name) 441 | ydb_params[param_key] = (param_value, ydb_type) 442 | else: 443 | ydb_params[param_key] = param_value 444 | else: 445 | ydb_params[param_key] = value 446 | try: 447 | return await self.query(sql, ydb_params) 448 | except Exception as e: 449 | error_message = f"Error executing parameterized query: {str(e)}" 450 | logger.error(error_message) 451 | safe_error = self._stringify_dict_keys({"error": error_message}) 452 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 453 | 454 | def register_tools(self): 455 | """Register YDB query tools. 456 | 457 | Note: Tools are registered with both the FastMCP framework and our tool_manager. 458 | The FastMCP.add_tool method doesn't support parameters, so we only provide 459 | the handler, name, and description to it. The complete tool specification 460 | including parameters is registered with our tool_manager. 461 | """ 462 | # Define tool specifications 463 | tool_specs = [ 464 | { 465 | "name": "ydb_query", 466 | "description": "Run a SQL query against YDB database", 467 | "handler": self.query, # Use real handler 468 | "parameters": { 469 | "properties": {"sql": {"type": "string", "title": "Sql"}}, 470 | "required": ["sql"], 471 | "type": "object", 472 | }, 473 | }, 474 | { 475 | "name": "ydb_query_with_params", 476 | "description": "Run a parameterized SQL query with JSON parameters", 477 | "handler": self.query_with_params, # Use real handler 478 | "parameters": { 479 | "properties": { 480 | "sql": {"type": "string", "title": "Sql"}, 481 | "params": {"type": "string", "title": "Params"}, 482 | }, 483 | "required": ["sql", "params"], 484 | "type": "object", 485 | }, 486 | }, 487 | { 488 | "name": "ydb_status", 489 | "description": "Get the current status of the YDB connection", 490 | "handler": self.get_connection_status, # Use real handler 491 | "parameters": {"type": "object", "properties": {}, "required": []}, 492 | }, 493 | { 494 | "name": "ydb_list_directory", 495 | "description": "List directory contents in YDB", 496 | "handler": self.list_directory, 497 | "parameters": { 498 | "properties": {"path": {"type": "string", "title": "Path"}}, 499 | "required": ["path"], 500 | "type": "object", 501 | }, 502 | }, 503 | { 504 | "name": "ydb_describe_path", 505 | "description": "Get detailed information about a YDB path (table, directory, etc.)", 506 | "handler": self.describe_path, 507 | "parameters": { 508 | "properties": {"path": {"type": "string", "title": "Path"}}, 509 | "required": ["path"], 510 | "type": "object", 511 | }, 512 | }, 513 | ] 514 | 515 | # Register all tools with FastMCP framework 516 | for spec in tool_specs: 517 | self.add_tool(spec["handler"], name=spec["name"], description=spec["description"]) 518 | 519 | # Also register with our tool manager 520 | self.tool_manager.register_tool( 521 | name=spec["name"], 522 | handler=spec["handler"], 523 | description=spec["description"], 524 | parameters=spec.get("parameters"), 525 | ) 526 | 527 | async def get_connection_status(self) -> List[TextContent]: 528 | """Get the current status of the YDB connection. 529 | 530 | Returns: 531 | List of TextContent objects 532 | """ 533 | connection_status = "disconnected" 534 | error_message = None 535 | 536 | try: 537 | # Force create driver to ensure up-to-date status 538 | if self.driver is None: 539 | logger.info("Creating new driver for connection status check") 540 | await self.create_driver() 541 | 542 | if self.driver: 543 | try: 544 | discovery = self.driver.discovery_debug_details() 545 | if discovery.startswith("Resolved endpoints"): 546 | connection_status = "connected" 547 | else: 548 | error_message = f"Discovery error: {discovery}" 549 | except Exception as conn_error: 550 | error_message = f"Error checking connection via discovery: {conn_error}" 551 | else: 552 | error_message = "No driver available for connection status check" 553 | except Exception as e: 554 | error_message = str(e) 555 | 556 | status_info = { 557 | "status": "running", 558 | "ydb_endpoint": self.endpoint, 559 | "ydb_database": self.database, 560 | "auth_mode": self.auth_mode, 561 | "ydb_connection": connection_status, 562 | "error": error_message, 563 | } 564 | 565 | # Format the result as a TextContent object 566 | safe_status = self._stringify_dict_keys(status_info) 567 | formatted_result = json.dumps(safe_status, indent=2, cls=CustomJSONEncoder) 568 | logger.info(f"Connection status: {formatted_result}") 569 | return [TextContent(type="text", text=formatted_result)] 570 | 571 | async def list_directory(self, path: str) -> List[TextContent]: 572 | """List the contents of a YDB directory. 573 | 574 | Args: 575 | path: Path to the directory to list 576 | 577 | Returns: 578 | List of TextContent objects with JSON-formatted directory contents 579 | """ 580 | # Check for authentication errors 581 | if self.auth_error: 582 | return [TextContent(type="text", text=json.dumps({"error": self.auth_error}, indent=2))] 583 | 584 | try: 585 | # Create driver if needed 586 | if self.driver is None: 587 | await self.create_driver() 588 | 589 | if self.driver is None: 590 | return [TextContent(type="text", text=json.dumps({"error": "Failed to create driver"}, indent=2))] 591 | 592 | # Access the scheme client 593 | scheme_client = self.driver.scheme_client 594 | 595 | # List the directory 596 | logger.info(f"Listing directory contents for path: {path}") 597 | dir_response = await scheme_client.list_directory(path) 598 | 599 | # Process the response 600 | result = {"path": path, "items": []} 601 | 602 | if dir_response.children: 603 | for entry in dir_response.children: 604 | item = { 605 | "name": entry.name, 606 | "type": self.ENTRY_TYPE_MAP.get(entry.type, str(entry.type)), 607 | "owner": entry.owner, 608 | } 609 | 610 | # Add permissions if available 611 | if hasattr(entry, "permissions") and entry.permissions: 612 | item["permissions"] = [] 613 | for perm in entry.permissions: 614 | item["permissions"].append( 615 | { 616 | "subject": perm.subject, 617 | "permission_names": list(perm.permission_names), 618 | } 619 | ) 620 | 621 | result["items"].append(item) 622 | 623 | # Sort items by name for consistency 624 | result["items"].sort(key=lambda x: x["name"]) 625 | 626 | # Convert all dict keys to strings for JSON serialization 627 | safe_result = self._stringify_dict_keys(result) 628 | return [TextContent(type="text", text=json.dumps(safe_result, indent=2, cls=CustomJSONEncoder))] 629 | 630 | except Exception as e: 631 | logger.exception(f"Error listing directory {path}: {e}") 632 | safe_error = self._stringify_dict_keys({"error": f"Error listing directory {path}: {str(e)}"}) 633 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 634 | 635 | async def describe_path(self, path: str) -> List[TextContent]: 636 | """Describe a path in YDB. 637 | 638 | Args: 639 | path: Path to describe 640 | 641 | Returns: 642 | List of TextContent objects with path description 643 | """ 644 | # Check for authentication errors 645 | if self.auth_error: 646 | safe_error = {"error": self.auth_error} 647 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 648 | 649 | try: 650 | # Create driver if needed 651 | if self.driver is None: 652 | await self.create_driver() 653 | 654 | if self.driver is None: 655 | safe_error = {"error": "Failed to create driver"} 656 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 657 | 658 | # Access the scheme client 659 | scheme_client = self.driver.scheme_client 660 | 661 | # Describe the path 662 | logger.info(f"Describing path: {path}") 663 | path_response = await scheme_client.describe_path(path) 664 | 665 | # Process the response 666 | if path_response is None: 667 | safe_error = {"error": f"Path '{path}' not found"} 668 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 669 | 670 | # Format the result 671 | result = { 672 | "path": path, 673 | "type": str(path_response.type), 674 | "name": path_response.name, 675 | "owner": path_response.owner, 676 | } 677 | 678 | # Add permissions if available 679 | if hasattr(path_response, "permissions") and path_response.permissions: 680 | result["permissions"] = [] 681 | for perm in path_response.permissions: 682 | result["permissions"].append( 683 | {"subject": perm.subject, "permission_names": list(perm.permission_names)} 684 | ) 685 | 686 | # Add table specific information if it's a table 687 | if str(path_response.type) == "TABLE" or path_response.type == 2: 688 | try: 689 | # Get table client for more detailed table info 690 | table_client = self.driver.table_client 691 | session = await table_client.session().create() 692 | try: 693 | # Get detailed table description 694 | table_desc = await session.describe_table(path) 695 | result["table"] = { 696 | "columns": [], 697 | "primary_key": table_desc.primary_key, 698 | "indexes": [], 699 | "partitioning_settings": {}, 700 | "storage_settings": {}, 701 | "key_bloom_filter": table_desc.key_bloom_filter, 702 | "read_replicas_settings": table_desc.read_replicas_settings, 703 | "column_families": [], 704 | } 705 | 706 | # Add columns with more details 707 | for column in table_desc.columns: 708 | col_info = { 709 | "name": column.name, 710 | "type": str(column.type), 711 | "family": column.family, 712 | } 713 | result["table"]["columns"].append(col_info) 714 | 715 | # Add indexes with more details 716 | for index in table_desc.indexes: 717 | index_info = { 718 | "name": index.name, 719 | "index_columns": list(index.index_columns), 720 | "cover_columns": (list(index.cover_columns) if hasattr(index, "cover_columns") else []), 721 | "index_type": (str(index.index_type) if hasattr(index, "index_type") else None), 722 | } 723 | result["table"]["indexes"].append(index_info) 724 | 725 | # Add column families if present 726 | if hasattr(table_desc, "column_families"): 727 | for family in table_desc.column_families: 728 | family_info = { 729 | "name": family.name, 730 | "data": family.data, 731 | "compression": (str(family.compression) if hasattr(family, "compression") else None), 732 | } 733 | result["table"]["column_families"].append(family_info) 734 | 735 | # Add storage settings if present 736 | if hasattr(table_desc, "storage_settings"): 737 | ss = table_desc.storage_settings 738 | if ss: 739 | result["table"]["storage_settings"] = { 740 | "tablet_commit_log0": ss.tablet_commit_log0, 741 | "tablet_commit_log1": ss.tablet_commit_log1, 742 | "external": ss.external, 743 | "store_external": ss.store_external, 744 | } 745 | 746 | # Add partitioning settings if present 747 | if hasattr(table_desc, "partitioning_settings"): 748 | ps = table_desc.partitioning_settings 749 | if ps: 750 | if hasattr(ps, "partition_at_keys"): 751 | result["table"]["partitioning_settings"]["partition_at_keys"] = ps.partition_at_keys 752 | if hasattr(ps, "partition_by_size"): 753 | result["table"]["partitioning_settings"]["partition_by_size"] = ps.partition_by_size 754 | if hasattr(ps, "min_partitions_count"): 755 | result["table"]["partitioning_settings"]["min_partitions_count"] = ( 756 | ps.min_partitions_count 757 | ) 758 | if hasattr(ps, "max_partitions_count"): 759 | result["table"]["partitioning_settings"]["max_partitions_count"] = ( 760 | ps.max_partitions_count 761 | ) 762 | 763 | finally: 764 | # Always release the session 765 | await session.close() 766 | 767 | except Exception as table_error: 768 | logger.warning(f"Error getting detailed table info: {table_error}") 769 | # Fallback to basic table info from path_response 770 | if hasattr(path_response, "table") and path_response.table: 771 | result["table"] = { 772 | "columns": [], 773 | "primary_key": ( 774 | path_response.table.primary_key if hasattr(path_response.table, "primary_key") else [] 775 | ), 776 | "indexes": [], 777 | "partitioning_settings": {}, 778 | } 779 | 780 | # Add basic columns 781 | if hasattr(path_response.table, "columns"): 782 | for column in path_response.table.columns: 783 | result["table"]["columns"].append({"name": column.name, "type": str(column.type)}) 784 | 785 | # Add basic indexes 786 | if hasattr(path_response.table, "indexes"): 787 | for index in path_response.table.indexes: 788 | result["table"]["indexes"].append( 789 | { 790 | "name": index.name, 791 | "index_columns": ( 792 | list(index.index_columns) if hasattr(index, "index_columns") else [] 793 | ), 794 | } 795 | ) 796 | 797 | # Add basic partitioning settings 798 | if hasattr(path_response.table, "partitioning_settings"): 799 | ps = path_response.table.partitioning_settings 800 | if ps: 801 | if hasattr(ps, "partition_at_keys"): 802 | result["table"]["partitioning_settings"]["partition_at_keys"] = ps.partition_at_keys 803 | if hasattr(ps, "partition_by_size"): 804 | result["table"]["partitioning_settings"]["partition_by_size"] = ps.partition_by_size 805 | if hasattr(ps, "min_partitions_count"): 806 | result["table"]["partitioning_settings"]["min_partitions_count"] = ( 807 | ps.min_partitions_count 808 | ) 809 | if hasattr(ps, "max_partitions_count"): 810 | result["table"]["partitioning_settings"]["max_partitions_count"] = ( 811 | ps.max_partitions_count 812 | ) 813 | 814 | # Convert to JSON string and return as TextContent 815 | formatted_result = json.dumps(result, indent=2, cls=CustomJSONEncoder) 816 | return [TextContent(type="text", text=formatted_result)] 817 | 818 | except Exception as e: 819 | logger.exception(f"Error describing path {path}: {e}") 820 | safe_error = {"error": f"Error describing path {path}: {str(e)}"} 821 | return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] 822 | 823 | async def restart(self): 824 | """Restart the YDB connection by closing and recreating the driver.""" 825 | logger.info("Restarting YDB connection") 826 | 827 | # Close session pool first 828 | if self.pool is not None: 829 | logger.info("Closing YDB session pool") 830 | try: 831 | await asyncio.shield(self.pool.stop()) 832 | except Exception as e: 833 | logger.warning(f"Error closing session pool: {e}") 834 | self.pool = None 835 | 836 | # Stop the driver 837 | if self.driver is not None: 838 | logger.info("Stopping YDB driver") 839 | try: 840 | # Cancel any pending discovery tasks first 841 | if hasattr(self.driver, "discovery") and self.driver.discovery is not None: 842 | try: 843 | # Stop discovery process 844 | if hasattr(self.driver.discovery, "stop"): 845 | self.driver.discovery.stop() 846 | 847 | # Cancel discovery task if it exists 848 | if hasattr(self.driver.discovery, "_discovery_task"): 849 | task = self.driver.discovery._discovery_task 850 | if task and not task.done() and not task.cancelled(): 851 | task.cancel() 852 | try: 853 | await asyncio.shield(asyncio.wait_for(task, timeout=1)) 854 | except (asyncio.CancelledError, asyncio.TimeoutError): 855 | pass 856 | 857 | except Exception as e: 858 | logger.warning(f"Error handling discovery task: {e}") 859 | 860 | # Stop the driver with proper error handling 861 | try: 862 | # Use shield to prevent cancellation of the stop operation 863 | await asyncio.shield(asyncio.wait_for(self.driver.stop(), timeout=5)) 864 | except asyncio.TimeoutError: 865 | logger.warning("Driver stop timed out") 866 | except asyncio.CancelledError: 867 | logger.warning("Driver stop was cancelled") 868 | except Exception as e: 869 | logger.warning(f"Error stopping driver: {e}") 870 | 871 | except Exception as e: 872 | logger.warning(f"Error during driver cleanup: {e}") 873 | finally: 874 | self.driver = None 875 | 876 | # Create new driver 877 | logger.info("Creating new YDB driver") 878 | try: 879 | new_driver = await self.create_driver() 880 | if new_driver is None: 881 | logger.error("Failed to create new driver during restart") 882 | return False 883 | return True 884 | except Exception as e: 885 | logger.error(f"Failed to create new driver during restart: {e}") 886 | return False 887 | 888 | def _text_content_to_dict(self, text_content_list): 889 | """Convert TextContent objects to serializable dictionaries. 890 | 891 | Args: 892 | text_content_list: List of TextContent objects 893 | 894 | Returns: 895 | List of dictionaries 896 | """ 897 | result = [] 898 | for item in text_content_list: 899 | if isinstance(item, TextContent): 900 | result.append({"type": item.type, "text": item.text}) 901 | else: 902 | result.append(item) 903 | return result 904 | 905 | async def call_tool(self, tool_name: str, params: Dict[str, Any]) -> List[TextContent]: 906 | """Call a registered tool. 907 | 908 | Args: 909 | tool_name: Name of the tool to call 910 | params: Parameters to pass to the tool 911 | 912 | Returns: 913 | List of TextContent objects or serializable dicts 914 | 915 | Raises: 916 | ValueError: If the tool is not found 917 | """ 918 | tool = self.tool_manager.get(tool_name) 919 | if not tool: 920 | raise ValueError(f"Tool not found: {tool_name}") 921 | 922 | logger.info(f"Calling tool: {tool_name} with params: {params}") 923 | try: 924 | result = None 925 | 926 | # Special handling for YDB tools to directly call methods with correct parameters 927 | if tool_name == "ydb_query" and "sql" in params: 928 | result = await self.query(sql=params["sql"]) 929 | elif tool_name == "ydb_query_with_params" and "sql" in params and "params" in params: 930 | result = await self.query_with_params(sql=params["sql"], params=params["params"]) 931 | elif tool_name == "ydb_status": 932 | result = await self.get_connection_status() 933 | elif tool_name == "ydb_list_directory" and "path" in params: 934 | result = await self.list_directory(path=params["path"]) 935 | elif tool_name == "ydb_describe_path" and "path" in params: 936 | result = await self.describe_path(path=params["path"]) 937 | else: 938 | # For other tools, use the standard handler 939 | result = await tool.handler(**params) 940 | 941 | # Convert TextContent objects to dictionaries if needed 942 | if isinstance(result, list) and any(isinstance(item, TextContent) for item in result): 943 | serializable_result = self._text_content_to_dict(result) 944 | return serializable_result # type: ignore 945 | 946 | # Handle any other result type 947 | if result is None: 948 | return [TextContent(type="text", text="Operation completed successfully but returned no data")] 949 | 950 | return result 951 | 952 | except Exception as e: 953 | logger.exception(f"Error calling tool {tool_name}: {e}") 954 | error_msg = f"Error executing {tool_name}: {str(e)}" 955 | return [TextContent(type="text", text=error_msg)] 956 | 957 | def get_tool_schema(self) -> List[Dict[str, Any]]: 958 | """Get JSON schema for all registered tools. 959 | 960 | Returns: 961 | List of tool schema definitions 962 | """ 963 | return self.tool_manager.get_schema() 964 | 965 | def run(self): 966 | """Run the YDB MCP server using the FastMCP server implementation.""" 967 | print("Starting YDB MCP server") 968 | print(f"YDB endpoint: {self.endpoint or 'Not set'}") 969 | print(f"YDB database: {self.database or 'Not set'}") 970 | logger.info("Starting YDB MCP server") 971 | 972 | # Use FastMCP's built-in run method with stdio transport 973 | super().run(transport="stdio") 974 | 975 | def get_credentials_factory(self) -> Optional[Callable[[], ydb.Credentials]]: 976 | """Get YDB credentials factory based on authentication mode. 977 | 978 | Returns: 979 | Callable that creates YDB credentials, or None if authentication fails 980 | """ 981 | # Clear any previous auth errors 982 | self.auth_error = None 983 | 984 | supported_auth_modes = {AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD} 985 | if self.auth_mode not in supported_auth_modes: 986 | self.auth_error = ( 987 | f"Unsupported auth mode: {self.auth_mode}. Supported modes: {', '.join(supported_auth_modes)}" 988 | ) 989 | return None 990 | 991 | # If auth_mode is login_password and we have both login and password, use them 992 | if self.auth_mode == AUTH_MODE_LOGIN_PASSWORD: 993 | if not self.login or not self.password: 994 | self.auth_error = "Login and password must be provided for login-password authentication mode." 995 | return None 996 | logger.info(f"Using login/password authentication with user '{self.login}'") 997 | return self._login_password_credentials 998 | else: 999 | # Default to anonymous auth 1000 | logger.info("Using anonymous authentication") 1001 | return self._anonymous_credentials 1002 | -------------------------------------------------------------------------------- /ydb_mcp/tool_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional 2 | 3 | 4 | class ToolDefinition: 5 | """Defines a tool that can be called by the MCP.""" 6 | 7 | def __init__(self, name: str, handler: Callable, description: str = "", parameters: Optional[Dict] = None): 8 | """Initialize a tool definition. 9 | 10 | Args: 11 | name: Name of the tool 12 | handler: Async callable that handles the tool execution 13 | description: Tool description 14 | parameters: JSON schema for the tool parameters 15 | """ 16 | self.name = name 17 | self.handler = handler 18 | self.description = description 19 | self.parameters = parameters or {} 20 | 21 | 22 | class ToolManager: 23 | """Manages MCP tools for YDB interactions.""" 24 | 25 | def __init__(self): 26 | """Initialize the tool manager.""" 27 | self._tools: Dict[str, ToolDefinition] = {} 28 | 29 | def register_tool( 30 | self, name: str, handler: Callable, description: str = "", parameters: Optional[Dict] = None 31 | ) -> None: 32 | """Register a tool with the manager. 33 | 34 | Args: 35 | name: Name of the tool 36 | handler: Async callable that handles the tool execution 37 | description: Tool description 38 | parameters: JSON schema for tool parameters 39 | """ 40 | self._tools[name] = ToolDefinition(name=name, handler=handler, description=description, parameters=parameters) 41 | 42 | def get(self, name: str) -> Optional[ToolDefinition]: 43 | """Get a tool by name. 44 | 45 | Args: 46 | name: Name of the tool to retrieve 47 | 48 | Returns: 49 | Tool definition if found, None otherwise 50 | """ 51 | return self._tools.get(name) 52 | 53 | def get_all_tools(self) -> Dict[str, ToolDefinition]: 54 | """Get all registered tools. 55 | 56 | Returns: 57 | Dictionary of tool name to tool definition 58 | """ 59 | return self._tools 60 | 61 | def get_schema(self) -> List[Dict[str, Any]]: 62 | """Get JSON schema for all registered tools. 63 | 64 | Returns: 65 | List of tool schema definitions 66 | """ 67 | result = [] 68 | for name, tool in self._tools.items(): 69 | tool_schema: Dict[str, Any] = { 70 | "name": name, 71 | "description": tool.description, 72 | } 73 | 74 | if tool.parameters: 75 | tool_schema["parameters"] = tool.parameters 76 | 77 | result.append(tool_schema) 78 | 79 | return result 80 | -------------------------------------------------------------------------------- /ydb_mcp/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.1.1" 2 | --------------------------------------------------------------------------------