├── .editorconfig ├── .github └── workflows │ ├── publish.yml │ └── tests.yml ├── .gitignore ├── .pylintrc ├── LICENSE ├── Makefile ├── README.md ├── examples ├── auth_clear_password.py ├── auth_native_password.py ├── dbapi_proxy.py ├── simple.py └── streaming_results.py ├── integration ├── mysql-connector-j │ ├── Makefile │ ├── pom.xml │ └── src │ │ ├── main │ │ └── java │ │ │ └── .gitkeep │ │ └── test │ │ └── java │ │ └── com │ │ └── mysql_mimic │ │ └── integration │ │ └── IntegrationTest.java └── run.py ├── mysql_mimic ├── __init__.py ├── auth.py ├── charset.py ├── connection.py ├── constants.py ├── context.py ├── control.py ├── errors.py ├── functions.py ├── intercept.py ├── packets.py ├── prepared.py ├── results.py ├── schema.py ├── server.py ├── session.py ├── stream.py ├── types.py ├── utils.py ├── variable_processor.py ├── variables.py └── version.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── fixtures ├── __init__.py ├── certificate.pem ├── key.pem └── queries.py ├── test_auth.py ├── test_control.py ├── test_query.py ├── test_result.py ├── test_schema.py ├── test_ssl.py ├── test_stream.py ├── test_types.py ├── test_variables.py └── test_version.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [Makefile] 12 | indent_style = tab 13 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.9 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.9' 17 | cache: pip 18 | cache-dependency-path: setup.py 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | make deps 23 | - name: Build 24 | run: | 25 | make build 26 | - name: Publish 27 | env: 28 | TWINE_USERNAME: __token__ 29 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 30 | run: | 31 | make publish 32 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | strategy: 12 | matrix: 13 | version: [ 14 | { python: "3.7", ubuntu: "ubuntu-22.04" }, 15 | { python: "3.8", ubuntu: "ubuntu-22.04" }, 16 | { python: "3.9", ubuntu: "ubuntu-latest" }, 17 | { python: "3.10", ubuntu: "ubuntu-latest" }, 18 | { python: "3.11", ubuntu: "ubuntu-latest" }, 19 | { python: "3.12", ubuntu: "ubuntu-latest" } ] 20 | runs-on: ${{ matrix.version.ubuntu }} 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Kerberos 24 | run: sudo apt-get install -y libkrb5-dev krb5-kdc krb5-admin-server 25 | - name: Set up Python ${{ matrix.version.python }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.version.python }} 29 | cache: pip 30 | cache-dependency-path: setup.py 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | make deps 35 | - name: Test 36 | run: | 37 | make test 38 | - name: Lint 39 | if: matrix.version.python != '3.7' 40 | run: | 41 | make lint 42 | - name: Format 43 | if: matrix.version.python != '3.7' 44 | run: | 45 | make format-check 46 | - name: Type annotations 47 | if: matrix.version.python != '3.7' 48 | run: | 49 | make types 50 | mysql-connector-j: 51 | name: Integration (mysql-connector-j) 52 | runs-on: ubuntu-latest 53 | steps: 54 | - uses: actions/checkout@v2 55 | - name: Set up Kerberos 56 | run: sudo apt-get install -y libkrb5-dev krb5-kdc krb5-admin-server 57 | - name: Set up Python 58 | uses: actions/setup-python@v5 59 | with: 60 | python-version: '3.10' 61 | cache: pip 62 | cache-dependency-path: setup.py 63 | - name: Install Python dependencies 64 | run: | 65 | python -m pip install --upgrade pip 66 | make deps 67 | - name: Set up Java 68 | uses: actions/setup-java@v3 69 | with: 70 | distribution: 'temurin' 71 | java-version: '17' 72 | cache: 'maven' 73 | - name: Test mysql-connector-j 74 | run: | 75 | python integration/run.py integration/mysql-connector-j/ 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | venv*/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # PyCharm 133 | .idea/ 134 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | 3 | disable= 4 | import-outside-toplevel, 5 | invalid-name, 6 | line-too-long, 7 | missing-class-docstring, 8 | missing-function-docstring, 9 | missing-module-docstring, 10 | too-few-public-methods, 11 | too-many-instance-attributes, 12 | too-many-public-methods, 13 | too-many-branches, 14 | too-many-return-statements, 15 | unused-argument, 16 | redefined-outer-name, 17 | too-many-statements, 18 | multiple-statements, 19 | 20 | [FORMAT] 21 | 22 | max-args=10 23 | max-line-length=130 24 | max-positional-arguments=15 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Christopher Giroir 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_TARGET: deps 2 | 3 | .PHONY: deps format format-check run lint test build publish clean 4 | 5 | deps: 6 | pip install --progress-bar off -e .[dev] 7 | 8 | format: 9 | python -m black . 10 | 11 | format-check: 12 | python -m black --check . 13 | 14 | run: 15 | python -m mysql_mimic.server 16 | 17 | lint: 18 | python -m pylint mysql_mimic/ tests/ 19 | 20 | types: 21 | python -m mypy -p mysql_mimic -p tests 22 | 23 | test: 24 | coverage run --source=mysql_mimic -m pytest 25 | coverage report 26 | coverage html 27 | 28 | check: format-check lint types test 29 | 30 | build: clean 31 | python setup.py sdist bdist_wheel 32 | 33 | publish: build 34 | twine upload dist/* 35 | 36 | clean: 37 | rm -rf build dist 38 | find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MySQL-Mimic 2 | 3 | [![Tests](https://github.com/kelsin/mysql-mimic/actions/workflows/tests.yml/badge.svg)](https://github.com/kelsin/mysql-mimic/actions/workflows/tests.yml) 4 | 5 | Pure-python implementation of the MySQL server [wire protocol](https://dev.mysql.com/doc/internals/en/client-server-protocol.html). 6 | 7 | This can be used to create applications that act as a MySQL server. 8 | 9 | ## Installation 10 | 11 | ```shell 12 | pip install mysql-mimic 13 | ``` 14 | 15 | ## Usage 16 | 17 | A minimal use case might look like this: 18 | 19 | ```python 20 | import asyncio 21 | 22 | from mysql_mimic import MysqlServer, Session 23 | 24 | 25 | class MySession(Session): 26 | async def query(self, expression, sql, attrs): 27 | print(f"Parsed abstract syntax tree: {expression}") 28 | print(f"Original SQL string: {sql}") 29 | print(f"Query attributes: {sql}") 30 | print(f"Currently authenticated user: {self.username}") 31 | print(f"Currently selected database: {self.database}") 32 | return [("a", 1), ("b", 2)], ["col1", "col2"] 33 | 34 | async def schema(self): 35 | # Optionally provide the database schema. 36 | # This is used to serve INFORMATION_SCHEMA and SHOW queries. 37 | return { 38 | "table": { 39 | "col1": "TEXT", 40 | "col2": "INT", 41 | } 42 | } 43 | 44 | if __name__ == "__main__": 45 | server = MysqlServer(session_factory=MySession) 46 | asyncio.run(server.serve_forever()) 47 | ``` 48 | 49 | Using [sqlglot](https://github.com/tobymao/sqlglot), the abstract `Session` class handles queries to metadata, variables, etc. that many MySQL clients expect. 50 | 51 | To bypass this default behavior, you can implement the [`mysql_mimic.session.BaseSession`](mysql_mimic/session.py) interface. 52 | 53 | See [examples](./examples) for more examples. 54 | 55 | ## Authentication 56 | 57 | MySQL-mimic has built in support for several standard MySQL authentication plugins: 58 | - [mysql_native_password](https://dev.mysql.com/doc/refman/8.0/en/native-pluggable-authentication.html) 59 | - The client sends hashed passwords to the server, and the server stores hashed passwords. See the documentation for more details on how this works. 60 | - [example](examples/auth_native_password.py) 61 | - [mysql_clear_password](https://dev.mysql.com/doc/refman/8.0/en/cleartext-pluggable-authentication.html) 62 | - The client sends passwords to the server as clear text, without hashing or encryption. 63 | - This is typically used as the client plugin for a custom server plugin. As such, MySQL-mimic provides an abstract class, [`mysql_mimic.auth.AbstractClearPasswordAuthPlugin`](mysql_mimic/auth.py), which can be extended. 64 | - [example](examples/auth_clear_password.py) 65 | - [mysql_no_login](https://dev.mysql.com/doc/refman/8.0/en/no-login-pluggable-authentication.html) 66 | - The server prevents clients from directly authenticating as an account. See the documentation for relevant use cases. 67 | - [authentication_kerberos](https://dev.mysql.com/doc/mysql-security-excerpt/8.0/en/kerberos-pluggable-authentication.html) 68 | - Kerberos uses tickets together with symmetric-key cryptography, enabling authentication without sending passwords over the network. Kerberos authentication supports userless and passwordless scenarios. 69 | 70 | 71 | By default, a session naively accepts whatever username the client provides. 72 | 73 | Plugins are provided to the server by implementing [`mysql_mimic.IdentityProvider`](mysql_mimic/auth.py), which configures all available plugins and a callback for fetching users. 74 | 75 | Custom plugins can be created by extending [`mysql_mimic.auth.AuthPlugin`](mysql_mimic/auth.py). 76 | 77 | ## Development 78 | 79 | You can install dependencies with `make deps`. 80 | 81 | You can format your code with `make format`. 82 | 83 | You can lint with `make lint`. 84 | 85 | You can check type annotations with `make types`. 86 | 87 | You can run tests with `make test`. This will build a coverage report in `./htmlcov/index.html`. 88 | 89 | You can run all the checks with `make check`. 90 | 91 | You can build a pip package with `make build`. 92 | -------------------------------------------------------------------------------- /examples/auth_clear_password.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import asyncio 3 | 4 | from mysql_mimic import ( 5 | MysqlServer, 6 | IdentityProvider, 7 | User, 8 | ) 9 | from mysql_mimic.auth import AbstractClearPasswordAuthPlugin 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | # Storing plain text passwords is not safe. 15 | # This is just done for demonstration purposes. 16 | USERS = { 17 | "user1": "password1", 18 | "user2": "password2", 19 | } 20 | 21 | 22 | class CustomAuthPlugin(AbstractClearPasswordAuthPlugin): 23 | name = "custom_plugin" 24 | 25 | async def check(self, username, password): 26 | return username if USERS.get(username) == password else None 27 | 28 | 29 | class CustomIdentityProvider(IdentityProvider): 30 | def get_plugins(self): 31 | return [CustomAuthPlugin()] 32 | 33 | async def get_user(self, username): 34 | # Because we're storing users/passwords in an external system (the USERS dictionary, in this case), 35 | # we just assume all users exist. 36 | return User(name=username, auth_plugin=CustomAuthPlugin.name) 37 | 38 | 39 | async def main(): 40 | logging.basicConfig(level=logging.DEBUG) 41 | identity_provider = CustomIdentityProvider() 42 | server = MysqlServer(identity_provider=identity_provider) 43 | await server.serve_forever() 44 | 45 | 46 | if __name__ == "__main__": 47 | asyncio.run(main()) 48 | # To connect with the MySQL command line interface: 49 | # mysql -h127.0.0.1 -P3306 -uuser1 -ppassword1 --enable-cleartext-plugin 50 | -------------------------------------------------------------------------------- /examples/auth_native_password.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import asyncio 3 | 4 | from mysql_mimic import ( 5 | MysqlServer, 6 | IdentityProvider, 7 | NativePasswordAuthPlugin, 8 | User, 9 | ) 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class CustomIdentityProvider(IdentityProvider): 15 | def __init__(self, passwords): 16 | # Storing passwords in plain text isn't safe. 17 | # This is done for demonstration purposes. 18 | # It's better to store the password hash, as returned by `NativePasswordAuthPlugin.create_auth_string` 19 | self.passwords = passwords 20 | 21 | def get_plugins(self): 22 | return [NativePasswordAuthPlugin()] 23 | 24 | async def get_user(self, username): 25 | password = self.passwords.get(username) 26 | if password: 27 | return User( 28 | name=username, 29 | auth_string=NativePasswordAuthPlugin.create_auth_string(password), 30 | auth_plugin=NativePasswordAuthPlugin.name, 31 | ) 32 | return None 33 | 34 | 35 | async def main(): 36 | logging.basicConfig(level=logging.DEBUG) 37 | identity_provider = CustomIdentityProvider(passwords={"user": "password"}) 38 | server = MysqlServer(identity_provider=identity_provider) 39 | await server.serve_forever() 40 | 41 | 42 | if __name__ == "__main__": 43 | asyncio.run(main()) 44 | -------------------------------------------------------------------------------- /examples/dbapi_proxy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import asyncio 3 | import sqlite3 4 | 5 | from mysql_mimic import MysqlServer, Session 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class DbapiProxySession(Session): 11 | def __init__(self): 12 | super().__init__() 13 | self.conn = sqlite3.connect(":memory:") 14 | 15 | async def query(self, expression, sql, attrs): 16 | cursor = self.conn.cursor() 17 | cursor.execute(expression.sql(dialect="sqlite")) 18 | try: 19 | rows = cursor.fetchall() 20 | if cursor.description: 21 | columns = [c[0] for c in cursor.description] 22 | return rows, columns 23 | return None 24 | finally: 25 | cursor.close() 26 | 27 | 28 | async def main(): 29 | logging.basicConfig(level=logging.DEBUG) 30 | server = MysqlServer(session_factory=DbapiProxySession) 31 | await server.serve_forever() 32 | 33 | 34 | if __name__ == "__main__": 35 | asyncio.run(main()) 36 | -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import asyncio 3 | from sqlglot.executor import execute 4 | 5 | from mysql_mimic import MysqlServer, Session 6 | 7 | 8 | SCHEMA = { 9 | "test": { 10 | "x": { 11 | "a": "INT", 12 | } 13 | } 14 | } 15 | 16 | TABLES = { 17 | "test": { 18 | "x": [ 19 | {"a": 1}, 20 | {"a": 2}, 21 | {"a": 3}, 22 | ] 23 | } 24 | } 25 | 26 | 27 | class MySession(Session): 28 | async def query(self, expression, sql, attrs): 29 | result = execute(expression, schema=SCHEMA, tables=TABLES) 30 | return result.rows, result.columns 31 | 32 | async def schema(self): 33 | return SCHEMA 34 | 35 | 36 | async def main(): 37 | logging.basicConfig(level=logging.DEBUG) 38 | server = MysqlServer(session_factory=MySession) 39 | await server.serve_forever() 40 | 41 | 42 | if __name__ == "__main__": 43 | asyncio.run(main()) 44 | -------------------------------------------------------------------------------- /examples/streaming_results.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import asyncio 3 | 4 | from mysql_mimic import MysqlServer, Session 5 | 6 | 7 | class MySession(Session): 8 | async def generate_rows(self, n): 9 | for i in range(n): 10 | if i % 100 == 0: 11 | logging.info("Pretending to fetch another batch of results...") 12 | await asyncio.sleep(1) 13 | yield i, 14 | 15 | async def query(self, expression, sql, attrs): 16 | return self.generate_rows(1000), ["a"] 17 | 18 | 19 | async def main(): 20 | logging.basicConfig(level=logging.DEBUG) 21 | server = MysqlServer(session_factory=MySession) 22 | await server.serve_forever() 23 | 24 | 25 | if __name__ == "__main__": 26 | asyncio.run(main()) 27 | -------------------------------------------------------------------------------- /integration/mysql-connector-j/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | 3 | test: 4 | mvn test 5 | 6 | test-debug: 7 | mvn -X test 8 | -------------------------------------------------------------------------------- /integration/mysql-connector-j/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | com.mysql_mimic.integration 8 | integration 9 | 1.0-SNAPSHOT 10 | 11 | integration 12 | 13 | http://www.example.com 14 | 15 | 16 | UTF-8 17 | 1.7 18 | 1.7 19 | 20 | 21 | 22 | 23 | junit 24 | junit 25 | 4.11 26 | test 27 | 28 | 29 | com.mysql 30 | mysql-connector-j 31 | 8.0.31 32 | 33 | 34 | com.google.protobuf 35 | protobuf-java 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | maven-clean-plugin 47 | 3.1.0 48 | 49 | 50 | 51 | maven-resources-plugin 52 | 3.0.2 53 | 54 | 55 | maven-compiler-plugin 56 | 3.8.0 57 | 58 | 59 | maven-surefire-plugin 60 | 2.22.1 61 | 62 | 63 | maven-jar-plugin 64 | 3.0.2 65 | 66 | 67 | maven-install-plugin 68 | 2.5.2 69 | 70 | 71 | maven-deploy-plugin 72 | 2.8.2 73 | 74 | 75 | 76 | maven-site-plugin 77 | 3.7.1 78 | 79 | 80 | maven-project-info-reports-plugin 81 | 3.0.0 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /integration/mysql-connector-j/src/main/java/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barakalon/mysql-mimic/570cd6cb862a52fabbbdf15d801f4855d530c11d/integration/mysql-connector-j/src/main/java/.gitkeep -------------------------------------------------------------------------------- /integration/mysql-connector-j/src/test/java/com/mysql_mimic/integration/IntegrationTest.java: -------------------------------------------------------------------------------- 1 | package com.mysql_mimic.integration; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | 5 | import org.junit.Test; 6 | import java.sql.Connection; 7 | import java.sql.DriverManager; 8 | import java.sql.Statement; 9 | import java.sql.SQLException; 10 | import java.sql.ResultSet; 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | 14 | public class IntegrationTest 15 | { 16 | @Test 17 | public void test() throws Exception { 18 | Class.forName("com.mysql.cj.jdbc.Driver").newInstance(); 19 | 20 | String port = System.getenv("PORT"); 21 | String url = String.format("jdbc:mysql://user:password@127.0.0.1:%s", port); 22 | 23 | Connection conn = DriverManager.getConnection(url); 24 | 25 | Statement stmt = conn.createStatement(); 26 | ResultSet rs = stmt.executeQuery("SELECT a FROM x ORDER BY a"); 27 | 28 | List result = new ArrayList<>(); 29 | while (rs.next()) { 30 | result.add(rs.getInt("a")); 31 | } 32 | 33 | List expected = new ArrayList<>(); 34 | expected.add(1); 35 | expected.add(2); 36 | expected.add(3); 37 | 38 | assertEquals(expected, result); 39 | } 40 | 41 | @Test 42 | public void testKrb5() throws Exception { 43 | // System.setProperty("sun.security.krb5.debug", "True"); 44 | // System.setProperty("sun.security.jgss.debug", "True"); 45 | System.setProperty("java.security.krb5.conf", System.getenv("KRB5_CONFIG")); 46 | System.setProperty("java.security.auth.login.config", System.getenv("JAAS_CONFIG")); 47 | 48 | Class.forName("com.mysql.cj.jdbc.Driver").newInstance(); 49 | 50 | String port = System.getenv("PORT"); 51 | String url = String.format("jdbc:mysql://127.0.0.1:%s/?defaultAuthenticationPlugin=authentication_kerberos_client", port); 52 | 53 | Connection conn = DriverManager.getConnection(url); 54 | 55 | Statement stmt = conn.createStatement(); 56 | ResultSet rs = stmt.executeQuery("SELECT CURRENT_USER AS CURRENT_USER"); 57 | 58 | rs.next(); 59 | 60 | assertEquals("krb5_user", rs.getString("CURRENT_USER")); 61 | } 62 | 63 | @Test(expected = SQLException.class) 64 | public void testKrb5IncorrectUser() throws Exception { 65 | System.setProperty("java.security.krb5.conf", System.getenv("KRB5_CONFIG")); 66 | System.setProperty("java.security.auth.login.config", System.getenv("JAAS_CONFIG")); 67 | 68 | Class.forName("com.mysql.cj.jdbc.Driver").newInstance(); 69 | 70 | String user = "incorrect_user"; 71 | String port = System.getenv("PORT"); 72 | String url = String.format("jdbc:mysql://%s@127.0.0.1:%s/?defaultAuthenticationPlugin=authentication_kerberos_client", user, port); 73 | 74 | Connection conn = DriverManager.getConnection(url); 75 | 76 | Statement stmt = conn.createStatement(); 77 | stmt.executeQuery("SELECT a FROM x ORDER BY a"); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /integration/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import asyncio 4 | import os 5 | import sys 6 | import tempfile 7 | import time 8 | 9 | import k5test 10 | from sqlglot.executor import execute 11 | 12 | from mysql_mimic import ( 13 | MysqlServer, 14 | IdentityProvider, 15 | NativePasswordAuthPlugin, 16 | User, 17 | Session, 18 | ) 19 | from mysql_mimic.auth import KerberosAuthPlugin 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | SCHEMA = { 24 | "test": { 25 | "x": { 26 | "a": "INT", 27 | } 28 | } 29 | } 30 | 31 | TABLES = { 32 | "test": { 33 | "x": [ 34 | {"a": 1}, 35 | {"a": 2}, 36 | {"a": 3}, 37 | ] 38 | } 39 | } 40 | 41 | 42 | class MySession(Session): 43 | async def query(self, expression, sql, attrs): 44 | result = execute(expression, schema=SCHEMA, tables=TABLES) 45 | return result.rows, result.columns 46 | 47 | async def schema(self): 48 | return SCHEMA 49 | 50 | 51 | class CustomIdentityProvider(IdentityProvider): 52 | def __init__(self, krb5_service, krb5_realm): 53 | self.users = { 54 | "user": {"auth_plugin": "mysql_native_password", "password": "password"}, 55 | "krb5_user": {"auth_plugin": "authentication_kerberos"}, 56 | } 57 | self.krb5_service = krb5_service 58 | self.krb5_realm = krb5_realm 59 | 60 | def get_plugins(self): 61 | return [ 62 | NativePasswordAuthPlugin(), 63 | KerberosAuthPlugin(service=self.krb5_service, realm=self.krb5_realm), 64 | ] 65 | 66 | async def get_user(self, username): 67 | user = self.users.get(username) 68 | if user is not None: 69 | auth_plugin = user["auth_plugin"] 70 | 71 | if auth_plugin == "mysql_native_password": 72 | password = user.get("password") 73 | return User( 74 | name=username, 75 | auth_string=( 76 | NativePasswordAuthPlugin.create_auth_string(password) 77 | if password 78 | else None 79 | ), 80 | auth_plugin=NativePasswordAuthPlugin.name, 81 | ) 82 | elif auth_plugin == "authentication_kerberos": 83 | return User(name=username, auth_plugin=KerberosAuthPlugin.name) 84 | return None 85 | 86 | 87 | async def wait_for_port(port, host="localhost", timeout=5.0): 88 | start_time = time.time() 89 | while True: 90 | try: 91 | _ = await asyncio.open_connection(host, port) 92 | break 93 | except OSError: 94 | await asyncio.sleep(0.01) 95 | if time.time() - start_time >= timeout: 96 | raise TimeoutError() 97 | 98 | 99 | def setup_krb5(krb5_user): 100 | realm = k5test.K5Realm() 101 | krb5_user_princ = f"{krb5_user}@{realm.realm}" 102 | realm.addprinc(krb5_user_princ, realm.password(krb5_user)) 103 | realm.kinit(krb5_user_princ, realm.password(krb5_user)) 104 | return realm 105 | 106 | 107 | def write_jaas_conf(realm, debug=False): 108 | conf = f""" 109 | MySQLConnectorJ {{ 110 | com.sun.security.auth.module.Krb5LoginModule 111 | required 112 | debug={str(debug).lower()} 113 | useTicketCache=true 114 | ticketCache="{realm.env["KRB5CCNAME"]}"; 115 | }}; 116 | """ 117 | path = tempfile.mktemp() 118 | with open(path, "w") as f: 119 | f.write(conf) 120 | 121 | return path 122 | 123 | 124 | async def main(): 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("test_dir") 127 | parser.add_argument("-p", "--port", type=int, default=3308) 128 | args = parser.parse_args() 129 | 130 | logging.basicConfig(level=logging.DEBUG) 131 | realm = None 132 | jaas_conf = None 133 | 134 | try: 135 | realm = setup_krb5(krb5_user="krb5_user") 136 | os.environ.update(realm.env) 137 | 138 | jaas_conf = write_jaas_conf(realm) 139 | os.environ["JAAS_CONFIG"] = jaas_conf 140 | 141 | krb5_service = realm.host_princ[: realm.host_princ.index("@")] 142 | identity_provider = CustomIdentityProvider( 143 | krb5_service=krb5_service, krb5_realm=realm.realm 144 | ) 145 | server = MysqlServer( 146 | identity_provider=identity_provider, session_factory=MySession 147 | ) 148 | 149 | await server.start_server(port=args.port) 150 | await wait_for_port(port=args.port) 151 | process = await asyncio.create_subprocess_shell( 152 | "make test", 153 | env={**os.environ, "PORT": str(args.port)}, 154 | cwd=args.test_dir, 155 | ) 156 | return_code = await process.wait() 157 | 158 | server.close() 159 | await server.wait_closed() 160 | return return_code 161 | finally: 162 | if realm: 163 | realm.stop() 164 | if jaas_conf: 165 | os.remove(jaas_conf) 166 | 167 | 168 | if __name__ == "__main__": 169 | sys.exit(asyncio.run(main())) 170 | -------------------------------------------------------------------------------- /mysql_mimic/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of the mysql server wire protocol""" 2 | 3 | from mysql_mimic.auth import ( 4 | User, 5 | IdentityProvider, 6 | NativePasswordAuthPlugin, 7 | NoLoginAuthPlugin, 8 | AuthPlugin, 9 | ) 10 | from mysql_mimic.results import AllowedResult, ResultColumn, ResultSet 11 | from mysql_mimic.session import Session 12 | from mysql_mimic.server import MysqlServer 13 | from mysql_mimic.types import ColumnType 14 | -------------------------------------------------------------------------------- /mysql_mimic/auth.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | from copy import copy 5 | from hashlib import sha1 6 | import logging 7 | from dataclasses import dataclass 8 | from typing import Optional, Dict, AsyncGenerator, Union, Tuple, Sequence 9 | 10 | from mysql_mimic.types import read_str_null 11 | from mysql_mimic import utils 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | # Many authentication plugins don't need to send any sort of challenge/nonce. 17 | FILLER = b"0" * 20 + b"\x00" 18 | 19 | 20 | @dataclass 21 | class Forbidden: 22 | msg: Optional[str] = None 23 | 24 | 25 | @dataclass 26 | class Success: 27 | authenticated_as: str 28 | 29 | 30 | @dataclass 31 | class User: 32 | name: str 33 | auth_string: Optional[str] = None 34 | auth_plugin: Optional[str] = None 35 | 36 | # For plugins that support a primary and secondary password 37 | # This is helpful for zero-downtime password rotation 38 | old_auth_string: Optional[str] = None 39 | 40 | 41 | @dataclass 42 | class AuthInfo: 43 | username: str 44 | data: bytes 45 | user: User 46 | connect_attrs: Dict[str, str] 47 | client_plugin_name: Optional[str] 48 | handshake_auth_data: Optional[bytes] 49 | handshake_plugin_name: str 50 | 51 | def copy(self, data: bytes) -> AuthInfo: 52 | new = copy(self) 53 | new.data = data 54 | return new 55 | 56 | 57 | Decision = Union[Success, Forbidden, bytes] 58 | AuthState = AsyncGenerator[Decision, AuthInfo] 59 | 60 | 61 | class AuthPlugin: 62 | """ 63 | Abstract base class for authentication plugins. 64 | """ 65 | 66 | name = "" 67 | client_plugin_name: Optional[str] = None # None means any 68 | 69 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState: 70 | """ 71 | Create an async generator that drives the authentication lifecycle. 72 | 73 | This should either yield `bytes`, in which case an AuthMoreData packet is sent to the client, 74 | or a `Success` or `Forbidden` instance, in which case authentication is complete. 75 | 76 | Args: 77 | auth_info: This is None if authentication is starting from the optimistic handshake. 78 | """ 79 | yield Forbidden() 80 | 81 | async def start( 82 | self, auth_info: Optional[AuthInfo] = None 83 | ) -> Tuple[Decision, AuthState]: 84 | state = self.auth(auth_info) 85 | data = await state.__anext__() 86 | return data, state 87 | 88 | 89 | class AbstractClearPasswordAuthPlugin(AuthPlugin): 90 | """ 91 | Abstract class for implementing the server-side of the standard client plugin "mysql_clear_password". 92 | """ 93 | 94 | name = "abstract_mysql_clear_password" 95 | client_plugin_name = "mysql_clear_password" 96 | 97 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState: 98 | if not auth_info: 99 | auth_info = yield FILLER 100 | 101 | r = io.BytesIO(auth_info.data) 102 | password = read_str_null(r).decode() 103 | authenticated_as = await self.check(auth_info.username, password) 104 | if authenticated_as is not None: 105 | yield Success(authenticated_as) 106 | else: 107 | yield Forbidden() 108 | 109 | async def check(self, username: str, password: str) -> Optional[str]: 110 | return username 111 | 112 | 113 | class NativePasswordAuthPlugin(AuthPlugin): 114 | """ 115 | Standard plugin that uses a password hashing method. 116 | 117 | The client hashed the password using a nonce provided by the server, so the 118 | password can't be snooped on the network. 119 | 120 | Furthermore, thanks to some clever hashing techniques, knowing the hash stored in the 121 | user database isn't enough to authenticate as that user. 122 | """ 123 | 124 | name = "mysql_native_password" 125 | client_plugin_name = "mysql_native_password" 126 | 127 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState: 128 | if ( 129 | auth_info 130 | and auth_info.handshake_plugin_name == self.name 131 | and auth_info.handshake_auth_data 132 | ): 133 | # mysql_native_password can reuse the nonce from the initial handshake 134 | nonce = auth_info.handshake_auth_data.rstrip(b"\x00") 135 | else: 136 | nonce = utils.nonce(20) 137 | # Some clients expect a null terminating byte 138 | auth_info = yield nonce + b"\x00" 139 | 140 | user = auth_info.user 141 | if self.password_matches(user=user, scramble=auth_info.data, nonce=nonce): 142 | yield Success(user.name) 143 | else: 144 | yield Forbidden() 145 | 146 | def empty_password_quickpath(self, user: User, scramble: bytes) -> bool: 147 | return not scramble and not user.auth_string 148 | 149 | def password_matches(self, user: User, scramble: bytes, nonce: bytes) -> bool: 150 | return ( 151 | self.empty_password_quickpath(user, scramble) 152 | or self.verify_scramble(user.auth_string, scramble, nonce) 153 | or self.verify_scramble(user.old_auth_string, scramble, nonce) 154 | ) 155 | 156 | def verify_scramble( 157 | self, auth_string: Optional[str], scramble: bytes, nonce: bytes 158 | ) -> bool: 159 | # From docs, 160 | # response.data should be: 161 | # SHA1(password) XOR SHA1("20-bytes random data from server" SHA1(SHA1(password))) 162 | # auth_string should be: 163 | # SHA1(SHA1(password)) 164 | try: 165 | sha1_sha1_password = bytes.fromhex(auth_string or "") 166 | sha1_sha1_with_nonce = sha1(nonce + sha1_sha1_password).digest() 167 | rcvd_sha1_password = utils.xor(scramble, sha1_sha1_with_nonce) 168 | return sha1(rcvd_sha1_password).digest() == sha1_sha1_password 169 | except Exception: # pylint: disable=broad-except 170 | logger.info("Invalid scramble") 171 | return False 172 | 173 | @classmethod 174 | def create_auth_string(cls, password: str) -> str: 175 | return sha1(sha1(password.encode("utf-8")).digest()).hexdigest() 176 | 177 | 178 | class KerberosAuthPlugin(AuthPlugin): 179 | """ 180 | This plugin implements the Generic Security Service Application Program Interface (GSS-API) by way of the Kerberos 181 | mechanism as described in RFC1964(https://www.rfc-editor.org/rfc/rfc1964.html). 182 | """ 183 | 184 | name = "authentication_kerberos" 185 | client_plugin_name = "authentication_kerberos_client" 186 | 187 | def __init__(self, service: str, realm: str) -> None: 188 | self.service = service 189 | self.realm = realm 190 | 191 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState: 192 | import gssapi 193 | from gssapi.exceptions import GSSError 194 | 195 | # Fast authentication not supported 196 | if not auth_info: 197 | yield b"" 198 | 199 | auth_info = ( 200 | yield len(self.service).to_bytes(2, "little") 201 | + self.service.encode("utf-8") 202 | + len(self.realm).to_bytes(2, "little") 203 | + self.realm.encode("utf-8") 204 | ) 205 | 206 | server_creds = gssapi.Credentials( 207 | usage="accept", name=gssapi.Name(f"{self.service}@{self.realm}") 208 | ) 209 | server_ctx = gssapi.SecurityContext(usage="accept", creds=server_creds) 210 | 211 | try: 212 | server_ctx.step(auth_info.data) 213 | except GSSError as e: 214 | yield Forbidden(str(e)) 215 | 216 | username = str(server_ctx.initiator_name).split("@", 1)[0] 217 | if auth_info.username and auth_info.username != username: 218 | yield Forbidden("Given username different than kerberos client") 219 | yield Success(username) 220 | 221 | 222 | class NoLoginAuthPlugin(AuthPlugin): 223 | """ 224 | Standard plugin that prevents all clients from direct login. 225 | 226 | This is useful for user accounts that can only be accessed by proxy authentication. 227 | """ 228 | 229 | name = "mysql_no_login" 230 | client_plugin_name = None 231 | 232 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState: 233 | if not auth_info: 234 | _ = yield FILLER 235 | yield Forbidden() 236 | 237 | 238 | class IdentityProvider: 239 | """ 240 | Abstract base class for an identity provider. 241 | 242 | An identity provider tells the server which authentication plugins to make 243 | available to clients and how to retrieve users. 244 | """ 245 | 246 | def get_plugins(self) -> Sequence[AuthPlugin]: 247 | return [NativePasswordAuthPlugin(), NoLoginAuthPlugin()] 248 | 249 | async def get_user(self, username: str) -> Optional[User]: 250 | return None 251 | 252 | def get_default_plugin(self) -> AuthPlugin: 253 | return self.get_plugins()[0] 254 | 255 | def get_plugin(self, name: str) -> Optional[AuthPlugin]: 256 | try: 257 | return next(p for p in self.get_plugins() if p.name == name) 258 | except StopIteration: 259 | return None 260 | 261 | 262 | class SimpleIdentityProvider(IdentityProvider): 263 | """ 264 | Simple identity provider implementation that naively accepts whatever username a client provides. 265 | """ 266 | 267 | async def get_user(self, username: str) -> Optional[User]: 268 | return User(name=username, auth_plugin=NativePasswordAuthPlugin.name) 269 | -------------------------------------------------------------------------------- /mysql_mimic/constants.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import auto, Enum 4 | 5 | from mysql_mimic.types import Capabilities 6 | 7 | DEFAULT_SERVER_CAPABILITIES = ( 8 | Capabilities.CLIENT_PROTOCOL_41 9 | | Capabilities.CLIENT_DEPRECATE_EOF 10 | | Capabilities.CLIENT_CONNECT_WITH_DB 11 | | Capabilities.CLIENT_QUERY_ATTRIBUTES 12 | | Capabilities.CLIENT_CONNECT_ATTRS 13 | | Capabilities.CLIENT_PLUGIN_AUTH 14 | | Capabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 15 | | Capabilities.CLIENT_SECURE_CONNECTION 16 | | Capabilities.CLIENT_LONG_PASSWORD 17 | | Capabilities.CLIENT_ODBC 18 | | Capabilities.CLIENT_INTERACTIVE 19 | | Capabilities.CLIENT_IGNORE_SPACE 20 | ) 21 | 22 | 23 | class KillKind(Enum): 24 | # Terminate the statement the connection is currently executing, but leave the connection itself intact 25 | QUERY = auto() 26 | # Terminate the connection, after terminating any statement the connection is executing 27 | CONNECTION = auto() 28 | 29 | 30 | INFO_SCHEMA = { 31 | "information_schema": { 32 | "character_sets": { 33 | "character_set_name": "TEXT", 34 | "default_collate_name": "TEXT", 35 | "description": "TEXT", 36 | "maxlen": "INT", 37 | }, 38 | "collations": { 39 | "collation_name": "TEXT", 40 | "character_set_name": "TEXT", 41 | "id": "TEXT", 42 | "is_default": "TEXT", 43 | "is_compiled": "TEXT", 44 | "sortlen": "INT", 45 | }, 46 | "columns": { 47 | "table_catalog": "TEXT", 48 | "table_schema": "TEXT", 49 | "table_name": "TEXT", 50 | "column_name": "TEXT", 51 | "ordinal_position": "INT", 52 | "column_default": "TEXT", 53 | "is_nullable": "TEXT", 54 | "data_type": "TEXT", 55 | "character_maximum_length": "INT", 56 | "character_octet_length": "INT", 57 | "numeric_precision": "INT", 58 | "numeric_scale": "INT", 59 | "datetime_precision": "INT", 60 | "character_set_name": "TEXT", 61 | "collation_name": "TEXT", 62 | "column_type": "TEXT", 63 | "column_key": "TEXT", 64 | "extra": "TEXT", 65 | "privileges": "TEXT", 66 | "column_comment": "TEXT", 67 | "generation_expression": "TEXT", 68 | "srs_id": "TEXT", 69 | }, 70 | "column_privileges": { 71 | "grantee": "TEXT", 72 | "table_catalog": "TEXT", 73 | "table_schema": "TEXT", 74 | "table_name": "TEXT", 75 | "column_name": "TEXT", 76 | "privilege_type": "TEXT", 77 | "is_grantable": "TEXT", 78 | }, 79 | "events": { 80 | "event_catalog": "TEXT", 81 | "event_schema": "TEXT", 82 | "event_name": "TEXT", 83 | "definer": "TEXT", 84 | "time_zone": "TEXT", 85 | "event_body": "TEXT", 86 | "event_definition": "TEXT", 87 | "event_type": "TEXT", 88 | "execute_at": "TEXT", 89 | "interval_value": "INT", 90 | "interval_field": "TEXT", 91 | "sql_mode": "TEXT", 92 | "starts": "TEXT", 93 | "ends": "TEXT", 94 | "status": "TEXT", 95 | "on_completion": "TEXT", 96 | "created": "TEXT", 97 | "last_altered": "TEXT", 98 | "last_executed": "TEXT", 99 | "event_comment": "TEXT", 100 | "originator": "INT", 101 | "character_set_client": "TEXT", 102 | "collation_connection": "TEXT", 103 | "database_collation": "TEXT", 104 | }, 105 | "key_column_usage": { 106 | "constraint_catalog": "TEXT", 107 | "constraint_schema": "TEXT", 108 | "constraint_name": "TEXT", 109 | "table_catalog": "TEXT", 110 | "table_schema": "TEXT", 111 | "table_name": "TEXT", 112 | "column_name": "TEXT", 113 | "ordinal_position": "INT", 114 | "position_in_unique_constraint": "INT", 115 | "referenced_table_schema": "TEXT", 116 | "referenced_table_name": "TEXT", 117 | "referenced_column_name": "TEXT", 118 | }, 119 | "parameters": { 120 | "specific_catalog": "TEXT", 121 | "specific_schema": "TEXT", 122 | "specific_name": "TEXT", 123 | "ordinal_position": "INT", 124 | "parameter_mode": "TEXT", 125 | "parameter_name": "TEXT", 126 | "data_type": "TEXT", 127 | "character_maximum_length": "INT", 128 | "character_octet_length": "INT", 129 | "numeric_precision": "INT", 130 | "numeric_scale": "INT", 131 | "datetime_precision": "INT", 132 | "character_set_name": "TEXT", 133 | "collation_name": "TEXT", 134 | "dtd_identifier": "TEXT", 135 | "routine_type": "TEXT", 136 | }, 137 | "partitions": { 138 | "table_catalog": "TEXT", 139 | "table_schema": "TEXT", 140 | "table_name": "TEXT", 141 | "partition_name": "TEXT", 142 | "subpartition_name": "TEXT", 143 | "partition_ordinal_position": "INT", 144 | "subpartition_ordinal_position": "INT", 145 | "partition_method": "TEXT", 146 | "subpartition_method": "TEXT", 147 | "partition_expression": "TEXT", 148 | "subpartition_expression": "TEXT", 149 | "partition_description": "TEXT", 150 | "table_rows": "INT", 151 | "avg_row_length": "INT", 152 | "data_length": "INT", 153 | "max_data_length": "INT", 154 | "index_length": "INT", 155 | "data_free": "INT", 156 | "create_time": "TEXT", 157 | "update_time": "TEXT", 158 | "check_time": "TEXT", 159 | "checksum": "TEXT", 160 | "partition_comment": "TEXT", 161 | "nodegroup": "TEXT", 162 | "tablespace_name": "TEXT", 163 | }, 164 | "referential_constraints": { 165 | "constraint_catalog": "TEXT", 166 | "constraint_schema": "TEXT", 167 | "constraint_name": "TEXT", 168 | "unique_constraint_catalog": "TEXT", 169 | "unique_constraint_schema": "TEXT", 170 | "unique_constraint_name": "TEXT", 171 | "match_option": "TEXT", 172 | "update_rule": "TEXT", 173 | "delete_rule": "TEXT", 174 | "table_name": "TEXT", 175 | "referenced_table_name": "TEXT", 176 | }, 177 | "routines": { 178 | "specific_name": "TEXT", 179 | "routine_catalog": "TEXT", 180 | "routine_schema": "TEXT", 181 | "routine_name": "TEXT", 182 | "routine_type": "TEXT", 183 | "data_type": "TEXT", 184 | "character_maximum_length": "INT", 185 | "character_octet_length": "INT", 186 | "numeric_precision": "INT", 187 | "numeric_scale": "INT", 188 | "datetime_precision": "INT", 189 | "character_set_name": "TEXT", 190 | "collation_name": "TEXT", 191 | "dtd_identifier": "TEXT", 192 | "routine_body": "TEXT", 193 | "routine_definition": "TEXT", 194 | "external_name": "TEXT", 195 | "external_language": "TEXT", 196 | "parameter_style": "TEXT", 197 | "is_deterministic": "TEXT", 198 | "sql_data_access": "TEXT", 199 | "sql_path": "TEXT", 200 | "security_type": "TEXT", 201 | "created": "TEXT", 202 | "last_altered": "TEXT", 203 | "sql_mode": "TEXT", 204 | "routine_comment": "TEXT", 205 | "definer": "TEXT", 206 | "character_set_client": "TEXT", 207 | "collation_connection": "TEXT", 208 | "database_collation": "TEXT", 209 | }, 210 | "schema_privileges": { 211 | "grantee": "TEXT", 212 | "table_catalog": "TEXT", 213 | "table_schema": "TEXT", 214 | "privilege_type": "TEXT", 215 | "is_grantable": "TEXT", 216 | }, 217 | "schemata": { 218 | "catalog_name": "TEXT", 219 | "schema_name": "TEXT", 220 | "default_character_set_name": "TEXT", 221 | "default_collation_name": "TEXT", 222 | "sql_path": "TEXT", 223 | }, 224 | "statistics": { 225 | "table_catalog": "TEXT", 226 | "table_schema": "TEXT", 227 | "table_name": "TEXT", 228 | "non_unique": "INT", 229 | "index_schema": "TEXT", 230 | "index_name": "TEXT", 231 | "seq_in_index": "INT", 232 | "column_name": "TEXT", 233 | "collation": "TEXT", 234 | "cardinality": "INT", 235 | "sub_part": "TEXT", 236 | "packed": "TEXT", 237 | "nullable": "TEXT", 238 | "index_type": "TEXT", 239 | "comment": "TEXT", 240 | "index_comment": "TEXT", 241 | "is_visible": "TEXT", 242 | "expression": "TEXT", 243 | }, 244 | "tables": { 245 | "table_catalog": "TEXT", 246 | "table_schema": "TEXT", 247 | "table_name": "TEXT", 248 | "table_type": "TEXT", 249 | "engine": "TEXT", 250 | "version": "TEXT", 251 | "row_format": "TEXT", 252 | "table_rows": "INT", 253 | "avg_row_length": "INT", 254 | "data_length": "INT", 255 | "max_data_length": "INT", 256 | "index_length": "INT", 257 | "data_free": "INT", 258 | "auto_increment": "INT", 259 | "create_time": "TEXT", 260 | "update_time": "TEXT", 261 | "check_time": "TEXT", 262 | "table_collation": "TEXT", 263 | "checksum": "TEXT", 264 | "create_options": "TEXT", 265 | "table_comment": "TEXT", 266 | }, 267 | "table_constraints": { 268 | "constraint_name": "TEXT", 269 | "constraint_catalog": "TEXT", 270 | "constraint_schema": "TEXT", 271 | "table_schema": "TEXT", 272 | "table_name": "TEXT", 273 | "constraint_type": "TEXT", 274 | "enforced": "TEXT", 275 | }, 276 | "table_privileges": { 277 | "grantee": "TEXT", 278 | "table_catalog": "TEXT", 279 | "table_schema": "TEXT", 280 | "table_name": "TEXT", 281 | "privilege_type": "TEXT", 282 | "is_grantable": "TEXT", 283 | }, 284 | "triggers": { 285 | "trigger_catalog": "TEXT", 286 | "trigger_schema": "TEXT", 287 | "trigger_name": "TEXT", 288 | "event_manipulation": "TEXT", 289 | "event_object_catalog": "TEXT", 290 | "event_object_schema": "TEXT", 291 | "event_object_table": "TEXT", 292 | "action_order": "INT", 293 | "action_condition": "TEXT", 294 | "action_statement": "TEXT", 295 | "action_orientation": "TEXT", 296 | "action_timing": "TEXT", 297 | "action_reference_old_table": "TEXT", 298 | "action_reference_new_table": "TEXT", 299 | "action_reference_old_row": "TEXT", 300 | "action_reference_new_row": "TEXT", 301 | "created": "TEXT", 302 | "sql_mode": "TEXT", 303 | "definer": "TEXT", 304 | "character_set_client": "TEXT", 305 | "collation_connection": "TEXT", 306 | "database_collation": "TEXT", 307 | }, 308 | "user_privileges": { 309 | "grantee": "TEXT", 310 | "table_catalog": "TEXT", 311 | "privilege_type": "TEXT", 312 | "is_grantable": "TEXT", 313 | }, 314 | "views": { 315 | "table_catalog": "TEXT", 316 | "table_schema": "TEXT", 317 | "table_name": "TEXT", 318 | "view_definition": "TEXT", 319 | "check_option": "TEXT", 320 | "is_updatable": "TEXT", 321 | "definer": "TEXT", 322 | "security_type": "TEXT", 323 | "character_set_client": "TEXT", 324 | "collation_connection": "TEXT", 325 | }, 326 | }, 327 | "mysql": { 328 | "procs_priv": { 329 | "host": "TEXT", 330 | "db": "TEXT", 331 | "user": "TEXT", 332 | "routine_name": "TEXT", 333 | "routine_type": "TEXT", 334 | "proc_priv": "TEXT", 335 | "timestamp": "TEXT", 336 | "grantor": "TEXT", 337 | }, 338 | "role_edges": { 339 | "from_host": "TEXT", 340 | "from_user": "TEXT", 341 | "to_host": "TEXT", 342 | "to_user": "TEXT", 343 | "with_admin_option": "TEXT", 344 | }, 345 | "user": { 346 | "host": "TEXT", 347 | "user": "TEXT", 348 | "select_priv": "TEXT", 349 | "insert_priv": "TEXT", 350 | "update_priv": "TEXT", 351 | "delete_priv": "TEXT", 352 | "index_priv": "TEXT", 353 | "alter_priv": "TEXT", 354 | "create_priv": "TEXT", 355 | "drop_priv": "TEXT", 356 | "grant_priv": "TEXT", 357 | "create_view_priv": "TEXT", 358 | "show_view_priv": "TEXT", 359 | "create_routine_priv": "TEXT", 360 | "alter_routine_priv": "TEXT", 361 | "execute_priv": "TEXT", 362 | "trigger_priv": "TEXT", 363 | "event_priv": "TEXT", 364 | "create_tmp_table_priv": "TEXT", 365 | "lock_tables_priv": "TEXT", 366 | "references_priv": "TEXT", 367 | "reload_priv": "TEXT", 368 | "shutdown_priv": "TEXT", 369 | "process_priv": "TEXT", 370 | "file_priv": "TEXT", 371 | "show_db_priv": "TEXT", 372 | "super_priv": "TEXT", 373 | "repl_slave_priv": "TEXT", 374 | "repl_client_priv": "TEXT", 375 | "create_user_priv": "TEXT", 376 | "create_tablespace_priv": "TEXT", 377 | "create_role_priv": "TEXT", 378 | "drop_role_priv": "TEXT", 379 | "ssl_type": "TEXT", 380 | "ssl_cipher": "TEXT", 381 | "x509_issuer": "TEXT", 382 | "x509_subject": "TEXT", 383 | "plugin": "TEXT", 384 | "authentication_string": "TEXT", 385 | "password_expired": "TEXT", 386 | "password_last_changed": "TEXT", 387 | "password_lifetime": "TEXT", 388 | "account_locked": "TEXT", 389 | "password_reuse_history": "TEXT", 390 | "password_reuse_time": "TEXT", 391 | "password_require_current": "TEXT", 392 | "user_attributes": "TEXT", 393 | "max_questions": "INT", 394 | "max_updates": "INT", 395 | "max_connections": "INT", 396 | "max_user_connections": "INT", 397 | }, 398 | }, 399 | } 400 | -------------------------------------------------------------------------------- /mysql_mimic/context.py: -------------------------------------------------------------------------------- 1 | from contextvars import ContextVar 2 | 3 | connection_id: ContextVar[int] = ContextVar("connection_id") 4 | -------------------------------------------------------------------------------- /mysql_mimic/control.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | from typing import Dict, Optional, TYPE_CHECKING 5 | 6 | from mysql_mimic.constants import KillKind 7 | from mysql_mimic.utils import seq 8 | 9 | if TYPE_CHECKING: 10 | from mysql_mimic.connection import Connection 11 | 12 | 13 | class TooManyConnections(Exception): 14 | pass 15 | 16 | 17 | class Control: 18 | """ 19 | Base class for controlling server state. 20 | 21 | This is intended to encapsulate operations that might span deployment of multiple 22 | MySQL mimic server instances. 23 | """ 24 | 25 | async def add(self, connection: Connection) -> int: 26 | """ 27 | Add a new connection 28 | 29 | Returns: 30 | connection id 31 | """ 32 | raise NotImplementedError() 33 | 34 | async def remove(self, connection_id: int) -> None: 35 | """Remove an existing connection""" 36 | raise NotImplementedError() 37 | 38 | async def kill( 39 | self, connection_id: int, kind: KillKind = KillKind.CONNECTION 40 | ) -> None: 41 | """Request termination of an existing connection""" 42 | raise NotImplementedError() 43 | 44 | 45 | class LocalControl(Control): 46 | """ 47 | Simple Control implementation that handles everything in-memory. 48 | 49 | Args: 50 | server_id: set a unique server ID. This is used to generate globally unique 51 | connection IDs. This should be an integer between 0 and 65535. 52 | If left as None, a random server ID will be generated. 53 | """ 54 | 55 | _CONNECTION_ID_BITS = 16 56 | _MAX_CONNECTION_SEQ = 2**_CONNECTION_ID_BITS 57 | _MAX_SERVER_ID = 2**16 58 | 59 | def __init__(self, server_id: Optional[int] = None): 60 | self._connection_seq = seq(self._MAX_CONNECTION_SEQ) 61 | self._connections: Dict[int, Connection] = {} 62 | self.server_id = server_id or random.randint(0, self._MAX_SERVER_ID - 1) 63 | 64 | async def add(self, connection: Connection) -> int: 65 | connection_id = self._new_connection_id() 66 | self._connections[connection_id] = connection 67 | return connection_id 68 | 69 | async def remove(self, connection_id: int) -> None: 70 | self._connections.pop(connection_id, None) 71 | 72 | async def kill( 73 | self, connection_id: int, kind: KillKind = KillKind.CONNECTION 74 | ) -> None: 75 | conn = self._connections.get(connection_id) 76 | if conn: 77 | conn.kill(kind) 78 | 79 | def _new_connection_id(self) -> int: 80 | """ 81 | Generate a connection ID. 82 | 83 | MySQL connection IDs are 4 bytes. 84 | 85 | This is tricky for us, as there may be multiple MySQL-Mimic server instances. 86 | 87 | We use a concatenation of the server ID (first two bytes) and connection 88 | sequence (second two bytes): 89 | 90 | |<---server ID--->|<-conn sequence->| 91 | 00000000 00000000 00000000 00000000 92 | 93 | If incremental connection IDs aren't manually provided, this isn't guaranteed 94 | to be unique. But collisions should be highly unlikely. 95 | """ 96 | if len(self._connections) >= self._MAX_CONNECTION_SEQ: 97 | raise TooManyConnections() 98 | 99 | server_id_prefix = ( 100 | self.server_id % self._MAX_SERVER_ID 101 | ) << self._CONNECTION_ID_BITS 102 | 103 | connection_id = server_id_prefix + next(self._connection_seq) 104 | 105 | # Avoid connection ID collisions in the unlikely chance that a connection is 106 | # alive longer than it takes for the sequence to reset. 107 | while connection_id in self._connections: 108 | connection_id = server_id_prefix + next(self._connection_seq) 109 | 110 | return connection_id 111 | -------------------------------------------------------------------------------- /mysql_mimic/errors.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class ErrorCode(IntEnum): 5 | """https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html""" 6 | 7 | CON_COUNT_ERROR = 1040 8 | HANDSHAKE_ERROR = 1043 9 | ACCESS_DENIED_ERROR = 1045 10 | NO_DB_ERROR = 1046 11 | PARSE_ERROR = 1064 12 | EMPTY_QUERY = 1065 13 | UNKNOWN_PROCEDURE = 1106 14 | UNKNOWN_SYSTEM_VARIABLE = 1193 15 | UNKNOWN_COM_ERROR = 1047 16 | UNKNOWN_ERROR = 1105 17 | WRONG_VALUE_FOR_VAR = 1231 18 | NOT_SUPPORTED_YET = 1235 19 | MALFORMED_PACKET = 1835 20 | USER_DOES_NOT_EXIST = 3162 21 | SESSION_WAS_KILLED = 3169 22 | PLUGIN_REQUIRES_REGISTRATION = 4055 23 | 24 | 25 | # For more info, see https://dev.mysql.com/doc/refman/8.0/en/error-message-elements.html 26 | SQLSTATES = { 27 | ErrorCode.CON_COUNT_ERROR: b"08004", 28 | ErrorCode.ACCESS_DENIED_ERROR: b"28000", 29 | ErrorCode.HANDSHAKE_ERROR: b"08S01", 30 | ErrorCode.NO_DB_ERROR: b"3D000", 31 | ErrorCode.PARSE_ERROR: b"42000", 32 | ErrorCode.EMPTY_QUERY: b"42000", 33 | ErrorCode.UNKNOWN_PROCEDURE: b"42000", 34 | ErrorCode.UNKNOWN_COM_ERROR: b"08S01", 35 | ErrorCode.WRONG_VALUE_FOR_VAR: b"42000", 36 | ErrorCode.NOT_SUPPORTED_YET: b"42000", 37 | } 38 | 39 | 40 | def get_sqlstate(code: ErrorCode) -> bytes: 41 | return SQLSTATES.get(code, b"HY000") 42 | 43 | 44 | class MysqlError(Exception): 45 | def __init__(self, msg: str, code: ErrorCode = ErrorCode.UNKNOWN_ERROR): 46 | super().__init__(f"{code}: {msg}") 47 | self.msg = msg 48 | self.code = code 49 | -------------------------------------------------------------------------------- /mysql_mimic/functions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Callable 2 | from datetime import datetime 3 | 4 | Functions = Mapping[str, Callable[[], Any]] 5 | 6 | 7 | def mysql_datetime_function_mapping(timestamp: datetime) -> Functions: 8 | functions = { 9 | "NOW": lambda: timestamp.strftime("%Y-%m-%d %H:%M:%S"), 10 | "CURDATE": lambda: timestamp.strftime("%Y-%m-%d"), 11 | "CURTIME": lambda: timestamp.strftime("%H:%M:%S"), 12 | } 13 | functions.update( 14 | { 15 | "CURRENT_TIMESTAMP": functions["NOW"], 16 | "LOCALTIME": functions["NOW"], 17 | "LOCALTIMESTAMP": functions["NOW"], 18 | "CURRENT_DATE": functions["CURDATE"], 19 | "CURRENT_TIME": functions["CURTIME"], 20 | } 21 | ) 22 | return functions 23 | -------------------------------------------------------------------------------- /mysql_mimic/intercept.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from sqlglot import expressions as exp 6 | 7 | from mysql_mimic.errors import MysqlError, ErrorCode 8 | from mysql_mimic.variables import DEFAULT 9 | 10 | 11 | # Mapping of transaction characteristic from SET TRANSACTION statements to their corresponding system variable 12 | TRANSACTION_CHARACTERISTICS = { 13 | "ISOLATION LEVEL REPEATABLE READ": ("transaction_isolation", "REPEATABLE-READ"), 14 | "ISOLATION LEVEL READ COMMITTED": ("transaction_isolation", "READ-COMMITTED"), 15 | "ISOLATION LEVEL READ UNCOMMITTED": ("transaction_isolation", "READ-UNCOMMITTED"), 16 | "ISOLATION LEVEL SERIALIZABLE": ("transaction_isolation", "SERIALIZABLE"), 17 | "READ WRITE": ("transaction_read_only", False), 18 | "READ ONLY": ("transaction_read_only", True), 19 | } 20 | 21 | 22 | def setitem_kind(setitem: exp.SetItem) -> str: 23 | kind = setitem.text("kind") 24 | if not kind: 25 | return "VARIABLE" 26 | 27 | if kind in {"GLOBAL", "PERSIST", "PERSIST_ONLY", "SESSION", "LOCAL"}: 28 | return "VARIABLE" 29 | 30 | return kind 31 | 32 | 33 | def value_to_expression(value: Any) -> exp.Expression: 34 | if value is True: 35 | return exp.true() 36 | if value is False: 37 | return exp.false() 38 | if value is None: 39 | return exp.null() 40 | if isinstance(value, (int, float)): 41 | return exp.Literal.number(value) 42 | return exp.Literal.string(str(value)) 43 | 44 | 45 | def expression_to_value(expression: exp.Expression) -> Any: 46 | if expression == exp.true(): 47 | return True 48 | if expression == exp.false(): 49 | return False 50 | if expression == exp.null(): 51 | return None 52 | if isinstance(expression, exp.Literal) and not expression.args.get("is_string"): 53 | try: 54 | return int(expression.this) 55 | except ValueError: 56 | return float(expression.this) 57 | if isinstance(expression, exp.Literal): 58 | return expression.name 59 | if expression.name == "DEFAULT": 60 | return DEFAULT 61 | if expression.name == "ON": 62 | return True 63 | if expression.name == "OFF": 64 | return False 65 | raise MysqlError( 66 | "Complex expressions in variables not supported yet", 67 | code=ErrorCode.NOT_SUPPORTED_YET, 68 | ) 69 | -------------------------------------------------------------------------------- /mysql_mimic/packets.py: -------------------------------------------------------------------------------- 1 | import io 2 | from dataclasses import dataclass, field 3 | from typing import Optional, Dict, Any, Sequence, Callable, Tuple, List, Union 4 | 5 | from mysql_mimic.charset import Collation, CharacterSet 6 | from mysql_mimic.constants import DEFAULT_SERVER_CAPABILITIES 7 | from mysql_mimic.errors import ErrorCode, get_sqlstate, MysqlError 8 | from mysql_mimic.prepared import PreparedStatement, REGEX_PARAM 9 | from mysql_mimic.results import NullBitmap, ResultColumn 10 | from mysql_mimic.types import ( 11 | Capabilities, 12 | uint_2, 13 | uint_1, 14 | str_null, 15 | uint_4, 16 | str_fixed, 17 | read_uint_4, 18 | read_str_null, 19 | read_str_fixed, 20 | read_uint_1, 21 | read_str_len, 22 | read_uint_len, 23 | str_len, 24 | str_rest, 25 | uint_len, 26 | ColumnType, 27 | read_int_1, 28 | read_int_2, 29 | read_uint_2, 30 | read_int_4, 31 | read_uint_8, 32 | read_int_8, 33 | read_float, 34 | read_double, 35 | ResultsetMetadata, 36 | ColumnDefinition, 37 | ComStmtExecuteFlags, 38 | peek, 39 | ServerStatus, 40 | read_str_rest, 41 | ) 42 | 43 | 44 | @dataclass 45 | class SSLRequest: 46 | max_packet_size: int 47 | capabilities: Capabilities 48 | client_charset: CharacterSet 49 | 50 | 51 | @dataclass 52 | class HandshakeResponse41: 53 | max_packet_size: int 54 | capabilities: Capabilities 55 | client_charset: CharacterSet 56 | username: str 57 | auth_response: bytes 58 | connect_attrs: Dict[str, str] = field(default_factory=dict) 59 | database: Optional[str] = None 60 | client_plugin: Optional[str] = None 61 | zstd_compression_level: int = 0 62 | 63 | 64 | @dataclass 65 | class ComChangeUser: 66 | username: str 67 | auth_response: bytes 68 | database: str 69 | client_charset: Optional[CharacterSet] = None 70 | client_plugin: Optional[str] = None 71 | connect_attrs: Dict[str, str] = field(default_factory=dict) 72 | 73 | 74 | @dataclass 75 | class ComQuery: 76 | sql: str 77 | query_attrs: Dict[str, str] 78 | 79 | 80 | @dataclass 81 | class ComStmtSendLongData: 82 | stmt_id: int 83 | param_id: int 84 | data: bytes 85 | 86 | 87 | @dataclass 88 | class ComStmtExecute: 89 | sql: str 90 | query_attrs: Dict[str, str] 91 | stmt: PreparedStatement 92 | use_cursor: bool 93 | 94 | 95 | @dataclass 96 | class ComStmtFetch: 97 | stmt_id: int 98 | num_rows: int 99 | 100 | 101 | @dataclass 102 | class ComStmtReset: 103 | stmt_id: int 104 | 105 | 106 | @dataclass 107 | class ComStmtClose: 108 | stmt_id: int 109 | 110 | 111 | @dataclass 112 | class ComFieldList: 113 | table: str 114 | wildcard: str 115 | 116 | 117 | def make_ok( 118 | capabilities: Capabilities, 119 | status_flags: ServerStatus, 120 | eof: bool = False, 121 | affected_rows: int = 0, 122 | last_insert_id: int = 0, 123 | warnings: int = 0, 124 | flags: int = 0, 125 | ) -> bytes: 126 | parts = [ 127 | uint_1(0) if not eof else uint_1(0xFE), 128 | uint_len(affected_rows), 129 | uint_len(last_insert_id), 130 | ] 131 | 132 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities: 133 | parts.append(uint_2(status_flags | flags)) 134 | parts.append(uint_2(warnings)) 135 | elif Capabilities.CLIENT_TRANSACTIONS in capabilities: 136 | parts.append(uint_2(status_flags | flags)) 137 | 138 | return _concat(*parts) 139 | 140 | 141 | def make_eof( 142 | capabilities: Capabilities, 143 | status_flags: ServerStatus, 144 | warnings: int = 0, 145 | flags: int = 0, 146 | ) -> bytes: 147 | parts = [uint_1(0xFE)] 148 | 149 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities: 150 | parts.append(uint_2(warnings)) 151 | parts.append(uint_2(status_flags | flags)) 152 | 153 | return _concat(*parts) 154 | 155 | 156 | def make_error( 157 | capabilities: Capabilities = DEFAULT_SERVER_CAPABILITIES, 158 | server_charset: CharacterSet = CharacterSet.utf8mb4, 159 | msg: Any = "", 160 | code: ErrorCode = ErrorCode.UNKNOWN_ERROR, 161 | ) -> bytes: 162 | parts = [uint_1(0xFF), uint_2(code)] 163 | 164 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities: 165 | parts.append(str_fixed(1, b"#")) 166 | parts.append(str_fixed(5, get_sqlstate(code))) 167 | 168 | parts.append(str_rest(server_charset.encode(str(msg)))) 169 | 170 | return _concat(*parts) 171 | 172 | 173 | def make_handshake_v10( 174 | capabilities: Capabilities, 175 | server_charset: CharacterSet, 176 | server_version: str, 177 | connection_id: int, 178 | auth_data: bytes, 179 | status_flags: ServerStatus, 180 | auth_plugin_name: str, 181 | ) -> bytes: 182 | auth_plugin_data_len = ( 183 | len(auth_data) if Capabilities.CLIENT_PLUGIN_AUTH in capabilities else 0 184 | ) 185 | 186 | parts = [ 187 | uint_1(10), # protocol version 188 | str_null(server_charset.encode(server_version)), # server version 189 | uint_4(connection_id), # connection ID 190 | str_null(auth_data[:8]), # plugin data 191 | uint_2(capabilities & 0xFFFF), # lower capabilities flag 192 | uint_1(server_charset), # lower character set 193 | uint_2(status_flags), # server status flag 194 | uint_2(capabilities >> 16), # higher capabilities flag 195 | uint_1(auth_plugin_data_len), 196 | str_fixed(10, bytes(10)), # reserved 197 | str_fixed(max(13, auth_plugin_data_len - 8), auth_data[8:]), 198 | ] 199 | if Capabilities.CLIENT_PLUGIN_AUTH in capabilities: 200 | parts.append(str_null(server_charset.encode(auth_plugin_name))) 201 | return _concat(*parts) 202 | 203 | 204 | def parse_handshake_response( 205 | capabilities: Capabilities, data: bytes 206 | ) -> Union[HandshakeResponse41, SSLRequest]: 207 | r = io.BytesIO(data) 208 | 209 | client_capabilities = Capabilities(read_uint_4(r)) 210 | 211 | capabilities = capabilities & client_capabilities 212 | 213 | max_packet_size = read_uint_4(r) 214 | client_charset = Collation(read_uint_1(r)).charset 215 | read_str_fixed(r, 23) 216 | 217 | if not peek(r): 218 | return SSLRequest( 219 | max_packet_size=max_packet_size, 220 | capabilities=capabilities, 221 | client_charset=client_charset, 222 | ) 223 | 224 | username = client_charset.decode(read_str_null(r)) 225 | 226 | if Capabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA in capabilities: 227 | auth_response = read_str_len(r) 228 | else: 229 | l_auth_response = read_uint_1(r) 230 | auth_response = read_str_fixed(r, l_auth_response) 231 | 232 | response = HandshakeResponse41( 233 | max_packet_size=max_packet_size, 234 | capabilities=capabilities, 235 | client_charset=client_charset, 236 | username=username, 237 | auth_response=auth_response, 238 | ) 239 | 240 | if Capabilities.CLIENT_CONNECT_WITH_DB in capabilities: 241 | response.database = client_charset.decode(read_str_null(r)) 242 | 243 | if Capabilities.CLIENT_PLUGIN_AUTH in capabilities: 244 | response.client_plugin = client_charset.decode(read_str_null(r)) 245 | 246 | if Capabilities.CLIENT_CONNECT_ATTRS in capabilities: 247 | response.connect_attrs = _read_connect_attrs(r, client_charset) 248 | 249 | if Capabilities.CLIENT_ZSTD_COMPRESSION_ALGORITHM in capabilities: 250 | response.zstd_compression_level = read_uint_1(r) 251 | 252 | return response 253 | 254 | 255 | def parse_handshake_response_41( 256 | capabilities: Capabilities, data: bytes 257 | ) -> HandshakeResponse41: 258 | response = parse_handshake_response(capabilities, data) 259 | assert isinstance(response, HandshakeResponse41) 260 | return response 261 | 262 | 263 | def make_auth_more_data(data: bytes) -> bytes: 264 | return _concat(uint_1(1), str_rest(data)) # status tag 265 | 266 | 267 | def make_auth_switch_request( 268 | server_charset: CharacterSet, plugin_name: str, plugin_provided_data: bytes 269 | ) -> bytes: 270 | return _concat( 271 | uint_1(254), # status tag 272 | str_null(server_charset.encode(plugin_name)), 273 | str_rest(plugin_provided_data), 274 | ) 275 | 276 | 277 | def parse_com_change_user( 278 | capabilities: Capabilities, client_charset: CharacterSet, data: bytes 279 | ) -> ComChangeUser: 280 | r = io.BytesIO(data) 281 | username = client_charset.decode(read_str_null(r)) 282 | if Capabilities.CLIENT_SECURE_CONNECTION in capabilities: 283 | l_auth_response = read_uint_1(r) 284 | auth_response = read_str_fixed(r, l_auth_response) 285 | else: 286 | auth_response = read_str_null(r) 287 | database = client_charset.decode(read_str_null(r)) 288 | 289 | response = ComChangeUser( 290 | username=username, auth_response=auth_response, database=database 291 | ) 292 | 293 | if peek(r): # more data available 294 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities: 295 | client_charset = Collation(read_uint_2(r)).charset 296 | response.client_charset = client_charset 297 | if Capabilities.CLIENT_PLUGIN_AUTH in capabilities: 298 | response.client_plugin = client_charset.decode(read_str_null(r)) 299 | if Capabilities.CLIENT_CONNECT_ATTRS in capabilities: 300 | response.connect_attrs = _read_connect_attrs(r, client_charset) 301 | 302 | return response 303 | 304 | 305 | def parse_com_query( 306 | capabilities: Capabilities, client_charset: CharacterSet, data: bytes 307 | ) -> ComQuery: 308 | r = io.BytesIO(data) 309 | 310 | if Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities: 311 | parameter_count = read_uint_len(r) 312 | read_uint_len(r) # parameter_set_count. Always 1. 313 | query_attrs = { 314 | k: v 315 | for k, v in _read_params(capabilities, client_charset, r, parameter_count) 316 | if k is not None 317 | } 318 | else: 319 | query_attrs = {} 320 | 321 | sql = r.read().decode(client_charset.codec) 322 | 323 | return ComQuery( 324 | sql=sql, 325 | query_attrs=query_attrs, 326 | ) 327 | 328 | 329 | def make_column_count(capabilities: Capabilities, column_count: int) -> bytes: 330 | parts = [] 331 | 332 | if Capabilities.CLIENT_OPTIONAL_RESULTSET_METADATA in capabilities: 333 | parts.append(uint_1(ResultsetMetadata.RESULTSET_METADATA_FULL)) 334 | 335 | parts.append(uint_len(column_count)) 336 | 337 | return _concat(*parts) 338 | 339 | 340 | # pylint: disable=too-many-arguments 341 | def make_column_definition_41( 342 | server_charset: CharacterSet, 343 | schema: Optional[str] = None, 344 | table: Optional[str] = None, 345 | org_table: Optional[str] = None, 346 | name: Optional[str] = None, 347 | org_name: Optional[str] = None, 348 | character_set: CharacterSet = CharacterSet.utf8mb4, 349 | column_length: int = 256, 350 | column_type: ColumnType = ColumnType.VARCHAR, 351 | flags: ColumnDefinition = ColumnDefinition(0), 352 | decimals: int = 0, 353 | is_com_field_list: bool = False, 354 | default: Optional[str] = None, 355 | ) -> bytes: 356 | schema = schema or "" 357 | table = table or "" 358 | org_table = org_table or table 359 | name = name or "" 360 | org_name = org_name or name 361 | parts = [ 362 | str_len(b"def"), 363 | str_len(server_charset.encode(schema)), 364 | str_len(server_charset.encode(table)), 365 | str_len(server_charset.encode(org_table)), 366 | str_len(server_charset.encode(name)), 367 | str_len(server_charset.encode(org_name)), 368 | uint_len(0x0C), # Length of the following fields 369 | uint_2(character_set), 370 | uint_4(column_length), 371 | uint_1(column_type), 372 | uint_2(flags), 373 | uint_1(decimals), 374 | uint_2(0), # filler 375 | ] 376 | if is_com_field_list: 377 | if default is None: 378 | parts.append(uint_len(0)) 379 | else: 380 | default_values = server_charset.encode(default) 381 | parts.extend([uint_len(len(default_values)), str_len(default_values)]) 382 | return _concat(*parts) 383 | 384 | 385 | def make_text_resultset_row( 386 | row: Sequence[Any], columns: Sequence[ResultColumn] 387 | ) -> bytes: 388 | parts = [] 389 | 390 | for value, column in zip(row, columns): 391 | if value is None: 392 | parts.append(b"\xfb") 393 | else: 394 | text = column.text_encode(value) 395 | parts.append(str_len(text)) 396 | 397 | return _concat(*parts) 398 | 399 | 400 | def make_com_stmt_prepare_ok(statement: PreparedStatement) -> bytes: 401 | return _concat( 402 | uint_1(0), # OK 403 | uint_4(statement.stmt_id), 404 | uint_2(0), # number of columns 405 | uint_2(statement.num_params), 406 | uint_1(0), # filler 407 | uint_2(0), # number of warnings 408 | ) 409 | 410 | 411 | def parse_com_stmt_send_long_data(data: bytes) -> ComStmtSendLongData: 412 | r = io.BytesIO(data) 413 | return ComStmtSendLongData( 414 | stmt_id=read_uint_4(r), 415 | param_id=read_uint_2(r), 416 | data=r.read(), 417 | ) 418 | 419 | 420 | def parse_com_stmt_execute( 421 | capabilities: Capabilities, 422 | client_charset: CharacterSet, 423 | data: bytes, 424 | get_stmt: Callable[[int], PreparedStatement], 425 | ) -> ComStmtExecute: 426 | r = io.BytesIO(data) 427 | stmt_id = read_uint_4(r) 428 | stmt = get_stmt(stmt_id) 429 | use_cursor, param_count_available = _read_cursor_flags(r) 430 | read_uint_4(r) # iteration count. Always 1. 431 | sql, query_attrs = _interpolate_params( 432 | capabilities, client_charset, r, stmt, param_count_available 433 | ) 434 | return ComStmtExecute( 435 | sql=sql, 436 | query_attrs=query_attrs, 437 | stmt=stmt, 438 | use_cursor=use_cursor, 439 | ) 440 | 441 | 442 | def make_binary_resultrow(row: Sequence[Any], columns: Sequence[ResultColumn]) -> bytes: 443 | column_count = len(row) 444 | 445 | null_bitmap = NullBitmap.new(column_count, offset=2) 446 | 447 | values = [] 448 | for i, (val, col) in enumerate(zip(row, columns)): 449 | if val is None: 450 | null_bitmap.flip(i) 451 | else: 452 | values.append(col.binary_encode(val)) 453 | 454 | values_data = b"".join(values) 455 | 456 | null_bitmap_data = bytes(null_bitmap) 457 | 458 | return b"".join( 459 | [ 460 | uint_1(0), # packet header 461 | str_fixed(len(null_bitmap_data), null_bitmap_data), 462 | str_fixed(len(values_data), values_data), 463 | ] 464 | ) 465 | 466 | 467 | def parse_handle_stmt_fetch(data: bytes) -> ComStmtFetch: 468 | r = io.BytesIO(data) 469 | return ComStmtFetch( 470 | stmt_id=read_uint_4(r), 471 | num_rows=read_uint_4(r), 472 | ) 473 | 474 | 475 | def parse_com_stmt_reset(data: bytes) -> ComStmtReset: 476 | r = io.BytesIO(data) 477 | return ComStmtReset(stmt_id=read_uint_4(r)) 478 | 479 | 480 | def parse_com_stmt_close(data: bytes) -> ComStmtClose: 481 | r = io.BytesIO(data) 482 | return ComStmtClose(stmt_id=read_uint_4(r)) 483 | 484 | 485 | def parse_com_init_db(client_charset: CharacterSet, data: bytes) -> str: 486 | return client_charset.decode(data) 487 | 488 | 489 | def parse_com_field_list(client_charset: CharacterSet, data: bytes) -> ComFieldList: 490 | r = io.BytesIO(data) 491 | return ComFieldList( 492 | table=client_charset.decode(read_str_null(r)), 493 | wildcard=client_charset.decode(read_str_rest(r)), 494 | ) 495 | 496 | 497 | def _read_cursor_flags(reader: io.BytesIO) -> Tuple[bool, bool]: 498 | flags = ComStmtExecuteFlags(read_uint_1(reader)) 499 | param_count_available = ComStmtExecuteFlags.PARAMETER_COUNT_AVAILABLE in flags 500 | 501 | if ComStmtExecuteFlags.CURSOR_TYPE_READ_ONLY in flags: 502 | return True, param_count_available 503 | if ComStmtExecuteFlags.CURSOR_TYPE_NO_CURSOR in flags: 504 | return False, param_count_available 505 | raise MysqlError(f"Unsupported cursor flags: {flags}", ErrorCode.NOT_SUPPORTED_YET) 506 | 507 | 508 | def _interpolate_params( 509 | capabilities: Capabilities, 510 | client_charset: CharacterSet, 511 | reader: io.BytesIO, 512 | stmt: PreparedStatement, 513 | param_count_available: bool, 514 | ) -> Tuple[str, Dict[str, str]]: 515 | sql = stmt.sql 516 | query_attrs = {} 517 | parameter_count = stmt.num_params 518 | 519 | if stmt.num_params > 0 or ( 520 | Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities and param_count_available 521 | ): 522 | if Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities: 523 | parameter_count = read_uint_len(reader) 524 | 525 | if parameter_count > 0: 526 | # When there are query attributes, they are combined with statement parameters. 527 | # The statement parameters will be first, query attributes second. 528 | params = _read_params( 529 | capabilities, client_charset, reader, parameter_count, stmt.param_buffers 530 | ) 531 | 532 | for _, value in params[: stmt.num_params]: 533 | sql = REGEX_PARAM.sub(_encode_param_as_sql(value), sql, 1) 534 | 535 | query_attrs = {k: v for k, v in params[stmt.num_params :] if k is not None} 536 | 537 | return sql, query_attrs 538 | 539 | 540 | def _encode_param_as_sql(param: Any) -> str: 541 | if isinstance(param, str): 542 | return f"'{param}'" 543 | if param is None: 544 | return "NULL" 545 | if param is True: 546 | return "TRUE" 547 | if param is False: 548 | return "FALSE" 549 | return str(param) 550 | 551 | 552 | def _read_params( 553 | capabilities: Capabilities, 554 | client_charset: CharacterSet, 555 | reader: io.BytesIO, 556 | parameter_count: int, 557 | buffers: Optional[Dict[int, bytearray]] = None, 558 | ) -> Sequence[Tuple[Optional[str], Any]]: 559 | """ 560 | Read parameters from a stream. 561 | 562 | This is intended for reading query attributes from COM_QUERY and parameters from COM_STMT_EXECUTE. 563 | 564 | Returns: 565 | Name/value pairs. Name is None if there is no name, which is the case for stmt_execute, which 566 | combines statement parameters and query attributes. 567 | """ 568 | params: List[Tuple[Optional[str], Any]] = [] 569 | 570 | if parameter_count: 571 | null_bitmap = NullBitmap.from_buffer(reader, parameter_count) 572 | new_params_bound_flag = read_uint_1(reader) 573 | 574 | if not new_params_bound_flag: 575 | raise MysqlError( 576 | "Server requires the new-params-bound-flag to be set", 577 | ErrorCode.NOT_SUPPORTED_YET, 578 | ) 579 | 580 | param_types = [] 581 | 582 | for i in range(parameter_count): 583 | param_type, unsigned = _read_param_type(reader) 584 | 585 | if Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities: 586 | # Only query attributes have names 587 | # Statement parameters will have an empty name, e.g. b"\x00" 588 | param_name = client_charset.decode(read_str_len(reader)) 589 | else: 590 | param_name = "" 591 | 592 | param_types.append((param_name, param_type, unsigned)) 593 | 594 | for i, (param_name, param_type, unsigned) in enumerate(param_types): 595 | if null_bitmap.is_flipped(i): 596 | params.append((param_name, None)) 597 | elif buffers and i in buffers: 598 | params.append((param_name, client_charset.decode(buffers[i]))) 599 | else: 600 | params.append( 601 | ( 602 | param_name, 603 | _read_param_value(client_charset, reader, param_type, unsigned), 604 | ) 605 | ) 606 | 607 | return params 608 | 609 | 610 | def _read_param_type(reader: io.BytesIO) -> Tuple[ColumnType, bool]: 611 | param_type = ColumnType(read_uint_1(reader)) 612 | is_unsigned = (read_uint_1(reader) & 0x80) > 0 613 | return param_type, is_unsigned 614 | 615 | 616 | def _read_param_value( 617 | client_charset: CharacterSet, 618 | reader: io.BytesIO, 619 | param_type: ColumnType, 620 | unsigned: bool, 621 | ) -> Any: 622 | if param_type in { 623 | ColumnType.VARCHAR, 624 | ColumnType.VAR_STRING, 625 | ColumnType.STRING, 626 | ColumnType.BLOB, 627 | ColumnType.TINY_BLOB, 628 | ColumnType.MEDIUM_BLOB, 629 | ColumnType.LONG_BLOB, 630 | }: 631 | val = read_str_len(reader) 632 | return client_charset.decode(val) 633 | 634 | if param_type == ColumnType.TINY: 635 | return (read_uint_1 if unsigned else read_int_1)(reader) 636 | 637 | if param_type == ColumnType.BOOL: 638 | return read_uint_1(reader) 639 | 640 | if param_type in {ColumnType.SHORT, ColumnType.YEAR}: 641 | return (read_uint_2 if unsigned else read_int_2)(reader) 642 | 643 | if param_type in {ColumnType.LONG, ColumnType.INT24}: 644 | return (read_uint_4 if unsigned else read_int_4)(reader) 645 | 646 | if param_type == ColumnType.LONGLONG: 647 | return (read_uint_8 if unsigned else read_int_8)(reader) 648 | 649 | if param_type == ColumnType.FLOAT: 650 | return read_float(reader) 651 | 652 | if param_type == ColumnType.DOUBLE: 653 | return read_double(reader) 654 | 655 | if param_type == ColumnType.NULL: 656 | return None 657 | 658 | raise MysqlError( 659 | f"Unsupported parameter type: {param_type}", ErrorCode.NOT_SUPPORTED_YET 660 | ) 661 | 662 | 663 | def _read_connect_attrs( 664 | reader: io.BytesIO, client_charset: CharacterSet 665 | ) -> Dict[str, str]: 666 | connect_attrs = {} 667 | total_l = read_uint_len(reader) 668 | 669 | while total_l > 0: 670 | key = read_str_len(reader) 671 | value = read_str_len(reader) 672 | connect_attrs[client_charset.decode(key)] = client_charset.decode(value) 673 | 674 | item_l = len(str_len(key) + str_len(value)) 675 | total_l -= item_l 676 | return connect_attrs 677 | 678 | 679 | def _concat(*parts: bytes) -> bytes: 680 | return b"".join(parts) 681 | -------------------------------------------------------------------------------- /mysql_mimic/prepared.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import Optional, Dict, AsyncIterable 4 | 5 | # Borrowed from mysql-connector-python 6 | REGEX_PARAM = re.compile(r"""\?(?=(?:[^"'`]*["'`][^"'`]*["'`])*[^"'`]*$)""") 7 | 8 | 9 | @dataclass 10 | class PreparedStatement: 11 | stmt_id: int 12 | sql: str 13 | num_params: int 14 | param_buffers: Optional[Dict[int, bytearray]] = None 15 | cursor: Optional[AsyncIterable[bytes]] = None 16 | -------------------------------------------------------------------------------- /mysql_mimic/results.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | import struct 5 | from dataclasses import dataclass 6 | from datetime import datetime, date, timedelta 7 | from typing import ( 8 | Iterable, 9 | Sequence, 10 | Optional, 11 | Callable, 12 | Any, 13 | Union, 14 | Tuple, 15 | Dict, 16 | AsyncIterable, 17 | cast, 18 | ) 19 | 20 | from mysql_mimic.errors import MysqlError 21 | from mysql_mimic.types import ColumnType, str_len, uint_1, uint_2, uint_4 22 | from mysql_mimic.charset import CharacterSet 23 | from mysql_mimic.utils import aiterate 24 | 25 | Encoder = Callable[[Any, "ResultColumn"], bytes] 26 | 27 | 28 | class ResultColumn: 29 | """ 30 | Column data for a result set 31 | 32 | Args: 33 | name: column name 34 | type: column type 35 | character_set: column character set. Only relevant for string columns. 36 | text_encoder: Optionally override the function used to encode values for MySQL's text protocol 37 | binary_encoder: Optionally override the function used to encode values for MySQL's binary protocol 38 | """ 39 | 40 | def __init__( 41 | self, 42 | name: str, 43 | type: ColumnType, # pylint: disable=redefined-builtin 44 | character_set: CharacterSet = CharacterSet.utf8mb4, 45 | text_encoder: Optional[Encoder] = None, 46 | binary_encoder: Optional[Encoder] = None, 47 | ): 48 | self.name = name 49 | self.type = type 50 | self.character_set = character_set 51 | self.text_encoder = text_encoder or _TEXT_ENCODERS.get(type) or _unsupported 52 | self.binary_encoder = ( 53 | binary_encoder or _BINARY_ENCODERS.get(type) or _unsupported 54 | ) 55 | 56 | def text_encode(self, val: Any) -> bytes: 57 | return self.text_encoder(self, val) 58 | 59 | def binary_encode(self, val: Any) -> bytes: 60 | return self.binary_encoder(self, val) 61 | 62 | def __repr__(self) -> str: 63 | return f"ResultColumn({self.name} {self.type.name})" 64 | 65 | 66 | @dataclass 67 | class ResultSet: 68 | rows: Iterable[Sequence] | AsyncIterable[Sequence] 69 | columns: Sequence[ResultColumn] 70 | 71 | def __bool__(self) -> bool: 72 | return bool(self.columns) 73 | 74 | 75 | AllowedColumn = Union[ResultColumn, str] 76 | AllowedResult = Union[ 77 | ResultSet, 78 | Tuple[Sequence[Sequence[Any]], Sequence[AllowedColumn]], 79 | Tuple[AsyncIterable[Sequence[Any]], Sequence[AllowedColumn]], 80 | None, 81 | ] 82 | 83 | 84 | async def ensure_result_set(result: AllowedResult) -> ResultSet: 85 | if result is None: 86 | return ResultSet([], []) 87 | if isinstance(result, ResultSet): 88 | return result 89 | if isinstance(result, tuple): 90 | if len(result) != 2: 91 | raise MysqlError( 92 | f"Result tuple should be of size 2. Received: {len(result)}" 93 | ) 94 | rows = result[0] 95 | columns = result[1] 96 | 97 | return await _ensure_result_cols(rows, columns) 98 | 99 | raise MysqlError(f"Unexpected result set type: {type(result)}") 100 | 101 | 102 | async def _ensure_result_cols( 103 | rows: Sequence[Sequence[Any]] | AsyncIterable[Sequence[Any]], 104 | columns: Sequence[AllowedColumn], 105 | ) -> ResultSet: 106 | # Which columns need to be inferred? 107 | remaining = { 108 | col: i for i, col in enumerate(columns) if not isinstance(col, ResultColumn) 109 | } 110 | 111 | if not remaining: 112 | return ResultSet( 113 | rows=rows, 114 | columns=cast(Sequence[ResultColumn], columns), 115 | ) 116 | 117 | # Copy the columns 118 | columns = list(columns) 119 | 120 | arows = aiterate(rows) 121 | 122 | # Keep track of rows we've consumed from the iterator so we can add them back 123 | peeks = [] 124 | 125 | # Find the first non-null value for each column 126 | while remaining: 127 | try: 128 | peek = await arows.__anext__() 129 | except StopAsyncIteration: 130 | break 131 | 132 | peeks.append(peek) 133 | 134 | inferred = [] 135 | for name, i in remaining.items(): 136 | value = peek[i] 137 | if value is not None: 138 | type_ = infer_type(value) 139 | columns[i] = ResultColumn( 140 | name=str(name), 141 | type=type_, 142 | ) 143 | inferred.append(name) 144 | 145 | for name in inferred: 146 | remaining.pop(name) 147 | 148 | # If we failed to find a non-null value, set the type to NULL 149 | for name, i in remaining.items(): 150 | columns[i] = ResultColumn( 151 | name=str(name), 152 | type=ColumnType.NULL, 153 | ) 154 | 155 | # Add the consumed rows back in to the iterator 156 | async def gen_rows() -> AsyncIterable[Sequence[Any]]: 157 | for row in peeks: 158 | yield row 159 | 160 | async for row in arows: 161 | yield row 162 | 163 | assert all(isinstance(col, ResultColumn) for col in columns) 164 | return ResultSet(rows=gen_rows(), columns=cast(Sequence[ResultColumn], columns)) 165 | 166 | 167 | def _binary_encode_tiny(col: ResultColumn, val: Any) -> bytes: 168 | return uint_1(int(bool(val))) 169 | 170 | 171 | def _binary_encode_str(col: ResultColumn, val: Any) -> bytes: 172 | if not isinstance(val, bytes): 173 | val = str(val) 174 | 175 | if isinstance(val, str): 176 | val = val.encode(col.character_set.codec) 177 | 178 | return str_len(val) 179 | 180 | 181 | def _binary_encode_date(col: ResultColumn, val: Any) -> bytes: 182 | if isinstance(val, (float, int)): 183 | val = datetime.fromtimestamp(val) 184 | 185 | year = val.year 186 | month = val.month 187 | day = val.day 188 | 189 | if isinstance(val, datetime): 190 | hour = val.hour 191 | minute = val.minute 192 | second = val.second 193 | microsecond = val.microsecond 194 | else: 195 | hour = minute = second = microsecond = 0 196 | 197 | if microsecond == 0: 198 | if hour == minute == second == 0: 199 | if year == month == day == 0: 200 | return uint_1(0) 201 | return b"".join([uint_1(4), uint_2(year), uint_1(month), uint_1(day)]) 202 | return b"".join( 203 | [ 204 | uint_1(7), 205 | uint_2(year), 206 | uint_1(month), 207 | uint_1(day), 208 | uint_1(hour), 209 | uint_1(minute), 210 | uint_1(second), 211 | ] 212 | ) 213 | return b"".join( 214 | [ 215 | uint_1(11), 216 | uint_2(year), 217 | uint_1(month), 218 | uint_1(day), 219 | uint_1(hour), 220 | uint_1(minute), 221 | uint_1(second), 222 | uint_4(microsecond), 223 | ] 224 | ) 225 | 226 | 227 | def _binary_encode_short(col: ResultColumn, val: Any) -> bytes: 228 | return struct.pack(" bytes: 232 | return struct.pack(" bytes: 236 | return struct.pack(" bytes: 240 | return struct.pack(" bytes: 244 | return struct.pack(" bytes: 248 | return struct.pack(" bytes: 252 | days = abs(val.days) 253 | hours, remainder = divmod(abs(val.seconds), 3600) 254 | minutes, seconds = divmod(remainder, 60) 255 | microseconds = val.microseconds 256 | is_negative = val.total_seconds() < 0 257 | 258 | if microseconds == 0: 259 | if days == hours == minutes == seconds == 0: 260 | return uint_1(0) 261 | return b"".join( 262 | [ 263 | uint_1(8), 264 | uint_1(is_negative), 265 | uint_4(days), 266 | uint_1(hours), 267 | uint_1(minutes), 268 | uint_1(seconds), 269 | ] 270 | ) 271 | return b"".join( 272 | [ 273 | uint_1(12), 274 | uint_1(is_negative), 275 | uint_4(days), 276 | uint_1(hours), 277 | uint_1(minutes), 278 | uint_1(seconds), 279 | uint_4(microseconds), 280 | ] 281 | ) 282 | 283 | 284 | def _text_encode_str(col: ResultColumn, val: Any) -> bytes: 285 | if not isinstance(val, bytes): 286 | val = str(val) 287 | 288 | if isinstance(val, str): 289 | val = val.encode(col.character_set.codec) 290 | 291 | return val 292 | 293 | 294 | def _text_encode_tiny(col: ResultColumn, val: Any) -> bytes: 295 | return _text_encode_str(col, int(val)) 296 | 297 | 298 | def _unsupported(col: ResultColumn, val: Any) -> bytes: 299 | raise MysqlError(f"Unsupported column type: {col.type}") 300 | 301 | 302 | def _noop(col: ResultColumn, val: Any) -> bytes: 303 | return val 304 | 305 | 306 | # Order matters 307 | # bool is a subclass of int 308 | # datetime is a subclass of date 309 | _PY_TO_MYSQL_TYPE = { 310 | bool: ColumnType.TINY, 311 | datetime: ColumnType.DATETIME, 312 | str: ColumnType.STRING, 313 | bytes: ColumnType.BLOB, 314 | int: ColumnType.LONGLONG, 315 | float: ColumnType.DOUBLE, 316 | date: ColumnType.DATE, 317 | timedelta: ColumnType.TIME, 318 | } 319 | 320 | 321 | _TEXT_ENCODERS: Dict[ColumnType, Encoder] = { 322 | ColumnType.DECIMAL: _text_encode_str, 323 | ColumnType.TINY: _text_encode_tiny, 324 | ColumnType.SHORT: _text_encode_str, 325 | ColumnType.LONG: _text_encode_str, 326 | ColumnType.FLOAT: _text_encode_str, 327 | ColumnType.DOUBLE: _text_encode_str, 328 | ColumnType.NULL: _unsupported, 329 | ColumnType.TIMESTAMP: _text_encode_str, 330 | ColumnType.LONGLONG: _text_encode_str, 331 | ColumnType.INT24: _text_encode_str, 332 | ColumnType.DATE: _text_encode_str, 333 | ColumnType.TIME: _text_encode_str, 334 | ColumnType.DATETIME: _text_encode_str, 335 | ColumnType.YEAR: _text_encode_str, 336 | ColumnType.NEWDATE: _text_encode_str, 337 | ColumnType.VARCHAR: _text_encode_str, 338 | ColumnType.BIT: _text_encode_str, 339 | ColumnType.TIMESTAMP2: _unsupported, 340 | ColumnType.DATETIME2: _unsupported, 341 | ColumnType.TIME2: _unsupported, 342 | ColumnType.TYPED_ARRAY: _unsupported, 343 | ColumnType.INVALID: _unsupported, 344 | ColumnType.BOOL: _binary_encode_tiny, 345 | ColumnType.JSON: _text_encode_str, 346 | ColumnType.NEWDECIMAL: _text_encode_str, 347 | ColumnType.ENUM: _text_encode_str, 348 | ColumnType.SET: _text_encode_str, 349 | ColumnType.TINY_BLOB: _text_encode_str, 350 | ColumnType.MEDIUM_BLOB: _text_encode_str, 351 | ColumnType.LONG_BLOB: _text_encode_str, 352 | ColumnType.BLOB: _text_encode_str, 353 | ColumnType.VAR_STRING: _text_encode_str, 354 | ColumnType.STRING: _text_encode_str, 355 | ColumnType.GEOMETRY: _text_encode_str, 356 | } 357 | 358 | _BINARY_ENCODERS: Dict[ColumnType, Encoder] = { 359 | ColumnType.DECIMAL: _binary_encode_str, 360 | ColumnType.TINY: _binary_encode_tiny, 361 | ColumnType.SHORT: _binary_encode_short, 362 | ColumnType.LONG: _binary_encode_long, 363 | ColumnType.FLOAT: _binary_encode_float, 364 | ColumnType.DOUBLE: _binary_encode_double, 365 | ColumnType.NULL: _unsupported, 366 | ColumnType.TIMESTAMP: _binary_encode_date, 367 | ColumnType.LONGLONG: _binary_encode_longlong, 368 | ColumnType.INT24: _binary_encode_long, 369 | ColumnType.DATE: _binary_encode_date, 370 | ColumnType.TIME: _binary_encode_timedelta, 371 | ColumnType.DATETIME: _binary_encode_date, 372 | ColumnType.YEAR: _binary_encode_short, 373 | ColumnType.NEWDATE: _unsupported, 374 | ColumnType.VARCHAR: _binary_encode_str, 375 | ColumnType.BIT: _binary_encode_str, 376 | ColumnType.TIMESTAMP2: _unsupported, 377 | ColumnType.DATETIME2: _unsupported, 378 | ColumnType.TIME2: _unsupported, 379 | ColumnType.TYPED_ARRAY: _unsupported, 380 | ColumnType.INVALID: _unsupported, 381 | ColumnType.BOOL: _binary_encode_tiny, 382 | ColumnType.JSON: _binary_encode_str, 383 | ColumnType.NEWDECIMAL: _binary_encode_str, 384 | ColumnType.ENUM: _binary_encode_str, 385 | ColumnType.SET: _binary_encode_str, 386 | ColumnType.TINY_BLOB: _binary_encode_str, 387 | ColumnType.MEDIUM_BLOB: _binary_encode_str, 388 | ColumnType.LONG_BLOB: _binary_encode_str, 389 | ColumnType.BLOB: _binary_encode_str, 390 | ColumnType.VAR_STRING: _binary_encode_str, 391 | ColumnType.STRING: _binary_encode_str, 392 | ColumnType.GEOMETRY: _binary_encode_str, 393 | } 394 | 395 | 396 | def infer_type(val: Any) -> ColumnType: 397 | for py_type, my_type in _PY_TO_MYSQL_TYPE.items(): 398 | if isinstance(val, py_type): 399 | return my_type 400 | return ColumnType.VARCHAR 401 | 402 | 403 | class NullBitmap: 404 | """See https://dev.mysql.com/doc/internals/en/null-bitmap.html""" 405 | 406 | __slots__ = ("offset", "bitmap") 407 | 408 | def __init__(self, bitmap: bytearray, offset: int = 0): 409 | self.offset = offset 410 | self.bitmap = bitmap 411 | 412 | @classmethod 413 | def new(cls, num_bits: int, offset: int = 0) -> NullBitmap: 414 | bitmap = bytearray(cls._num_bytes(num_bits, offset)) 415 | return cls(bitmap, offset) 416 | 417 | @classmethod 418 | def from_buffer( 419 | cls, buffer: io.BytesIO, num_bits: int, offset: int = 0 420 | ) -> NullBitmap: 421 | bitmap = bytearray(buffer.read(cls._num_bytes(num_bits, offset))) 422 | return cls(bitmap, offset) 423 | 424 | @classmethod 425 | def _num_bytes(cls, num_bits: int, offset: int) -> int: 426 | return (num_bits + 7 + offset) // 8 427 | 428 | def flip(self, i: int) -> None: 429 | byte_position, bit_position = self._pos(i) 430 | self.bitmap[byte_position] |= 1 << bit_position 431 | 432 | def is_flipped(self, i: int) -> bool: 433 | byte_position, bit_position = self._pos(i) 434 | return bool(self.bitmap[byte_position] & (1 << bit_position)) 435 | 436 | def _pos(self, i: int) -> Tuple[int, int]: 437 | byte_position = (i + self.offset) // 8 438 | bit_position = (i + self.offset) % 8 439 | return byte_position, bit_position 440 | 441 | def __bytes__(self) -> bytes: 442 | return bytes(self.bitmap) 443 | 444 | def __repr__(self) -> str: 445 | return "".join(format(b, "08b") for b in self.bitmap) 446 | -------------------------------------------------------------------------------- /mysql_mimic/schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from collections import defaultdict 5 | from itertools import chain 6 | from dataclasses import dataclass 7 | from typing import Any, Optional, List, Dict, Iterable 8 | 9 | from sqlglot.executor import Table, execute 10 | from sqlglot import expressions as exp 11 | 12 | from mysql_mimic.constants import INFO_SCHEMA 13 | from mysql_mimic.results import AllowedResult 14 | from mysql_mimic.errors import MysqlError, ErrorCode 15 | from mysql_mimic.packets import ComFieldList 16 | from mysql_mimic.utils import dict_depth 17 | 18 | 19 | @dataclass 20 | class Column: 21 | name: str 22 | type: str 23 | table: str 24 | is_nullable: bool = True 25 | default: Optional[str] = None 26 | comment: Optional[str] = None 27 | schema: Optional[str] = None 28 | catalog: Optional[str] = "def" 29 | 30 | 31 | def mapping_to_columns(schema: dict) -> List[Column]: 32 | """Convert a schema mapping into a list of Column instances. 33 | 34 | Example schema defining columns by type: 35 | { 36 | "customer_table" : { 37 | "first_name" : "TEXT" 38 | "last_name" : "TEXT" 39 | "id" : "INT" 40 | }, 41 | "sales_table" : { 42 | "ds" : "DATE" 43 | "customer_id" : "INT" 44 | "amount" : "DOUBLE" 45 | } 46 | } 47 | """ 48 | depth = dict_depth(schema) 49 | if depth < 2: 50 | return [] 51 | if depth == 2: 52 | # {table: {col: type}} 53 | schema = {"": schema} 54 | depth += 1 55 | if depth == 3: 56 | # {db: {table: {col: type}}} 57 | schema = {"def": schema} # def is the default MySQL catalog 58 | depth += 1 59 | if depth != 4: 60 | raise MysqlError("Invalid schema mapping") 61 | 62 | result = [] 63 | for catalog, dbs in schema.items(): 64 | for db, tables in dbs.items(): 65 | for table, cols in tables.items(): 66 | for column, coltype in cols.items(): 67 | result.append( 68 | Column( 69 | name=column, 70 | type=coltype, 71 | table=table, 72 | schema=db, 73 | catalog=catalog, 74 | ) 75 | ) 76 | 77 | return result 78 | 79 | 80 | def info_schema_tables(columns: Iterable[Column]) -> Dict[str, Dict[str, Table]]: 81 | """ 82 | Convert a list of Column instances into a mapping of SQLGlot Tables. 83 | 84 | These Tables are used by SQLGlot to execute INFORMATION_SCHEMA queries. 85 | """ 86 | ordinal_positions: dict[Any, int] = defaultdict(lambda: 0) 87 | 88 | data = { 89 | db: {k: Table(tuple(v.keys())) for k, v in tables.items()} 90 | for db, tables in INFO_SCHEMA.items() 91 | } 92 | 93 | tables = set() 94 | dbs = set() 95 | catalogs = set() 96 | 97 | info_schema_cols = mapping_to_columns(INFO_SCHEMA) 98 | 99 | for column in chain(columns, info_schema_cols): 100 | tables.add((column.catalog, column.schema, column.table)) 101 | dbs.add((column.catalog, column.schema)) 102 | catalogs.add(column.catalog) 103 | key = (column.catalog, column.schema, column.table) 104 | ordinal_position = ordinal_positions[key] 105 | ordinal_positions[key] += 1 106 | data["information_schema"]["columns"].append( 107 | ( 108 | column.catalog, # table_catalog 109 | column.schema, # table_schema 110 | column.table, # table_name 111 | column.name, # column_name 112 | ordinal_position, # ordinal_position 113 | column.default, # column_default 114 | "YES" if column.is_nullable else "NO", # is_nullable 115 | column.type, # data_type 116 | None, # character_maximum_length 117 | None, # character_octet_length 118 | None, # numeric_precision 119 | None, # numeric_scale 120 | None, # datetime_precision 121 | "NULL", # character_set_name 122 | "NULL", # collation_name 123 | column.type, # column_type 124 | None, # column_key 125 | None, # extra 126 | None, # privileges 127 | column.comment, # column_comment 128 | None, # generation_expression 129 | None, # srs_id 130 | ) 131 | ) 132 | 133 | for catalog, db, table in sorted(tables): 134 | data["information_schema"]["tables"].append( 135 | ( 136 | catalog, # table_catalog 137 | db, # table_schema 138 | table, # table_name 139 | "SYSTEM TABLE" if db in INFO_SCHEMA else "BASE TABLE", # table_type 140 | "MinervaSQL", # engine 141 | "1.0", # version 142 | None, # row_format 143 | None, # table_rows 144 | None, # avg_row_length 145 | None, # data_length 146 | None, # max_data_length 147 | None, # index_length 148 | None, # data_free 149 | None, # auto_increment 150 | None, # create_time 151 | None, # update_time 152 | None, # check_time 153 | "utf8mb4_general_ci ", # table_collation 154 | None, # checksum 155 | None, # create_options 156 | None, # table_comment 157 | ) 158 | ) 159 | 160 | for catalog, db in sorted(dbs): 161 | data["information_schema"]["schemata"].append( 162 | ( 163 | catalog, # catalog_name 164 | db, # schema_name 165 | "utf8mb4", # default_character_set_name 166 | "utf8mb4_general_ci", # default_collation_name 167 | None, # sql_path 168 | ) 169 | ) 170 | 171 | return data 172 | 173 | 174 | def show_statement_to_info_schema_query( 175 | show: exp.Show, database: Optional[str] = None 176 | ) -> exp.Select: 177 | kind = show.name.upper() 178 | if kind == "COLUMNS": 179 | outputs = [ 180 | 'column_name AS "Field"', 181 | 'data_type AS "Type"', 182 | 'is_nullable AS "Null"', 183 | 'column_key AS "Key"', 184 | 'column_default AS "Default"', 185 | 'extra AS "Extra"', 186 | ] 187 | if show.args.get("full"): 188 | outputs.extend( 189 | [ 190 | 'collation_name AS "Collation"', 191 | 'privileges AS "Privileges"', 192 | 'column_comment AS "Comment"', 193 | ] 194 | ) 195 | table = show.text("target") 196 | if not table: 197 | raise MysqlError( 198 | "You have an error in your SQL syntax. Table name is missing.", 199 | code=ErrorCode.PARSE_ERROR, 200 | ) 201 | select = ( 202 | exp.select(*outputs) 203 | .from_("information_schema.columns") 204 | .where(f"table_name = '{table}'") 205 | ) 206 | db = show.text("db") or database 207 | if db: 208 | select = select.where(f"table_schema = '{db}'") 209 | like = show.text("like") 210 | if like: 211 | select = select.where(f"column_name LIKE '{like}'") 212 | elif kind == "TABLES": 213 | outputs = ['table_name AS "Table_name"'] 214 | if show.args.get("full"): 215 | outputs.extend(['table_type AS "Table_type"']) 216 | 217 | select = exp.select(*outputs).from_("information_schema.tables") 218 | db = show.text("db") or database 219 | if not db: 220 | raise MysqlError("No database selected.", code=ErrorCode.NO_DB_ERROR) 221 | select = select.where(f"table_schema = '{db}'") 222 | like = show.text("like") 223 | if like: 224 | select = select.where(f"table_name LIKE '{like}'") 225 | elif kind == "DATABASES": 226 | select = exp.select('schema_name AS "Database"').from_( 227 | "information_schema.schemata" 228 | ) 229 | like = show.text("like") 230 | if like: 231 | select = select.where(f"schema_name LIKE '{like}'") 232 | elif kind == "INDEX": 233 | outputs = [ 234 | '"table_name" AS Table', 235 | '"non_unique" AS Non_unique', 236 | '"index_name" AS Key_name', 237 | '"seq_in_index" AS Seq_in_index', 238 | '"column_name" AS Column_name', 239 | '"collation" AS Collation', 240 | '"cardinality" AS Cardinality', 241 | '"sub_part" AS Sub_part', 242 | '"packed" AS Packed', 243 | '"nullable" AS Null', 244 | '"index_type" AS Index_type', 245 | '"comment" AS Comment', 246 | '"index_comment" AS Index_comment', 247 | '"is_visible" AS Visible', 248 | '"expression" AS Expression', 249 | ] 250 | table = show.text("target") 251 | if not table: 252 | raise MysqlError( 253 | "You have an error in your SQL syntax. Table name is missing.", 254 | code=ErrorCode.PARSE_ERROR, 255 | ) 256 | select = ( 257 | exp.select(*outputs) 258 | .from_("information_schema.statistics") 259 | .where(f"table_name = '{table}'") 260 | ) 261 | db = show.text("db") or database 262 | if db: 263 | select = select.where(f"table_schema = '{db}'") 264 | else: 265 | raise MysqlError( 266 | f"Unsupported SHOW command: {kind}", code=ErrorCode.NOT_SUPPORTED_YET 267 | ) 268 | 269 | return select 270 | 271 | 272 | def com_field_list_to_show_statement(com_field_list: ComFieldList) -> str: 273 | show = exp.Show( 274 | this="COLUMNS", 275 | target=exp.to_identifier(com_field_list.table), 276 | ) 277 | if com_field_list.wildcard: 278 | show.set("like", exp.Literal.string(com_field_list.wildcard)) 279 | return show.sql(dialect="mysql") 280 | 281 | 282 | def like_to_regex(like: str) -> re.Pattern: 283 | like = like.replace("%", ".*?") 284 | like = like.replace("_", ".") 285 | return re.compile(like) 286 | 287 | 288 | class BaseInfoSchema: 289 | """ 290 | Base InfoSchema interface used by the `Session` class. 291 | """ 292 | 293 | async def query(self, expression: exp.Expression) -> AllowedResult: ... 294 | 295 | 296 | class InfoSchema(BaseInfoSchema): 297 | """ 298 | InfoSchema implementation that uses SQLGlot to execute queries. 299 | """ 300 | 301 | def __init__(self, tables: Dict[str, Dict[str, Table]]): 302 | self.tables = tables 303 | 304 | async def query(self, expression: exp.Expression) -> AllowedResult: 305 | expression = self._preprocess(expression) 306 | result = execute(expression, schema=INFO_SCHEMA, tables=self.tables) 307 | return result.rows, result.columns 308 | 309 | @classmethod 310 | def from_mapping(cls, mapping: dict) -> InfoSchema: 311 | columns = mapping_to_columns(mapping) 312 | return cls(info_schema_tables(columns)) 313 | 314 | @classmethod 315 | def from_columns(cls, columns: List[Column]) -> InfoSchema: 316 | return cls(info_schema_tables(columns)) 317 | 318 | def _preprocess(self, expression: exp.Expression) -> exp.Expression: 319 | return expression.transform(_remove_collate) 320 | 321 | 322 | def ensure_info_schema(schema: dict | BaseInfoSchema) -> BaseInfoSchema: 323 | if isinstance(schema, BaseInfoSchema): 324 | return schema 325 | return InfoSchema.from_mapping(schema) 326 | 327 | 328 | def _remove_collate(node: exp.Expression) -> exp.Expression: 329 | """ 330 | SQLGlot's executor doesn't support collation specifiers. 331 | 332 | Just remove them for now. 333 | """ 334 | if isinstance(node, exp.Collate): 335 | return node.left 336 | return node 337 | -------------------------------------------------------------------------------- /mysql_mimic/server.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import inspect 5 | import logging 6 | from ssl import SSLContext 7 | from socket import socket 8 | from typing import Callable, Any, Optional, Sequence, Awaitable 9 | 10 | from mysql_mimic import packets 11 | from mysql_mimic.auth import IdentityProvider, SimpleIdentityProvider 12 | from mysql_mimic.connection import Connection 13 | from mysql_mimic.control import Control, LocalControl, TooManyConnections 14 | from mysql_mimic.errors import ErrorCode 15 | from mysql_mimic.session import Session, BaseSession 16 | from mysql_mimic.constants import DEFAULT_SERVER_CAPABILITIES 17 | from mysql_mimic.stream import MysqlStream 18 | from mysql_mimic.types import Capabilities 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class MaxConnectionsExceeded(Exception): 25 | pass 26 | 27 | 28 | class MysqlServer: 29 | """ 30 | MySQL mimic server. 31 | 32 | Args: 33 | session_factory: Callable that takes no arguments and returns a session 34 | capabilities: server capability flags 35 | control: Control instance to use. Defaults to a LocalControl instance. 36 | identity_provider: Authentication plugins to register. Defaults to `SimpleIdentityProvider`, 37 | which just blindly accepts whatever `username` is given by the client. 38 | ssl: SSLContext instance if this server should enable TLS over connections 39 | 40 | **kwargs: extra keyword args passed to the asyncio start server command 41 | """ 42 | 43 | def __init__( 44 | self, 45 | session_factory: Callable[[], BaseSession | Awaitable[BaseSession]] = Session, 46 | capabilities: Capabilities = DEFAULT_SERVER_CAPABILITIES, 47 | control: Control | None = None, 48 | identity_provider: IdentityProvider | None = None, 49 | ssl: SSLContext | None = None, 50 | **serve_kwargs: Any, 51 | ): 52 | self.session_factory = session_factory 53 | self.capabilities = capabilities 54 | self.identity_provider = identity_provider or SimpleIdentityProvider() 55 | self.ssl = ssl 56 | 57 | self.control = control or LocalControl() 58 | self._serve_kwargs = serve_kwargs 59 | self._server: Optional[asyncio.base_events.Server] = None 60 | 61 | async def _client_connected_cb( 62 | self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter 63 | ) -> None: 64 | stream = MysqlStream(reader, writer) 65 | 66 | try: 67 | if inspect.iscoroutinefunction(self.session_factory): 68 | session = await self.session_factory() 69 | else: 70 | session = self.session_factory() 71 | 72 | connection = Connection( 73 | stream=stream, 74 | session=session, 75 | control=self.control, 76 | server_capabilities=self.capabilities, 77 | identity_provider=self.identity_provider, 78 | ssl=self.ssl, 79 | ) 80 | 81 | except Exception: # pylint: disable=broad-except 82 | logger.exception("Failed to create connection") 83 | await stream.write( 84 | # Return an error so clients don't freeze 85 | packets.make_error(capabilities=self.capabilities) 86 | ) 87 | return 88 | 89 | try: 90 | connection_id = await self.control.add(connection) 91 | connection.connection_id = connection_id 92 | except TooManyConnections: 93 | await stream.write( 94 | connection.error( 95 | msg="Too many connections", 96 | code=ErrorCode.CON_COUNT_ERROR, 97 | ) 98 | ) 99 | return 100 | except Exception: # pylint: disable=broad-except 101 | logger.exception("Failed to register connection") 102 | await stream.write(connection.error(msg="Failed to register connection")) 103 | return 104 | 105 | try: 106 | return await connection.start() 107 | finally: 108 | writer.close() 109 | await self.control.remove(connection_id) 110 | 111 | async def start_server(self, **kwargs: Any) -> None: 112 | """ 113 | Start an asyncio socket server. 114 | 115 | Args: 116 | **kwargs: keyword args passed to `asyncio.start_server` 117 | """ 118 | kw = {} 119 | kw.update(self._serve_kwargs) 120 | kw.update(kwargs) 121 | if "port" not in kw: 122 | kw["port"] = 3306 123 | self._server = await asyncio.start_server(self._client_connected_cb, **kw) 124 | 125 | async def start_unix_server(self, **kwargs: Any) -> None: 126 | """ 127 | Start an asyncio unix socket server. 128 | 129 | Args: 130 | **kwargs: keyword args passed to `asyncio.start_unix_server` 131 | """ 132 | kw = {} 133 | kw.update(self._serve_kwargs) 134 | kw.update(kwargs) 135 | self._server = await asyncio.start_unix_server(self._client_connected_cb, **kw) 136 | 137 | async def serve_forever(self, **kwargs: Any) -> None: 138 | """ 139 | Start accepting connections until the coroutine is cancelled. 140 | 141 | If a server isn't running, this starts a server with `start_server`. 142 | 143 | Args: 144 | **kwargs: keyword args passed to `start_server` 145 | """ 146 | if not self._server: 147 | await self.start_server(**kwargs) 148 | assert self._server is not None 149 | await self._server.serve_forever() 150 | 151 | def close(self) -> None: 152 | """ 153 | Stop serving. 154 | 155 | The server is closed asynchronously - 156 | use the `wait_closed` coroutine to wait until the server is closed. 157 | """ 158 | if self._server: 159 | self._server.close() 160 | 161 | async def wait_closed(self) -> None: 162 | """Wait until the `close` method completes.""" 163 | if self._server: 164 | await self._server.wait_closed() 165 | 166 | def sockets(self) -> Sequence[socket]: 167 | """Get sockets the server is listening on.""" 168 | if self._server: 169 | return self._server.sockets 170 | return () 171 | -------------------------------------------------------------------------------- /mysql_mimic/session.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from datetime import datetime, timezone as timezone_ 5 | from typing import ( 6 | Dict, 7 | List, 8 | TYPE_CHECKING, 9 | Optional, 10 | Callable, 11 | Awaitable, 12 | Type, 13 | Any, 14 | ) 15 | 16 | from sqlglot import Dialect 17 | from sqlglot.dialects import MySQL 18 | from sqlglot import expressions as exp 19 | from sqlglot.executor import execute 20 | 21 | from mysql_mimic.charset import CharacterSet 22 | from mysql_mimic.errors import ErrorCode, MysqlError 23 | from mysql_mimic.functions import Functions, mysql_datetime_function_mapping 24 | from mysql_mimic.intercept import ( 25 | setitem_kind, 26 | expression_to_value, 27 | TRANSACTION_CHARACTERISTICS, 28 | ) 29 | from mysql_mimic.schema import ( 30 | show_statement_to_info_schema_query, 31 | like_to_regex, 32 | BaseInfoSchema, 33 | ensure_info_schema, 34 | ) 35 | from mysql_mimic.constants import INFO_SCHEMA, KillKind 36 | from mysql_mimic.variable_processor import VariableProcessor 37 | from mysql_mimic.utils import find_dbs 38 | from mysql_mimic.variables import ( 39 | Variables, 40 | SessionVariables, 41 | GlobalVariables, 42 | DEFAULT, 43 | parse_timezone, 44 | ) 45 | from mysql_mimic.results import AllowedResult 46 | 47 | if TYPE_CHECKING: 48 | from mysql_mimic.connection import Connection 49 | 50 | 51 | Middleware = Callable[["Query"], Awaitable[AllowedResult]] 52 | 53 | 54 | def mysql_function_mapping(session: Session) -> Functions: 55 | # Information functions. 56 | # These will be replaced in the AST with their corresponding values. 57 | functions = { 58 | **mysql_datetime_function_mapping(session.timestamp), 59 | "CONNECTION_ID": lambda: session.connection.connection_id, 60 | "USER": lambda: session.variables.get("external_user"), 61 | "CURRENT_USER": lambda: session.username, 62 | "VERSION": lambda: session.variables.get("version"), 63 | "DATABASE": lambda: session.database, 64 | } 65 | # Synonyms 66 | functions.update( 67 | { 68 | "SYSTEM_USER": functions["USER"], 69 | "SESSION_USER": functions["USER"], 70 | "SCHEMA": functions["DATABASE"], 71 | } 72 | ) 73 | return functions 74 | 75 | 76 | @dataclass 77 | class Query: 78 | """ 79 | Convenience class that wraps the parameters to middleware. 80 | 81 | Args: 82 | expression: current query expression 83 | sql: the original SQL sent by the client 84 | attrs: query attributes 85 | _middlewares: subsequent middleware functions 86 | _query: the ultimate query method 87 | """ 88 | 89 | expression: exp.Expression 90 | sql: str 91 | attrs: Dict[str, str] 92 | _middlewares: list[Middleware] 93 | _query: Callable[[exp.Expression, str, dict[str, str]], Awaitable[AllowedResult]] 94 | 95 | async def next(self) -> AllowedResult: 96 | """ 97 | Call the next middleware in the chain of middlewares. 98 | 99 | Returns: 100 | The final query result. 101 | """ 102 | if not self._middlewares: 103 | return await self._query(self.expression, self.sql, self.attrs) 104 | q = Query( 105 | expression=self.expression, 106 | sql=self.sql, 107 | attrs=self.attrs, 108 | _middlewares=self._middlewares[1:], 109 | _query=self._query, 110 | ) 111 | return await self._middlewares[0](q) 112 | 113 | async def start(self) -> AllowedResult: 114 | """ 115 | Start the middleware chain. 116 | 117 | This should only be called by the framework code 118 | """ 119 | return await self._middlewares[0](self) 120 | 121 | 122 | class BaseSession: 123 | """ 124 | Session interface. 125 | 126 | This defines what the Connection object depends on. 127 | 128 | Most applications to implement the abstract `Session`, not this class. 129 | """ 130 | 131 | variables: Variables 132 | username: Optional[str] 133 | database: Optional[str] 134 | 135 | async def handle_query(self, sql: str, attrs: Dict[str, str]) -> AllowedResult: 136 | """ 137 | Main entrypoint for queries. 138 | 139 | Args: 140 | sql: SQL statement 141 | attrs: Mapping of query attributes 142 | Returns: 143 | One of: 144 | - tuple(rows, column_names), where "rows" is a sequence of sequences 145 | and "column_names" is a sequence of strings with the same length 146 | as every row in "rows" 147 | - tuple(rows, result_columns), where "rows" is the same 148 | as above, and "result_columns" is a sequence of mysql_mimic.ResultColumn 149 | instances. 150 | - mysql_mimic.ResultSet instance 151 | """ 152 | 153 | async def init(self, connection: Connection) -> None: 154 | """ 155 | Called when connection phase is complete. 156 | """ 157 | 158 | async def close(self) -> None: 159 | """ 160 | Called when the client closes the connection. 161 | """ 162 | 163 | async def reset(self) -> None: 164 | """ 165 | Called when a client resets the connection and after a COM_CHANGE_USER command. 166 | """ 167 | 168 | async def use(self, database: str) -> None: 169 | """ 170 | Use a new default database. 171 | 172 | Called when a USE database_name command is received. 173 | 174 | Args: 175 | database: database name 176 | """ 177 | 178 | 179 | class Session(BaseSession): 180 | """ 181 | Abstract session. 182 | 183 | This automatically handles lots of behavior that many clients except, 184 | e.g. session variables, SHOW commands, queries to INFORMATION_SCHEMA, and more 185 | """ 186 | 187 | dialect: Type[Dialect] = MySQL 188 | 189 | def __init__(self, variables: Variables | None = None): 190 | self.variables = variables or SessionVariables(GlobalVariables()) 191 | 192 | # Query middlewares. 193 | # These allow queries to be intercepted or wrapped. 194 | self.middlewares: list[Middleware] = [ 195 | self._set_var_middleware, 196 | self._set_middleware, 197 | self._static_query_middleware, 198 | self._use_middleware, 199 | self._kill_middleware, 200 | self._show_middleware, 201 | self._describe_middleware, 202 | self._begin_middleware, 203 | self._commit_middleware, 204 | self._rollback_middleware, 205 | self._info_schema_middleware, 206 | ] 207 | 208 | # Current database 209 | self.database = None 210 | 211 | # Current authenticated user 212 | self.username = None 213 | 214 | # Time when query started 215 | self.timestamp: datetime = datetime.now() 216 | 217 | self._connection: Optional[Connection] = None 218 | 219 | async def query( 220 | self, expression: exp.Expression, sql: str, attrs: Dict[str, str] 221 | ) -> AllowedResult: 222 | """ 223 | Process a SQL query. 224 | 225 | Args: 226 | expression: parsed AST of the statement from client 227 | sql: original SQL statement from client 228 | attrs: arbitrary query attributes set by client 229 | Returns: 230 | One of: 231 | - tuple(rows, column_names), where "rows" is a sequence of sequences 232 | and "column_names" is a sequence of strings with the same length 233 | as every row in "rows" 234 | - tuple(rows, result_columns), where "rows" is the same 235 | as above, and "result_columns" is a sequence of mysql_mimic.ResultColumn 236 | instances. 237 | - mysql_mimic.ResultSet instance 238 | """ 239 | return [], [] 240 | 241 | async def schema(self) -> dict | BaseInfoSchema: 242 | """ 243 | Provide the database schema. 244 | 245 | This is used to serve INFORMATION_SCHEMA and SHOW queries. 246 | 247 | Returns: 248 | One of: 249 | - Mapping of: 250 | {table: {column: column_type}} or 251 | {db: {table: {column: column_type}}} or 252 | {catalog: {db: {table: {column: column_type}}}} 253 | - Instance of `BaseInfoSchema` 254 | """ 255 | return {} 256 | 257 | @property 258 | def connection(self) -> Connection: 259 | """ 260 | Get the connection associated with this session. 261 | """ 262 | if self._connection is None: 263 | raise AttributeError("Session is not yet bound") 264 | return self._connection 265 | 266 | async def init(self, connection: Connection) -> None: 267 | """ 268 | Called when connection phase is complete. 269 | """ 270 | self._connection = connection 271 | 272 | async def close(self) -> None: 273 | """ 274 | Called when the client closes the connection. 275 | """ 276 | self._connection = None 277 | 278 | async def handle_query(self, sql: str, attrs: Dict[str, str]) -> AllowedResult: 279 | self.timestamp = datetime.now(tz=self.timezone()) 280 | result = None 281 | for expression in self._parse(sql): 282 | if not expression: 283 | continue 284 | q = Query( 285 | expression=expression, 286 | sql=sql, 287 | attrs=attrs, 288 | _middlewares=self.middlewares, 289 | _query=self.query, 290 | ) 291 | result = await q.start() 292 | return result 293 | 294 | async def use(self, database: str) -> None: 295 | self.database = database 296 | 297 | def _parse(self, sql: str) -> List[exp.Expression]: 298 | return [e for e in self.dialect().parse(sql) if e] 299 | 300 | async def _query_info_schema(self, expression: exp.Expression) -> AllowedResult: 301 | return await ensure_info_schema(await self.schema()).query(expression) 302 | 303 | async def _set_var_middleware(self, q: Query) -> AllowedResult: 304 | """Handles SET_VAR hints and replaces functions defined in the _functions mapping with their mapped values.""" 305 | with VariableProcessor( 306 | mysql_function_mapping(self), self.variables, q.expression 307 | ).set_variables(): 308 | return await q.next() 309 | 310 | async def _use_middleware(self, q: Query) -> AllowedResult: 311 | """Intercept USE statements""" 312 | if isinstance(q.expression, exp.Use): 313 | await self.use(q.expression.this.name) 314 | return [], [] 315 | return await q.next() 316 | 317 | async def _kill_middleware(self, q: Query) -> AllowedResult: 318 | """Intercept KILL statements""" 319 | if isinstance(q.expression, exp.Kill): 320 | control = self.connection.control 321 | if control: 322 | kind = KillKind[q.expression.text("kind").upper() or "CONNECTION"] 323 | this = q.expression.this.name 324 | 325 | try: 326 | connection_id = int(this) 327 | except ValueError as e: 328 | raise MysqlError( 329 | f"Invalid KILL connection ID: {this}", 330 | code=ErrorCode.PARSE_ERROR, 331 | ) from e 332 | 333 | await control.kill(connection_id, kind) 334 | return [], [] 335 | return await q.next() 336 | 337 | async def _show_middleware(self, q: Query) -> AllowedResult: 338 | """Intercept SHOW statements""" 339 | if isinstance(q.expression, exp.Show): 340 | return await self._show(q.expression) 341 | return await q.next() 342 | 343 | async def _show(self, expression: exp.Show) -> AllowedResult: 344 | kind = expression.name.upper() 345 | if kind == "VARIABLES": 346 | return self._show_variables(expression) 347 | if kind == "STATUS": 348 | return self._show_status(expression) 349 | if kind == "WARNINGS": 350 | return self._show_warnings(expression) 351 | if kind == "ERRORS": 352 | return self._show_errors(expression) 353 | select = show_statement_to_info_schema_query(expression, self.database) 354 | return await self._query_info_schema(select) 355 | 356 | async def _describe_middleware(self, q: Query) -> AllowedResult: 357 | """Intercept DESCRIBE statements""" 358 | if isinstance(q.expression, exp.Describe): 359 | if isinstance(q.expression.this, exp.Select): 360 | # Mysql parse treats EXPLAIN SELECT as a DESCRIBE SELECT statement 361 | return await q.next() 362 | name = q.expression.this.name 363 | show = self.dialect().parse(f"SHOW COLUMNS FROM {name}")[0] 364 | return await self._show(show) if isinstance(show, exp.Show) else None 365 | return await q.next() 366 | 367 | async def _rollback_middleware(self, q: Query) -> AllowedResult: 368 | """Intercept ROLLBACK statements""" 369 | if isinstance(q.expression, exp.Rollback): 370 | return [], [] 371 | return await q.next() 372 | 373 | async def _commit_middleware(self, q: Query) -> AllowedResult: 374 | """Intercept COMMIT statements""" 375 | if isinstance(q.expression, exp.Commit): 376 | return [], [] 377 | return await q.next() 378 | 379 | async def _begin_middleware(self, q: Query) -> AllowedResult: 380 | """Intercept BEGIN statements""" 381 | if isinstance(q.expression, exp.Transaction): 382 | return [], [] 383 | return await q.next() 384 | 385 | async def _static_query_middleware(self, q: Query) -> AllowedResult: 386 | """ 387 | Handle static queries (e.g. SELECT 1). 388 | 389 | These very common, as many clients execute commands like SELECT DATABASE() when connecting. 390 | """ 391 | if isinstance(q.expression, exp.Select) and not any( 392 | q.expression.args.get(a) 393 | for a in set(exp.Select.arg_types) - {"expressions", "limit", "hint"} 394 | ): 395 | result = execute(q.expression) 396 | return result.rows, result.columns 397 | return await q.next() 398 | 399 | async def _set_middleware(self, q: Query) -> AllowedResult: 400 | """Intercept SET statements""" 401 | if isinstance(q.expression, exp.Set): 402 | expressions = q.expression.expressions 403 | for item in expressions: 404 | assert isinstance(item, exp.SetItem) 405 | 406 | kind = setitem_kind(item) 407 | 408 | if kind == "VARIABLE": 409 | self._set_variable(item) 410 | elif kind == "CHARACTER SET": 411 | self._set_charset(item) 412 | elif kind == "NAMES": 413 | self._set_names(item) 414 | elif kind == "TRANSACTION": 415 | self._set_transaction(item) 416 | else: 417 | raise MysqlError( 418 | f"Unsupported SET statement: {kind}", 419 | code=ErrorCode.NOT_SUPPORTED_YET, 420 | ) 421 | 422 | return [], [] 423 | return await q.next() 424 | 425 | async def _info_schema_middleware(self, q: Query) -> AllowedResult: 426 | """Intercept queries to INFORMATION_SCHEMA tables""" 427 | dbs = find_dbs(q.expression) 428 | if (self.database and self.database.lower() in INFO_SCHEMA) or ( 429 | dbs and all(db.lower() in INFO_SCHEMA for db in dbs) 430 | ): 431 | return await self._query_info_schema(q.expression) 432 | return await q.next() 433 | 434 | def _set_variable(self, setitem: exp.SetItem) -> None: 435 | assignment = setitem.this 436 | left = assignment.left 437 | 438 | if isinstance(left, exp.SessionParameter): 439 | scope = left.text("kind") or "SESSION" 440 | name = left.name 441 | elif isinstance(left, exp.Parameter): 442 | raise MysqlError( 443 | "User-defined variables not supported yet", 444 | code=ErrorCode.NOT_SUPPORTED_YET, 445 | ) 446 | else: 447 | scope = setitem.text("kind") or "SESSION" 448 | name = left.name 449 | 450 | scope = scope.upper() 451 | value = expression_to_value(assignment.right) 452 | 453 | if scope in {"SESSION", "LOCAL"}: 454 | self.variables.set(name, value) 455 | else: 456 | raise MysqlError( 457 | f"Cannot SET variable {name} with scope {scope}", 458 | code=ErrorCode.NOT_SUPPORTED_YET, 459 | ) 460 | 461 | def _set_charset(self, item: exp.SetItem) -> None: 462 | charset_name: Any 463 | charset_conn: Any 464 | if item.name == "DEFAULT": 465 | charset_name = DEFAULT 466 | charset_conn = DEFAULT 467 | else: 468 | charset_name = item.name 469 | charset_conn = self.variables.get("character_set_database") 470 | self.variables.set("character_set_client", charset_name) 471 | self.variables.set("character_set_results", charset_name) 472 | self.variables.set("character_set_connection", charset_conn) 473 | 474 | def _set_names(self, item: exp.SetItem) -> None: 475 | charset_name: Any 476 | collation_name: Any 477 | if item.name == "DEFAULT": 478 | charset_name = DEFAULT 479 | collation_name = DEFAULT 480 | else: 481 | charset_name = item.name 482 | collation_name = ( 483 | item.text("collate") 484 | or CharacterSet[charset_name].default_collation.name 485 | ) 486 | self.variables.set("character_set_client", charset_name) 487 | self.variables.set("character_set_connection", charset_name) 488 | self.variables.set("character_set_results", charset_name) 489 | self.variables.set("collation_connection", collation_name) 490 | 491 | def _set_transaction(self, item: exp.SetItem) -> None: 492 | characteristics = [e.name.upper() for e in item.expressions] 493 | for characteristic in characteristics: 494 | variable, value = TRANSACTION_CHARACTERISTICS[characteristic] 495 | self.variables.set(variable, value) 496 | 497 | def _show_variables(self, show: exp.Show) -> AllowedResult: 498 | rows = [(k, None if v is None else str(v)) for k, v in self.variables.list()] 499 | like = show.text("like") 500 | if like: 501 | rows = [(k, v) for k, v in rows if like_to_regex(like).match(k)] 502 | return rows, ["Variable_name", "Value"] 503 | 504 | def _show_status(self, show: exp.Show) -> AllowedResult: 505 | return [], ["Variable_name", "Value"] 506 | 507 | def _show_warnings(self, show: exp.Show) -> AllowedResult: 508 | return [], ["Level", "Code", "Message"] 509 | 510 | def _show_errors(self, show: exp.Show) -> AllowedResult: 511 | return [], ["Level", "Code", "Message"] 512 | 513 | def timezone(self) -> timezone_: 514 | tz = self.variables.get("time_zone") 515 | return parse_timezone(tz) 516 | -------------------------------------------------------------------------------- /mysql_mimic/stream.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import struct 3 | from ssl import SSLContext 4 | 5 | from mysql_mimic.errors import MysqlError, ErrorCode 6 | from mysql_mimic.types import uint_3, uint_1 7 | from mysql_mimic.utils import seq 8 | 9 | 10 | class ConnectionClosed(Exception): 11 | pass 12 | 13 | 14 | class MysqlStream: 15 | def __init__( 16 | self, 17 | reader: asyncio.StreamReader, 18 | writer: asyncio.StreamWriter, 19 | buffer_size: int = 2**15, 20 | ): 21 | self.reader = reader 22 | self.writer = writer 23 | self.seq = seq(256) 24 | self._buffer = bytearray() 25 | self._buffer_size = buffer_size 26 | 27 | async def read(self) -> bytes: 28 | data = b"" 29 | while True: 30 | header = await self.reader.read(4) 31 | 32 | if not header: 33 | raise ConnectionClosed() 34 | 35 | i = struct.unpack("> 24 38 | 39 | expected = next(self.seq) 40 | if sequence_id != expected: 41 | raise MysqlError( 42 | f"Expected seq({expected}) got seq({sequence_id})", 43 | ErrorCode.MALFORMED_PACKET, 44 | ) 45 | 46 | if payload_length == 0: 47 | return data 48 | 49 | data += await self.reader.readexactly(payload_length) 50 | 51 | if payload_length < 0xFFFFFF: 52 | return data 53 | 54 | async def write(self, data: bytes, drain: bool = True) -> None: 55 | while True: 56 | # Grab first 0xFFFFFF bytes to send 57 | payload = data[:0xFFFFFF] 58 | data = data[0xFFFFFF:] 59 | 60 | payload_length = uint_3(len(payload)) 61 | sequence_id = uint_1(next(self.seq)) 62 | packet = payload_length + sequence_id + payload 63 | 64 | self._buffer.extend(packet) 65 | if drain or len(self._buffer) >= self._buffer_size: 66 | await self.drain() 67 | 68 | # We are done unless len(send) == 0xFFFFFF 69 | if len(payload) != 0xFFFFFF: 70 | return 71 | 72 | async def drain(self) -> None: 73 | if self._buffer: 74 | self.writer.write(self._buffer) 75 | self._buffer.clear() 76 | await self.writer.drain() 77 | 78 | def reset_seq(self) -> None: 79 | self.seq.reset() 80 | 81 | async def start_tls(self, ssl: SSLContext) -> None: 82 | transport = self.writer.transport 83 | protocol = transport.get_protocol() 84 | loop = asyncio.get_event_loop() 85 | new_transport = await loop.start_tls( 86 | transport=transport, 87 | protocol=protocol, 88 | sslcontext=ssl, 89 | server_side=True, 90 | ) 91 | 92 | # This seems to be the easiest way to wrap the socket created by asyncio 93 | self.writer._transport = new_transport # type: ignore # pylint: disable=protected-access 94 | self.reader._transport = new_transport # type: ignore # pylint: disable=protected-access 95 | -------------------------------------------------------------------------------- /mysql_mimic/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import io 3 | import struct 4 | 5 | from enum import IntEnum, IntFlag, auto 6 | 7 | 8 | class ColumnType(IntEnum): 9 | DECIMAL = 0x00 10 | TINY = 0x01 11 | SHORT = 0x02 12 | LONG = 0x03 13 | FLOAT = 0x04 14 | DOUBLE = 0x05 15 | NULL = 0x06 16 | TIMESTAMP = 0x07 17 | LONGLONG = 0x08 18 | INT24 = 0x09 19 | DATE = 0x0A 20 | TIME = 0x0B 21 | DATETIME = 0x0C 22 | YEAR = 0x0D 23 | NEWDATE = 0x0E 24 | VARCHAR = 0x0F 25 | BIT = 0x10 26 | TIMESTAMP2 = 0x11 27 | DATETIME2 = 0x12 28 | TIME2 = 0x13 29 | TYPED_ARRAY = 0x14 30 | INVALID = 0xF3 31 | BOOL = 0xF4 32 | JSON = 0xF5 33 | NEWDECIMAL = 0xF6 34 | ENUM = 0xF7 35 | SET = 0xF8 36 | TINY_BLOB = 0xF9 37 | MEDIUM_BLOB = 0xFA 38 | LONG_BLOB = 0xFB 39 | BLOB = 0xFC 40 | VAR_STRING = 0xFD 41 | STRING = 0xFE 42 | GEOMETRY = 0xFF 43 | 44 | 45 | class Commands(IntEnum): 46 | COM_SLEEP = 0x00 47 | COM_QUIT = 0x01 48 | COM_INIT_DB = 0x02 49 | COM_QUERY = 0x03 50 | COM_FIELD_LIST = 0x04 51 | COM_CREATE_DB = 0x05 52 | COM_DROP_DB = 0x06 53 | COM_REFRESH = 0x07 54 | COM_SHUTDOWN = 0x08 55 | COM_STATISTICS = 0x09 56 | COM_PROCESS_INFO = 0x0A 57 | COM_CONNECT = 0x0B 58 | COM_PROCESS_KILL = 0x0C 59 | COM_DEBUG = 0x0D 60 | COM_PING = 0x0E 61 | COM_TIME = 0x0F 62 | COM_DELAYED_INSERT = 0x10 63 | COM_CHANGE_USER = 0x11 64 | COM_BINLOG_DUMP = 0x12 65 | COM_TABLE_DUMP = 0x13 66 | COM_CONNECT_OUT = 0x14 67 | COM_REGISTER_SLAVE = 0x15 68 | COM_STMT_PREPARE = 0x16 69 | COM_STMT_EXECUTE = 0x17 70 | COM_STMT_SEND_LONG_DATA = 0x18 71 | COM_STMT_CLOSE = 0x19 72 | COM_STMT_RESET = 0x1A 73 | COM_SET_OPTION = 0x1B 74 | COM_STMT_FETCH = 0x1C 75 | COM_DAEMON = 0x1D 76 | COM_BINLOG_DUMP_GTID = 0x1E 77 | COM_RESET_CONNECTION = 0x1F 78 | 79 | 80 | class ColumnDefinition(IntFlag): 81 | NOT_NULL_FLAG = auto() 82 | PRI_KEY_FLAG = auto() 83 | UNIQUE_KEY_FLAG = auto() 84 | MULTIPLE_KEY_FLAG = auto() 85 | BLOB_FLAG = auto() 86 | UNSIGNED_FLAG = auto() 87 | ZEROFILL_FLAG = auto() 88 | BINARY_FLAG = auto() 89 | ENUM_FLAG = auto() 90 | AUTO_INCREMENT_FLAG = auto() 91 | TIMESTAMP_FLAG = auto() 92 | SET_FLAG = auto() 93 | NO_DEFAULT_VALUE_FLAG = auto() 94 | ON_UPDATE_NOW_FLAG = auto() 95 | NUM_FLAG = auto() 96 | PART_KEY_FLAG = auto() 97 | GROUP_FLAG = auto() 98 | UNIQUE_FLAG = auto() 99 | BINCMP_FLAG = auto() 100 | GET_FIXED_FIELDS_FLAG = auto() 101 | FIELD_IN_PART_FUNC_FLAG = auto() 102 | FIELD_IN_ADD_INDEX = auto() 103 | FIELD_IS_RENAMED = auto() 104 | FIELD_FLAGS_STORAGE_MEDIA = auto() 105 | FIELD_FLAGS_STORAGE_MEDIA_MASK = auto() 106 | FIELD_FLAGS_COLUMN_FORMAT = auto() 107 | FIELD_FLAGS_COLUMN_FORMAT_MASK = auto() 108 | FIELD_IS_DROPPED = auto() 109 | EXPLICIT_NULL_FLAG = auto() 110 | NOT_SECONDARY_FLAG = auto() 111 | FIELD_IS_INVISIBLE = auto() 112 | 113 | 114 | class Capabilities(IntFlag): 115 | CLIENT_LONG_PASSWORD = auto() 116 | CLIENT_FOUND_ROWS = auto() 117 | CLIENT_LONG_FLAG = auto() 118 | CLIENT_CONNECT_WITH_DB = auto() 119 | CLIENT_NO_SCHEMA = auto() 120 | CLIENT_COMPRESS = auto() 121 | CLIENT_ODBC = auto() 122 | CLIENT_LOCAL_FILES = auto() 123 | CLIENT_IGNORE_SPACE = auto() 124 | CLIENT_PROTOCOL_41 = auto() 125 | CLIENT_INTERACTIVE = auto() 126 | CLIENT_SSL = auto() 127 | CLIENT_IGNORE_SIGPIPE = auto() 128 | CLIENT_TRANSACTIONS = auto() 129 | CLIENT_RESERVED = auto() 130 | CLIENT_SECURE_CONNECTION = auto() 131 | CLIENT_MULTI_STATEMENTS = auto() 132 | CLIENT_MULTI_RESULTS = auto() 133 | CLIENT_PS_MULTI_RESULTS = auto() 134 | CLIENT_PLUGIN_AUTH = auto() 135 | CLIENT_CONNECT_ATTRS = auto() 136 | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = auto() 137 | CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = auto() 138 | CLIENT_SESSION_TRACK = auto() 139 | CLIENT_DEPRECATE_EOF = auto() 140 | CLIENT_OPTIONAL_RESULTSET_METADATA = auto() 141 | CLIENT_ZSTD_COMPRESSION_ALGORITHM = auto() 142 | CLIENT_QUERY_ATTRIBUTES = auto() 143 | MULTI_FACTOR_AUTHENTICATION = auto() 144 | CLIENT_CAPABILITY_EXTENSION = auto() 145 | CLIENT_SSL_VERIFY_SERVER_CERT = auto() 146 | CLIENT_REMEMBER_OPTIONS = auto() 147 | 148 | 149 | class ServerStatus(IntFlag): 150 | SERVER_STATUS_IN_TRANS = 0x0001 151 | SERVER_STATUS_AUTOCOMMIT = 0x0002 152 | SERVER_MORE_RESULTS_EXISTS = 0x0008 153 | SERVER_STATUS_NO_GOOD_INDEX_USED = 0x0010 154 | SERVER_STATUS_NO_INDEX_USED = 0x0020 155 | SERVER_STATUS_CURSOR_EXISTS = 0x0040 156 | SERVER_STATUS_LAST_ROW_SENT = 0x0080 157 | SERVER_STATUS_DB_DROPPED = 0x0100 158 | SERVER_STATUS_NO_BACKSLASH_ESCAPES = 0x0200 159 | SERVER_STATUS_METADATA_CHANGED = 0x0400 160 | SERVER_QUERY_WAS_SLOW = 0x0800 161 | SERVER_PS_OUT_PARAMS = 0x1000 162 | SERVER_STATUS_IN_TRANS_READONLY = 0x2000 163 | SERVER_SESSION_STATE_CHANGED = 0x4000 164 | 165 | 166 | class ResultsetMetadata(IntEnum): 167 | RESULTSET_METADATA_NONE = 0 168 | RESULTSET_METADATA_FULL = 1 169 | 170 | 171 | class ComStmtExecuteFlags(IntFlag): 172 | CURSOR_TYPE_NO_CURSOR = 0x00 173 | CURSOR_TYPE_READ_ONLY = 0x01 174 | CURSOR_TYPE_FOR_UPDATE = 0x02 175 | CURSOR_TYPE_SCROLLABLE = 0x04 176 | PARAMETER_COUNT_AVAILABLE = 0x08 177 | 178 | 179 | def uint_len(i: int) -> bytes: 180 | if i < 251: 181 | return struct.pack(" bytes: 191 | return struct.pack(" bytes: 195 | return struct.pack(" bytes: 199 | return struct.pack("> 16) 200 | 201 | 202 | def uint_4(i: int) -> bytes: 203 | return struct.pack(" bytes: 207 | return struct.pack("> 32) 208 | 209 | 210 | def uint_8(i: int) -> bytes: 211 | return struct.pack(" bytes: 215 | return struct.pack(f"<{l}s", s) 216 | 217 | 218 | def str_null(s: bytes) -> bytes: 219 | l = len(s) 220 | return struct.pack(f"<{l}sB", s, 0) 221 | 222 | 223 | def str_len(s: bytes) -> bytes: 224 | l = len(s) 225 | return uint_len(l) + str_fixed(l, s) 226 | 227 | 228 | def str_rest(s: bytes) -> bytes: 229 | l = len(s) 230 | return str_fixed(l, s) 231 | 232 | 233 | def read_int_1(reader: io.BytesIO) -> int: 234 | data = reader.read(1) 235 | return struct.unpack(" int: 239 | data = reader.read(1) 240 | return struct.unpack(" int: 244 | data = reader.read(2) 245 | return struct.unpack(" int: 249 | data = reader.read(2) 250 | return struct.unpack(" int: 254 | data = reader.read(3) 255 | t = struct.unpack(" int: 260 | data = reader.read(4) 261 | return struct.unpack(" int: 265 | data = reader.read(4) 266 | return struct.unpack(" int: 270 | data = reader.read(6) 271 | t = struct.unpack(" int: 276 | data = reader.read(8) 277 | return struct.unpack(" int: 281 | data = reader.read(8) 282 | return struct.unpack(" float: 286 | data = reader.read(4) 287 | return struct.unpack(" float: 291 | data = reader.read(8) 292 | return struct.unpack(" int: 296 | i = read_uint_1(reader) 297 | 298 | if i == 0xFE: 299 | return read_uint_8(reader) 300 | 301 | if i == 0xFD: 302 | return read_uint_3(reader) 303 | 304 | if i == 0xFC: 305 | return read_uint_2(reader) 306 | 307 | return i 308 | 309 | 310 | def read_str_fixed(reader: io.BytesIO, l: int) -> bytes: 311 | return reader.read(l) 312 | 313 | 314 | def read_str_null(reader: io.BytesIO) -> bytes: 315 | data = b"" 316 | while True: 317 | b = reader.read(1) 318 | if b == b"\x00": 319 | return data 320 | data += b 321 | 322 | 323 | def read_str_len(reader: io.BytesIO) -> bytes: 324 | l = read_uint_len(reader) 325 | return read_str_fixed(reader, l) 326 | 327 | 328 | def read_str_rest(reader: io.BytesIO) -> bytes: 329 | return reader.read() 330 | 331 | 332 | def peek(reader: io.BytesIO, num_bytes: int = 1) -> bytes: 333 | pos = reader.tell() 334 | val = reader.read(num_bytes) 335 | reader.seek(pos) 336 | return val 337 | -------------------------------------------------------------------------------- /mysql_mimic/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import inspect 5 | import sys 6 | from collections.abc import Iterator # pylint: disable=import-error 7 | import random 8 | from typing import List, TypeVar, AsyncIterable, Iterable, AsyncIterator, cast 9 | import string 10 | 11 | from sqlglot import expressions as exp 12 | from sqlglot.optimizer.scope import traverse_scope 13 | 14 | 15 | T = TypeVar("T") 16 | 17 | # MySQL Connector/J uses ASCII to decode nonce 18 | SAFE_NONCE_CHARS = (string.ascii_letters + string.digits).encode() 19 | 20 | 21 | class seq(Iterator): 22 | """Auto-incrementing sequence with an optional maximum size""" 23 | 24 | def __init__(self, size: int | None = None): 25 | self.size = size 26 | self.value = 0 27 | 28 | def __next__(self) -> int: 29 | value = self.value 30 | self.value = self.value + 1 31 | if self.size: 32 | self.value = self.value % self.size 33 | return value 34 | 35 | def reset(self) -> None: 36 | self.value = 0 37 | 38 | 39 | def xor(a: bytes, b: bytes) -> bytes: 40 | # Fast XOR implementation, according to https://stackoverflow.com/questions/29408173/byte-operations-xor-in-python 41 | a, b = a[: len(b)], b[: len(a)] 42 | int_b = int.from_bytes(b, sys.byteorder) 43 | int_a = int.from_bytes(a, sys.byteorder) 44 | int_enc = int_b ^ int_a 45 | return int_enc.to_bytes(len(b), sys.byteorder) 46 | 47 | 48 | def nonce(nbytes: int) -> bytes: 49 | return bytes( 50 | bytearray( 51 | [random.SystemRandom().choice(SAFE_NONCE_CHARS) for _ in range(nbytes)] 52 | ) 53 | ) 54 | 55 | 56 | def find_tables(expression: exp.Expression) -> List[exp.Table]: 57 | """Find all tables in an expression""" 58 | if isinstance( 59 | expression, (exp.Select, exp.Subquery, exp.Union, exp.Except, exp.Intersect) 60 | ): 61 | return [ 62 | source 63 | for scope in traverse_scope(expression) 64 | for source in scope.sources.values() 65 | if isinstance(source, exp.Table) 66 | ] 67 | return [] 68 | 69 | 70 | def find_dbs(expression: exp.Expression) -> List[str]: 71 | """Find all database names in an expression""" 72 | return [table.text("db") for table in find_tables(expression)] 73 | 74 | 75 | def dict_depth(d: dict) -> int: 76 | """ 77 | Get the nesting depth of a dictionary. 78 | For example: 79 | >>> dict_depth(None) 80 | 0 81 | >>> dict_depth({}) 82 | 1 83 | >>> dict_depth({"a": "b"}) 84 | 1 85 | >>> dict_depth({"a": {}}) 86 | 2 87 | >>> dict_depth({"a": {"b": {}}}) 88 | 3 89 | Args: 90 | d (dict): dictionary 91 | Returns: 92 | int: depth 93 | """ 94 | try: 95 | return 1 + dict_depth(next(iter(d.values()))) 96 | except AttributeError: 97 | # d doesn't have attribute "values" 98 | return 0 99 | except StopIteration: 100 | # d.values() returns an empty sequence 101 | return 1 102 | 103 | 104 | async def aiterate(iterable: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T]: 105 | """Iterate either an async iterable or a regular iterable""" 106 | if inspect.isasyncgen(iterable): 107 | async for item in iterable: 108 | yield item 109 | else: 110 | for item in cast(Iterable, iterable): 111 | yield item 112 | 113 | 114 | async def cooperative_iterate( 115 | iterable: AsyncIterable[T], batch_size: int = 10_000 116 | ) -> AsyncIterator[T]: 117 | """ 118 | Iterate an async iterable in a cooperative manner, yielding control back to the event loop every `batch_size` iterations 119 | """ 120 | i = 0 121 | async for item in iterable: 122 | if i != 0 and i % batch_size == 0: 123 | await asyncio.sleep(0) 124 | yield item 125 | i += 1 126 | -------------------------------------------------------------------------------- /mysql_mimic/variable_processor.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Dict, Mapping, Generator, MutableMapping, Any 3 | 4 | from sqlglot import expressions as exp 5 | 6 | from mysql_mimic.intercept import value_to_expression, expression_to_value 7 | 8 | variable_constants = { 9 | "CURRENT_USER", 10 | "CURRENT_TIME", 11 | "CURRENT_TIMESTAMP", 12 | "CURRENT_DATE", 13 | } 14 | 15 | 16 | def _get_var_assignments(expression: exp.Expression) -> Dict[str, str]: 17 | """Returns a dictionary of session variables to replace, as indicated by SET_VAR hints.""" 18 | hints = expression.find_all(exp.Hint) 19 | if not hints: 20 | return {} 21 | 22 | assignments = {} 23 | 24 | # Iterate in reverse order so higher SET_VAR hints get priority 25 | for hint in reversed(list(hints)): 26 | set_var_hint = None 27 | 28 | for e in hint.expressions: 29 | if isinstance(e, exp.Func) and e.name == "SET_VAR": 30 | set_var_hint = e 31 | for eq in e.expressions: 32 | assignments[eq.left.name] = expression_to_value(eq.right) 33 | 34 | if set_var_hint: 35 | set_var_hint.pop() 36 | 37 | # Remove the hint entirely if SET_VAR was the only expression 38 | if not hint.expressions: 39 | hint.pop() 40 | 41 | return assignments 42 | 43 | 44 | class VariableProcessor: 45 | """ 46 | This class modifies the query in two ways: 47 | 1. Processing SET_VAR hints for system variables in the query text 48 | 2. Replacing functions in the query with their replacements defined in the functions argument. 49 | original values. 50 | """ 51 | 52 | def __init__( 53 | self, functions: Mapping, variables: MutableMapping, expression: exp.Expression 54 | ): 55 | self._functions = functions 56 | self._variables = variables 57 | self._expression = expression 58 | 59 | # Stores the original system variable values. 60 | self._orig: Dict[str, Any] = {} 61 | 62 | @contextmanager 63 | def set_variables(self) -> Generator[exp.Expression, None, None]: 64 | assignments = _get_var_assignments(self._expression) 65 | self._orig = {k: self._variables.get(k) for k in assignments} 66 | for k, v in assignments.items(): 67 | self._variables[k] = v 68 | 69 | self._replace_variables() 70 | 71 | yield self._expression 72 | 73 | for k, v in self._orig.items(): 74 | self._variables[k] = v 75 | 76 | def _replace_variables(self) -> None: 77 | """Replaces certain functions in the query with literals provided from the mapping in _functions, 78 | and session parameters with the values of the session variables. 79 | """ 80 | if isinstance(self._expression, exp.Set): 81 | for setitem in self._expression.expressions: 82 | if isinstance(setitem.this, exp.Binary): 83 | # In the case of statements like: SET @@foo = @@bar 84 | # We only want to replace variables on the right 85 | setitem.this.set( 86 | "expression", 87 | setitem.this.expression.transform(self._transform, copy=True), 88 | ) 89 | else: 90 | self._expression.transform(self._transform, copy=False) 91 | 92 | def _transform(self, node: exp.Expression) -> exp.Expression: 93 | new_node = None 94 | 95 | if isinstance(node, exp.Func): 96 | if isinstance(node, exp.Anonymous): 97 | func_name = node.name.upper() 98 | else: 99 | func_name = node.sql_name() 100 | func = self._functions.get(func_name) 101 | if func: 102 | value = func() 103 | new_node = value_to_expression(value) 104 | elif isinstance(node, exp.Column) and node.sql() in variable_constants: 105 | value = self._functions[node.sql()]() 106 | new_node = value_to_expression(value) 107 | elif isinstance(node, exp.SessionParameter): 108 | value = self._variables.get(node.name) 109 | new_node = value_to_expression(value) 110 | 111 | if ( 112 | new_node 113 | and isinstance(node.parent, exp.Select) 114 | and node.arg_key == "expressions" 115 | ): 116 | new_node = exp.alias_(new_node, exp.to_identifier(node.sql())) 117 | 118 | return new_node or node 119 | -------------------------------------------------------------------------------- /mysql_mimic/variables.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | import re 5 | from datetime import timezone, timedelta 6 | from functools import lru_cache 7 | from typing import Any, Callable, Tuple, Iterator, MutableMapping 8 | 9 | from mysql_mimic.charset import CharacterSet, Collation 10 | from mysql_mimic.errors import MysqlError, ErrorCode 11 | 12 | 13 | class Default: ... 14 | 15 | 16 | VariableType = Callable[[Any], Any] 17 | VariableSchema = Tuple[VariableType, Any, bool] 18 | 19 | 20 | DEFAULT = Default() 21 | 22 | SYSTEM_VARIABLES: dict[str, VariableSchema] = { 23 | # name: (type, default, dynamic) 24 | "auto_increment_increment": (int, 1, True), 25 | "autocommit": (bool, True, True), 26 | "character_set_client": (str, CharacterSet.utf8mb4.name, True), 27 | "character_set_connection": (str, CharacterSet.utf8mb4.name, True), 28 | "character_set_database": (str, CharacterSet.utf8mb4.name, True), 29 | "character_set_results": (str, CharacterSet.utf8mb4.name, True), 30 | "character_set_server": (str, CharacterSet.utf8mb4.name, True), 31 | "collation_connection": (str, Collation.utf8mb4_general_ci.name, True), 32 | "collation_database": (str, Collation.utf8mb4_general_ci.name, True), 33 | "collation_server": (str, Collation.utf8mb4_general_ci.name, True), 34 | "external_user": (str, "", False), 35 | "init_connect": (str, "", True), 36 | "interactive_timeout": (int, 28800, True), 37 | "license": (str, "MIT", False), 38 | "lower_case_table_names": (int, 0, True), 39 | "max_allowed_packet": (int, 67108864, True), 40 | "max_execution_time": (int, 0, True), 41 | "net_buffer_length": (int, 16384, True), 42 | "net_write_timeout": (int, 28800, True), 43 | "performance_schema": (bool, False, False), 44 | "sql_auto_is_null": (bool, False, True), 45 | "sql_mode": (str, "ANSI", True), 46 | "sql_select_limit": (int, None, True), 47 | "system_time_zone": (str, "UTC", False), 48 | "time_zone": (str, "UTC", True), 49 | "transaction_read_only": (bool, False, True), 50 | "transaction_isolation": (str, "READ-COMMITTED", True), 51 | "version": (str, "8.0.29", False), 52 | "version_comment": (str, "mysql-mimic", False), 53 | "wait_timeout": (int, 28800, True), 54 | "event_scheduler": (str, "OFF", True), 55 | "default_storage_engine": (str, "mysql-mimic", True), 56 | "default_tmp_storage_engine": (str, "mysql-mimic", True), 57 | } 58 | 59 | 60 | class Variables(abc.ABC, MutableMapping[str, Any]): 61 | """ 62 | Abstract class for MySQL system variables. 63 | """ 64 | 65 | def __init__(self) -> None: 66 | # Current variable values 67 | self._values: dict[str, Any] = {} 68 | 69 | def __getitem__(self, key: str) -> Any | None: 70 | try: 71 | return self.get_variable(key) 72 | except MysqlError as e: 73 | raise KeyError from e 74 | 75 | def __setitem__(self, key: str, value: Any) -> None: 76 | return self.set(key, value) 77 | 78 | def __delitem__(self, key: str) -> None: 79 | raise MysqlError(f"Cannot delete session variable {key}.") 80 | 81 | def __iter__(self) -> Iterator[str]: 82 | return self.schema.__iter__() 83 | 84 | def __len__(self) -> int: 85 | return len(self.schema) 86 | 87 | def get_schema(self, name: str) -> VariableSchema: 88 | schema = self.schema.get(name) 89 | if not schema: 90 | raise MysqlError( 91 | f"Unknown variable: {name}", code=ErrorCode.UNKNOWN_SYSTEM_VARIABLE 92 | ) 93 | return schema 94 | 95 | def set(self, name: str, value: Any, force: bool = False) -> None: 96 | name = name.lower() 97 | type_, default, dynamic = self.get_schema(name) 98 | 99 | if not dynamic and not force: 100 | raise MysqlError( 101 | f"Variable is not dynamic: {name}", code=ErrorCode.PARSE_ERROR 102 | ) 103 | 104 | if value is DEFAULT or value is None: 105 | self._values[name] = default 106 | else: 107 | self._values[name] = type_(value) 108 | 109 | def get_variable(self, name: str) -> Any | None: 110 | name = name.lower() 111 | if name in self._values: 112 | return self._values[name] 113 | _, default, _ = self.get_schema(name) 114 | 115 | return default 116 | 117 | def list(self) -> list[tuple[str, Any]]: 118 | return [(name, self.get(name)) for name in sorted(self.schema)] 119 | 120 | @property 121 | @abc.abstractmethod 122 | def schema(self) -> dict[str, VariableSchema]: ... 123 | 124 | 125 | class GlobalVariables(Variables): 126 | def __init__(self, schema: dict[str, VariableSchema] | None = None): 127 | self._schema = schema or SYSTEM_VARIABLES 128 | super().__init__() 129 | 130 | @property 131 | def schema(self) -> dict[str, VariableSchema]: 132 | return self._schema 133 | 134 | 135 | class SessionVariables(Variables): 136 | def __init__(self, global_variables: Variables): 137 | self.global_variables = global_variables 138 | super().__init__() 139 | 140 | @property 141 | def schema(self) -> dict[str, VariableSchema]: 142 | return self.global_variables.schema 143 | 144 | 145 | RE_TIMEZONE = re.compile(r"^(?P[+-])(?P\d\d):(?P\d\d)") 146 | 147 | 148 | @lru_cache(maxsize=48) 149 | def parse_timezone(tz: str) -> timezone: 150 | if tz.lower() == "utc": 151 | return timezone.utc 152 | match = RE_TIMEZONE.match(tz) 153 | if not match: 154 | raise MysqlError(msg=f"Invalid timezone: {tz}") 155 | offset = timedelta( 156 | hours=int(match.group("hours")), minutes=int(match.group("minutes")) 157 | ) 158 | if match.group("sign") == "-": 159 | offset = offset * -1 160 | return timezone(offset) 161 | -------------------------------------------------------------------------------- /mysql_mimic/version.py: -------------------------------------------------------------------------------- 1 | """mysql-mimic version information""" 2 | 3 | __version__ = "2.6.4" 4 | 5 | 6 | def main(name: str) -> None: 7 | if name == "__main__": 8 | print(__version__) 9 | 10 | 11 | main(__name__) 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] 7 | 8 | #[tool.pytest.ini_options] 9 | #log_cli = true 10 | #log_cli_level = "DEBUG" 11 | 12 | [tool.mypy] 13 | ignore_missing_imports = true 14 | disallow_untyped_defs = true 15 | disallow_incomplete_defs = true 16 | exclude = ['build'] 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Import __version__ 4 | exec(open("mysql_mimic/version.py").read()) 5 | 6 | setup( 7 | name="mysql-mimic", 8 | version=__version__, 9 | description="A python implementation of the mysql server protocol", 10 | long_description=open("README.md").read(), 11 | long_description_content_type="text/markdown", 12 | url="https://github.com/kelsin/mysql-mimic", 13 | author="Christopher Giroir", 14 | author_email="kelsin@valefor.com", 15 | license="MIT", 16 | packages=find_packages(include=["mysql_mimic", "mysql_mimic.*"]), 17 | python_requires=">=3.6", 18 | install_requires=["sqlglot"], 19 | extras_require={ 20 | "dev": [ 21 | "aiomysql", 22 | "mypy", 23 | "mysql-connector-python", 24 | "black", 25 | "coverage", 26 | "freezegun", 27 | "gssapi", 28 | "k5test", 29 | "pylint", 30 | "pytest", 31 | "pytest-asyncio", 32 | "sphinx", 33 | "sqlalchemy", 34 | "twine", 35 | "wheel", 36 | ], 37 | "krb5": ["gssapi"], 38 | }, 39 | classifiers=[ 40 | "Development Status :: 3 - Alpha", 41 | "Intended Audience :: Developers", 42 | "Intended Audience :: Science/Research", 43 | "Operating System :: OS Independent", 44 | "Programming Language :: SQL", 45 | "Programming Language :: Python :: 3 :: Only", 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barakalon/mysql-mimic/570cd6cb862a52fabbbdf15d801f4855d530c11d/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import asyncio 3 | import functools 4 | from contextvars import Context, copy_context 5 | from ssl import SSLContext 6 | from typing import ( 7 | Optional, 8 | List, 9 | Dict, 10 | Any, 11 | Callable, 12 | TypeVar, 13 | Awaitable, 14 | Sequence, 15 | AsyncGenerator, 16 | Type, 17 | ) 18 | 19 | import aiomysql 20 | import mysql.connector 21 | from mysql.connector.abstracts import MySQLConnectionAbstract 22 | from mysql.connector.connection import ( 23 | MySQLCursorPrepared, 24 | MySQLCursorDict, 25 | ) 26 | from mysql.connector.cursor import MySQLCursor 27 | from sqlglot import expressions as exp 28 | from sqlglot.executor import execute 29 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine 30 | import pytest 31 | import pytest_asyncio 32 | 33 | from mysql_mimic import MysqlServer, Session 34 | from mysql_mimic.auth import ( 35 | User, 36 | AuthPlugin, 37 | IdentityProvider, 38 | ) 39 | from mysql_mimic.connection import Connection 40 | from mysql_mimic.results import AllowedResult 41 | from mysql_mimic.schema import InfoSchema 42 | 43 | 44 | class PreparedDictCursor(MySQLCursorPrepared): 45 | def fetchall(self) -> Any: 46 | rows = super().fetchall() 47 | 48 | if rows is not None: 49 | return [dict(zip(self.column_names, row)) for row in rows] 50 | 51 | return None 52 | 53 | 54 | class MockSession(Session): 55 | def __init__(self) -> None: 56 | super().__init__() 57 | self.ctx: Context | None = None 58 | self.return_value: Any = None 59 | self.execute = False 60 | self.echo = False 61 | self.last_query_attrs: Optional[Dict[str, str]] = None 62 | self.users: Optional[Dict[str, User]] = None 63 | 64 | # tests can set an Event that the session will use to wait before continuing 65 | self.pause: Optional[asyncio.Event] = None 66 | # the session will set this event while waiting on `pause` 67 | self.waiting = asyncio.Event() 68 | 69 | async def init(self, connection: Connection) -> None: 70 | await super().init(connection) 71 | self.ctx = copy_context() 72 | 73 | async def query( 74 | self, expression: exp.Expression, sql: str, attrs: Dict[str, str] 75 | ) -> AllowedResult: 76 | if self.pause: 77 | self.waiting.set() 78 | await self.pause.wait() 79 | self.waiting.clear() 80 | 81 | self.last_query_attrs = attrs 82 | if self.echo: 83 | return [(sql,)], ["sql"] 84 | if self.execute: 85 | result = execute(expression) 86 | return result.rows, result.columns 87 | return self.return_value 88 | 89 | async def schema(self) -> dict | InfoSchema: 90 | return { 91 | "db": { 92 | "x": { 93 | "a": "TEXT", 94 | "b": "TEXT", 95 | }, 96 | "y": { 97 | "b": "TEXT", 98 | "c": "TEXT", 99 | }, 100 | } 101 | } 102 | 103 | 104 | class MockIdentityProvider(IdentityProvider): 105 | def __init__(self, auth_plugins: List[AuthPlugin], users: Dict[str, User]): 106 | self.auth_plugins = auth_plugins 107 | self.users = users 108 | 109 | def get_plugins(self) -> Sequence[AuthPlugin]: 110 | return self.auth_plugins 111 | 112 | async def get_user(self, username: str) -> Optional[User]: 113 | return self.users.get(username) 114 | 115 | 116 | T = TypeVar("T") 117 | 118 | 119 | async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: 120 | loop = asyncio.get_running_loop() 121 | func_call = functools.partial(func, *args, **kwargs) 122 | return await loop.run_in_executor(None, func_call) 123 | 124 | 125 | @pytest_asyncio.fixture 126 | async def session() -> MockSession: 127 | return MockSession() 128 | 129 | 130 | @pytest.fixture 131 | def session_factory(session: MockSession) -> Callable[[], MockSession]: 132 | return lambda: session 133 | 134 | 135 | @pytest.fixture 136 | def auth_plugins() -> Optional[List[AuthPlugin]]: 137 | return None 138 | 139 | 140 | @pytest.fixture 141 | def users() -> Dict[str, User]: 142 | return {} 143 | 144 | 145 | @pytest.fixture 146 | def identity_provider( 147 | auth_plugins: Optional[List[AuthPlugin]], users: Dict[str, User] 148 | ) -> Optional[MockIdentityProvider]: 149 | if auth_plugins: 150 | return MockIdentityProvider(auth_plugins, users) 151 | return None 152 | 153 | 154 | @pytest.fixture 155 | def ssl() -> Optional[SSLContext]: 156 | return None 157 | 158 | 159 | @pytest_asyncio.fixture 160 | async def server( 161 | session_factory: Callable[[], MockSession], 162 | identity_provider: Optional[MockIdentityProvider], 163 | ssl: Optional[SSLContext], 164 | ) -> AsyncGenerator[MysqlServer, None]: 165 | srv = MysqlServer( 166 | session_factory=session_factory, 167 | identity_provider=identity_provider, 168 | ssl=ssl, 169 | ) 170 | try: 171 | await srv.start_server(port=3307) 172 | except OSError as e: 173 | if e.errno == 48: 174 | # Port already in use. 175 | # This should only happen if there is a race condition between concurrent test runs. 176 | # Getting a new free port can be slow, so we optimistically try to use the free_port fixture first. 177 | await srv.start_server(port=0) 178 | else: 179 | raise 180 | asyncio.create_task(srv.serve_forever()) 181 | try: 182 | yield srv 183 | finally: 184 | srv.close() 185 | await srv.wait_closed() 186 | 187 | 188 | @pytest.fixture 189 | def port(server: MysqlServer) -> int: 190 | return server.sockets()[0].getsockname()[1] 191 | 192 | 193 | ConnectFixture = Callable[..., Awaitable[MySQLConnectionAbstract]] 194 | 195 | 196 | @pytest.fixture 197 | def connect(port: int) -> ConnectFixture: 198 | async def conn(**kwargs: Any) -> MySQLConnectionAbstract: 199 | return await to_thread( 200 | mysql.connector.connect, use_pure=True, port=port, **kwargs # type: ignore 201 | ) 202 | 203 | return conn 204 | 205 | 206 | @pytest_asyncio.fixture 207 | async def mysql_connector_conn( 208 | connect: ConnectFixture, 209 | ) -> AsyncGenerator[MySQLConnectionAbstract, None]: 210 | conn = await connect(user="levon_helm") 211 | # Different versions of mysql-connector use various default collations. 212 | # Execute the following query to ensure consistent test results across different versions. 213 | await query(conn, "SET NAMES 'utf8mb4' COLLATE 'utf8mb4_general_ci'") 214 | try: 215 | yield conn 216 | finally: 217 | conn.close() 218 | 219 | 220 | @pytest_asyncio.fixture 221 | async def aiomysql_conn(port: int) -> aiomysql.Connection: 222 | async with aiomysql.connect(port=port, user="levon_helm") as conn: 223 | yield conn 224 | 225 | 226 | @pytest_asyncio.fixture 227 | async def sqlalchemy_engine( 228 | port: int, 229 | ) -> AsyncGenerator[AsyncEngine, None]: 230 | engine = create_async_engine(url=f"mysql+aiomysql://levon_helm@127.0.0.1:{port}") 231 | try: 232 | yield engine 233 | finally: 234 | await engine.dispose() 235 | 236 | 237 | async def query( 238 | conn: MySQLConnectionAbstract, 239 | sql: str, 240 | cursor_class: Type[MySQLCursor] = MySQLCursorDict, 241 | params: Sequence[Any] | None = None, 242 | query_attributes: Dict[str, str] | None = None, 243 | ) -> Sequence[Any]: 244 | cursor = await to_thread(conn.cursor, cursor_class=cursor_class) 245 | if query_attributes: 246 | for key, value in query_attributes.items(): 247 | cursor.add_attribute(key, value) 248 | await to_thread(cursor.execute, sql, *(p for p in [params] if p)) 249 | result = await to_thread(cursor.fetchall) 250 | await to_thread(cursor.close) 251 | return result 252 | -------------------------------------------------------------------------------- /tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barakalon/mysql-mimic/570cd6cb862a52fabbbdf15d801f4855d530c11d/tests/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/certificate.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEIzCCAwugAwIBAgIUUMROPFb3ZMssWfl9/tJVMB0Zl0owDQYJKoZIhvcNAQEL 3 | BQAwgZ8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQH 4 | DA1TYW4gRnJhbmNpc2NvMRQwEgYDVQQKDAtteXNxbC1taW1pYzEUMBIGA1UECwwL 5 | bXlzcWwtbWltaWMxFDASBgNVBAMMC215c3FsLW1pbWljMSEwHwYJKoZIhvcNAQkB 6 | FhJleGFtcGxlQGRvbWFpbi5jb20wIBcNMjIxMDI0MTgyNzQ0WhgPMjEyMjA5MzAx 7 | ODI3NDRaMIGfMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQG 8 | A1UEBwwNU2FuIEZyYW5jaXNjbzEUMBIGA1UECgwLbXlzcWwtbWltaWMxFDASBgNV 9 | BAsMC215c3FsLW1pbWljMRQwEgYDVQQDDAtteXNxbC1taW1pYzEhMB8GCSqGSIb3 10 | DQEJARYSZXhhbXBsZUBkb21haW4uY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 11 | MIIBCgKCAQEAsnmS0k1Q0f+mv7izbT5K8zqYP54KwyW000qMhYLW/Ir5fGAw2QHg 12 | Q1vMBY+1oaHcO1DxjriG4l8P8ho9AShAD63Q3PVUiy3Prxi3PaZ/jPsI5rN/8s7s 13 | TXai6Po6gD56uYtZACl4cjF2ob3Vy/qzIPitW3D7UVEL+nqDEZUSmbFnT+NAMaCv 14 | Bq+Zf94vQSpQXgUSbTNAuFqjwMLeb8VX31e5yFOvhRd1Y65MOwmeCX0hMZ+XtcPc 15 | xtgCyQ9uT2OaKcqcngE/LtGnC4UJy0u5bcI8pPgwenGWNhQ6LrqLWAUizEXoprY8 16 | PUVoE42Itm7j1ODKm71OpNPqEjXxvs924wIDAQABo1MwUTAdBgNVHQ4EFgQUSVIG 17 | PQ1FTdvJeup85cCLpp/5wGIwHwYDVR0jBBgwFoAUSVIGPQ1FTdvJeup85cCLpp/5 18 | wGIwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAJcoALjwItgWa 19 | Qpr/hJCHhtf1gS6DCbtQoc+Owvz0iJkC+zqeaFlOYwVQz/n8iQ45g03cLMxHfMEv 20 | lj1ctHfVGI+2exp2hesMmU+CT49D9qFS2vQnaKxkeUqv5ia2d+V49jE6fj3CxIPP 21 | T3yI9+K2Ojx+7/WVVxCgz+eIZJbTCugoypDksldfuY4mx48E9qefOlBNcYyBzHt4 22 | 4iymCFFrfQHfMX+PnKDGHaoh/T/oZxdZRyDxNqTLfiyZ1PtAbueeT4Jf6CmamQZh 23 | IadFPXnA1YlEqreMk6nDrF8pqgOosHIngLhhUXHAVj/Br3UaDTaUGzlrlL7rLpRL 24 | oToCkplUgg== 25 | -----END CERTIFICATE----- 26 | -------------------------------------------------------------------------------- /tests/fixtures/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCyeZLSTVDR/6a/ 3 | uLNtPkrzOpg/ngrDJbTTSoyFgtb8ivl8YDDZAeBDW8wFj7Whodw7UPGOuIbiXw/y 4 | Gj0BKEAPrdDc9VSLLc+vGLc9pn+M+wjms3/yzuxNdqLo+jqAPnq5i1kAKXhyMXah 5 | vdXL+rMg+K1bcPtRUQv6eoMRlRKZsWdP40AxoK8Gr5l/3i9BKlBeBRJtM0C4WqPA 6 | wt5vxVffV7nIU6+FF3Vjrkw7CZ4JfSExn5e1w9zG2ALJD25PY5opypyeAT8u0acL 7 | hQnLS7ltwjyk+DB6cZY2FDouuotYBSLMReimtjw9RWgTjYi2buPU4MqbvU6k0+oS 8 | NfG+z3bjAgMBAAECggEBAI1wPUO+k/soUByGImPDxyAE4p0gAUVv/2KnJL+11exj 9 | sp23mV6Q1wpqmEAcCIQkQuUbG6PQZszFK1zhIFFndYU3aVuCbNKzpnAL9UO9TD4M 10 | v5wcypxBEhG9oBNkIrJ5UUbzwL+ZHePZgTtitykk74qEqNXbrr9drFF/f5mSeyAi 11 | nUkizI/aGJVp2+NVpsFvZd0M5gM1tEWoBq/5fny/b1oWM/blh1ZwWQII3bTcDABH 12 | 1h30oTuGWYyeQExUjiS4SDDJ8BH9PZxUGQhQsK1ZuJczJ5B+jh0OWicdD+P+V2EB 13 | npoWQVHKwqwfDQisx2T04TAb65QcbM7V/EwgQPpcd5ECgYEA3bQbff5bIvbgJwCo 14 | 1bHxj3MMmu1Qs+Km9nR8FbEtFaF1lbXzPLCj98VDpT3tBNw4VwpbVYn5rd4ZSs6h 15 | jlK2zEakFRvU/qgtTrUviW4oeSrEtUt8jI0QkdobUWPM8HuZih6Lgm5DTU6HwIBR 16 | dxkhojx7c1zwSm66Bojvn/CjVakCgYEAzhWHAxkNAj97XvjUe9XsV41jiksheNS2 17 | 7vAm1PASq9MGFpSazq2NeXTxuChvMK8vN9r2FngnFKW4hLGMi8wpReT+S1C32sNP 18 | ZWRq6w2JWgQvh7dNSNmn0mBc7zwUB6sFeS72q3Ge355dPJiuomRLxZmm+P3suHIV 19 | b+4Fj6q0p6sCgYEAqj5EojJwn1+97pVGEJqM6N+qvUkgoJGaLkRyiGG+Qg7y8RyA 20 | BImLz5ZuBHSSDhphNQ1h50SFMusKtvQHAPgpIKHaG898dnSEHh1pvHmXoLujw6eM 21 | o40rPSSjt5MQa1YuJ+6eqHCtQ67a9YpThEYLGr6g+YxThISUWrJKd6HceskCgYB0 22 | gWMUc0MRdEYQyOeHIsc8L+iINDU2FDtfFVE+rIJBtUkJ1vU1xpPmiCBnFiTWBxPQ 23 | pe7dgQvG9nE8QwvLtJ3Yr767YWSvPh9SmNSBEeQGibs9JHmCp9niayve673/H8Y2 24 | XkCBZ/iDPwpCyaZglAbqLRViSltbYtOPtaZbNAxxhQKBgH8Gj0yg3A5DwWEyeUKR 25 | KJBS2rPiKgrhJs8Kd8GZZUb3H5WGqzfgrRK1p9j5Ug8UXSWbB9e5jg4ymVtcAAhc 26 | tRz4rCvYNY8fHlA2TfzrOeEuuxtoMFyxd26eFjZvS/w2VdFrIQZbSvDkV4hMqhpS 27 | CCslSFIIZOMCzqgQtM+NTQJ6 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /tests/fixtures/queries.py: -------------------------------------------------------------------------------- 1 | DATA_GRIP_PARAMETERS = """ 2 | SELECT 3 | SPECIFIC_SCHEMA AS PROCEDURE_CAT, 4 | NULL AS `PROCEDURE_SCHEM`, 5 | SPECIFIC_NAME AS `PROCEDURE_NAME`, 6 | IFNULL(PARAMETER_NAME, '') AS `COLUMN_NAME`, 7 | CASE 8 | WHEN PARAMETER_MODE = 'IN' 9 | THEN 1 10 | WHEN PARAMETER_MODE = 'OUT' 11 | THEN 4 12 | WHEN PARAMETER_MODE = 'INOUT' 13 | THEN 2 14 | WHEN ORDINAL_POSITION = 0 15 | THEN 5 16 | ELSE 0 17 | END AS `COLUMN_TYPE`, 18 | CASE 19 | WHEN UPPER(DATA_TYPE) = 'DECIMAL' 20 | THEN 3 21 | WHEN UPPER(DATA_TYPE) = 'DECIMAL UNSIGNED' 22 | THEN 3 23 | WHEN UPPER(DATA_TYPE) = 'TINYINT' 24 | THEN -6 25 | WHEN UPPER(DATA_TYPE) = 'TINYINT UNSIGNED' 26 | THEN -6 27 | WHEN UPPER(DATA_TYPE) = 'BOOLEAN' 28 | THEN 16 29 | WHEN UPPER(DATA_TYPE) = 'SMALLINT' 30 | THEN 5 31 | WHEN UPPER(DATA_TYPE) = 'SMALLINT UNSIGNED' 32 | THEN 5 33 | WHEN UPPER(DATA_TYPE) = 'INT' 34 | THEN 4 35 | WHEN UPPER(DATA_TYPE) = 'INT UNSIGNED' 36 | THEN 4 37 | WHEN UPPER(DATA_TYPE) = 'FLOAT' 38 | THEN 7 39 | WHEN UPPER(DATA_TYPE) = 'FLOAT UNSIGNED' 40 | THEN 7 41 | WHEN UPPER(DATA_TYPE) = 'DOUBLE' 42 | THEN 8 43 | WHEN UPPER(DATA_TYPE) = 'DOUBLE UNSIGNED' 44 | THEN 8 45 | WHEN UPPER(DATA_TYPE) = 'NULL' 46 | THEN 0 47 | WHEN UPPER(DATA_TYPE) = 'TIMESTAMP' 48 | THEN 93 49 | WHEN UPPER(DATA_TYPE) = 'BIGINT' 50 | THEN -5 51 | WHEN UPPER(DATA_TYPE) = 'BIGINT UNSIGNED' 52 | THEN -5 53 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT' 54 | THEN 4 55 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT UNSIGNED' 56 | THEN 4 57 | WHEN UPPER(DATA_TYPE) = 'DATE' 58 | THEN 91 59 | WHEN UPPER(DATA_TYPE) = 'TIME' 60 | THEN 92 61 | WHEN UPPER(DATA_TYPE) = 'DATETIME' 62 | THEN 93 63 | WHEN UPPER(DATA_TYPE) = 'YEAR' 64 | THEN 91 65 | WHEN UPPER(DATA_TYPE) = 'VARCHAR' 66 | THEN 12 67 | WHEN UPPER(DATA_TYPE) = 'VARBINARY' 68 | THEN -3 69 | WHEN UPPER(DATA_TYPE) = 'BIT' 70 | THEN -7 71 | WHEN UPPER(DATA_TYPE) = 'JSON' 72 | THEN -1 73 | WHEN UPPER(DATA_TYPE) = 'ENUM' 74 | THEN 1 75 | WHEN UPPER(DATA_TYPE) = 'SET' 76 | THEN 1 77 | WHEN UPPER(DATA_TYPE) = 'TINYBLOB' 78 | THEN -3 79 | WHEN UPPER(DATA_TYPE) = 'TINYTEXT' 80 | THEN 12 81 | WHEN UPPER(DATA_TYPE) = 'MEDIUMBLOB' 82 | THEN -4 83 | WHEN UPPER(DATA_TYPE) = 'MEDIUMTEXT' 84 | THEN -1 85 | WHEN UPPER(DATA_TYPE) = 'LONGBLOB' 86 | THEN -4 87 | WHEN UPPER(DATA_TYPE) = 'LONGTEXT' 88 | THEN -1 89 | WHEN UPPER(DATA_TYPE) = 'BLOB' 90 | THEN -4 91 | WHEN UPPER(DATA_TYPE) = 'TEXT' 92 | THEN -1 93 | WHEN UPPER(DATA_TYPE) = 'CHAR' 94 | THEN 1 95 | WHEN UPPER(DATA_TYPE) = 'BINARY' 96 | THEN -2 97 | WHEN UPPER(DATA_TYPE) = 'GEOMETRY' 98 | THEN -2 99 | WHEN UPPER(DATA_TYPE) = 'UNKNOWN' 100 | THEN 1111 101 | WHEN UPPER(DATA_TYPE) = 'POINT' 102 | THEN -2 103 | WHEN UPPER(DATA_TYPE) = 'LINESTRING' 104 | THEN -2 105 | WHEN UPPER(DATA_TYPE) = 'POLYGON' 106 | THEN -2 107 | WHEN UPPER(DATA_TYPE) = 'MULTIPOINT' 108 | THEN -2 109 | WHEN UPPER(DATA_TYPE) = 'MULTILINESTRING' 110 | THEN -2 111 | WHEN UPPER(DATA_TYPE) = 'MULTIPOLYGON' 112 | THEN -2 113 | WHEN UPPER(DATA_TYPE) = 'GEOMETRYCOLLECTION' 114 | THEN -2 115 | WHEN UPPER(DATA_TYPE) = 'GEOMCOLLECTION' 116 | THEN -2 117 | ELSE 1111 118 | END AS `DATA_TYPE`, 119 | UPPER( 120 | CASE 121 | WHEN LOCATE('UNSIGNED', UPPER(DTD_IDENTIFIER)) <> 0 122 | AND LOCATE('UNSIGNED', UPPER(DATA_TYPE)) = 0 123 | AND LOCATE('SET', UPPER(DATA_TYPE)) <> 1 124 | AND LOCATE('ENUM', UPPER(DATA_TYPE)) <> 1 125 | THEN CONCAT(DATA_TYPE, ' UNSIGNED') 126 | WHEN UPPER(DATA_TYPE) = 'POINT' 127 | THEN 'GEOMETRY' 128 | WHEN UPPER(DATA_TYPE) = 'LINESTRING' 129 | THEN 'GEOMETRY' 130 | WHEN UPPER(DATA_TYPE) = 'POLYGON' 131 | THEN 'GEOMETRY' 132 | WHEN UPPER(DATA_TYPE) = 'MULTIPOINT' 133 | THEN 'GEOMETRY' 134 | WHEN UPPER(DATA_TYPE) = 'MULTILINESTRING' 135 | THEN 'GEOMETRY' 136 | WHEN UPPER(DATA_TYPE) = 'MULTIPOLYGON' 137 | THEN 'GEOMETRY' 138 | WHEN UPPER(DATA_TYPE) = 'GEOMETRYCOLLECTION' 139 | THEN 'GEOMETRY' 140 | WHEN UPPER(DATA_TYPE) = 'GEOMCOLLECTION' 141 | THEN 'GEOMETRY' 142 | ELSE UPPER(DATA_TYPE) 143 | END 144 | ) AS TYPE_NAME, 145 | CASE 146 | WHEN LCASE(DATA_TYPE) = 'date' 147 | THEN 0 148 | WHEN LCASE(DATA_TYPE) = 'time' 149 | OR LCASE(DATA_TYPE) = 'datetime' 150 | OR LCASE(DATA_TYPE) = 'timestamp' 151 | THEN DATETIME_PRECISION 152 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT' AND LOCATE('UNSIGNED', UPPER(DTD_IDENTIFIER)) <> 0 153 | THEN 8 154 | WHEN UPPER(DATA_TYPE) = 'JSON' 155 | THEN 1073741824 156 | ELSE NUMERIC_PRECISION 157 | END AS `PRECISION`, 158 | CASE 159 | WHEN LCASE(DATA_TYPE) = 'date' 160 | THEN 10 161 | WHEN LCASE(DATA_TYPE) = 'time' 162 | THEN 8 + ( 163 | CASE 164 | WHEN DATETIME_PRECISION > 0 165 | THEN DATETIME_PRECISION + 1 166 | ELSE DATETIME_PRECISION 167 | END 168 | ) 169 | WHEN LCASE(DATA_TYPE) = 'datetime' OR LCASE(DATA_TYPE) = 'timestamp' 170 | THEN 19 + ( 171 | CASE 172 | WHEN DATETIME_PRECISION > 0 173 | THEN DATETIME_PRECISION + 1 174 | ELSE DATETIME_PRECISION 175 | END 176 | ) 177 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT' AND LOCATE('UNSIGNED', UPPER(DTD_IDENTIFIER)) <> 0 178 | THEN 8 179 | WHEN UPPER(DATA_TYPE) = 'JSON' 180 | THEN 1073741824 181 | WHEN CHARACTER_MAXIMUM_LENGTH IS NULL 182 | THEN NUMERIC_PRECISION 183 | WHEN CHARACTER_MAXIMUM_LENGTH > 2147483647 184 | THEN 2147483647 185 | ELSE CHARACTER_MAXIMUM_LENGTH 186 | END AS LENGTH, 187 | NUMERIC_SCALE AS `SCALE`, 188 | 10 AS RADIX, 189 | 1 AS `NULLABLE`, 190 | NULL AS `REMARKS`, 191 | NULL AS `COLUMN_DEF`, 192 | NULL AS `SQL_DATA_TYPE`, 193 | NULL AS `SQL_DATETIME_SUB`, 194 | CASE 195 | WHEN CHARACTER_OCTET_LENGTH > 2147483647 196 | THEN 2147483647 197 | ELSE CHARACTER_OCTET_LENGTH 198 | END AS `CHAR_OCTET_LENGTH`, 199 | ORDINAL_POSITION, 200 | 'YES' AS `IS_NULLABLE`, 201 | SPECIFIC_NAME 202 | FROM INFORMATION_SCHEMA.PARAMETERS 203 | WHERE 204 | SPECIFIC_SCHEMA = 'magic' AND SPECIFIC_NAME LIKE '%' 205 | ORDER BY 206 | SPECIFIC_SCHEMA, 207 | SPECIFIC_NAME, 208 | ROUTINE_TYPE, 209 | ORDINAL_POSITION 210 | """ 211 | 212 | DATA_GRIP_VARIABLES = """ 213 | /* mysql-connector-java-8.0.25 (Revision: 08be9e9b4cba6aa115f9b27b215887af40b159e0) */SELECT 214 | @@session.auto_increment_increment AS auto_increment_increment, 215 | @@character_set_client AS character_set_client, 216 | @@character_set_connection AS character_set_connection, 217 | @@character_set_results AS character_set_results, 218 | @@character_set_server AS character_set_server, 219 | @@collation_server AS collation_server, 220 | @@collation_connection AS collation_connection, 221 | @@init_connect AS init_connect, 222 | @@interactive_timeout AS interactive_timeout, 223 | @@license AS license, 224 | @@lower_case_table_names AS lower_case_table_names, 225 | @@max_allowed_packet AS max_allowed_packet, 226 | @@net_write_timeout AS net_write_timeout, 227 | @@performance_schema AS performance_schema, 228 | @@sql_mode AS sql_mode, 229 | @@system_time_zone AS system_time_zone, 230 | @@time_zone AS time_zone, 231 | @@transaction_isolation AS transaction_isolation, 232 | @@wait_timeout AS wait_timeout 233 | """ 234 | 235 | DATA_GRIP_TABLES = """ 236 | /* ApplicationName=DataGrip 2022.2.5 */ 237 | SELECT 238 | TABLE_SCHEMA AS TABLE_CAT, 239 | NULL AS TABLE_SCHEM, 240 | TABLE_NAME, 241 | CASE 242 | WHEN TABLE_TYPE = 'BASE TABLE' 243 | THEN CASE 244 | WHEN TABLE_SCHEMA = 'mysql' OR TABLE_SCHEMA = 'performance_schema' 245 | THEN 'SYSTEM TABLE' 246 | ELSE 'TABLE' 247 | END 248 | WHEN TABLE_TYPE = 'TEMPORARY' 249 | THEN 'LOCAL_TEMPORARY' 250 | ELSE TABLE_TYPE 251 | END AS TABLE_TYPE, 252 | TABLE_COMMENT AS REMARKS, 253 | NULL AS TYPE_CAT, 254 | NULL AS TYPE_SCHEM, 255 | NULL AS TYPE_NAME, 256 | NULL AS SELF_REFERENCING_COL_NAME, 257 | NULL AS REF_GENERATION 258 | FROM INFORMATION_SCHEMA.TABLES 259 | WHERE 260 | TABLE_SCHEMA = 'db' AND TABLE_NAME LIKE '%' 261 | HAVING 262 | TABLE_TYPE IN ('LOCAL TEMPORARY', 'SYSTEM TABLE', 'SYSTEM VIEW', 'TABLE', 'VIEW') 263 | ORDER BY 264 | TABLE_TYPE, 265 | TABLE_SCHEMA, 266 | TABLE_NAME 267 | """ 268 | 269 | TABLEAU_INDEXES_1 = """ 270 | SELECT 271 | A.REFERENCED_TABLE_SCHEMA AS PKTABLE_CAT, 272 | NULL AS PKTABLE_SCHEM, 273 | A.REFERENCED_TABLE_NAME AS PKTABLE_NAME, 274 | A.REFERENCED_COLUMN_NAME AS PKCOLUMN_NAME, 275 | A.TABLE_SCHEMA AS FKTABLE_CAT, 276 | NULL AS FKTABLE_SCHEM, 277 | A.TABLE_NAME AS FKTABLE_NAME, 278 | A.COLUMN_NAME AS FKCOLUMN_NAME, 279 | A.ORDINAL_POSITION AS KEY_SEQ, 280 | CASE 281 | WHEN R.UPDATE_RULE = 'CASCADE' THEN 0 282 | WHEN R.UPDATE_RULE = 'SET NULL' THEN 2 283 | WHEN R.UPDATE_RULE = 'SET DEFAULT' THEN 4 284 | WHEN R.UPDATE_RULE = 'SET RESTRICT' THEN 1 285 | WHEN R.UPDATE_RULE = 'SET NO ACTION' THEN 3 286 | ELSE 3 287 | END AS UPDATE_RULE, 288 | CASE 289 | WHEN R.DELETE_RULE = 'CASCADE' THEN 0 290 | WHEN R.DELETE_RULE = 'SET NULL' THEN 2 291 | WHEN R.DELETE_RULE = 'SET DEFAULT' THEN 4 292 | WHEN R.DELETE_RULE = 'SET RESTRICT' THEN 1 293 | WHEN R.DELETE_RULE = 'SET NO ACTION' THEN 3 294 | ELSE 3 END AS DELETE_RULE, 295 | A.CONSTRAINT_NAME AS FK_NAME, 296 | 'PRIMARY' AS PK_NAME, 297 | 7 AS DEFERRABILITY 298 | FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE A 299 | JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE D 300 | ON ( 301 | D.TABLE_SCHEMA=A.REFERENCED_TABLE_SCHEMA 302 | AND D.TABLE_NAME=A.REFERENCED_TABLE_NAME 303 | AND D.COLUMN_NAME=A.REFERENCED_COLUMN_NAME) 304 | JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS R 305 | ON ( 306 | R.CONSTRAINT_NAME = A.CONSTRAINT_NAME 307 | AND R.TABLE_NAME = A.TABLE_NAME 308 | AND R.CONSTRAINT_SCHEMA = A.TABLE_SCHEMA) 309 | WHERE 310 | D.CONSTRAINT_NAME='PRIMARY' 311 | AND A.TABLE_SCHEMA = 'magic' 312 | ORDER BY FKTABLE_CAT, FKTABLE_NAME, KEY_SEQ, PKTABLE_NAME 313 | """ 314 | 315 | TABLEAU_INDEXES_2 = """ 316 | SELECT COLUMN_NAME, 317 | REFERENCED_TABLE_NAME, 318 | REFERENCED_COLUMN_NAME, 319 | ORDINAL_POSITION, 320 | CONSTRAINT_NAME 321 | FROM information_schema.KEY_COLUMN_USAGE 322 | WHERE REFERENCED_TABLE_NAME IS NOT NULL 323 | AND TABLE_NAME = 'x' 324 | AND TABLE_SCHEMA = 'db' 325 | """ 326 | 327 | TABLEAU_COLUMNS = """ 328 | SELECT `table_name`, `column_name` 329 | FROM `information_schema`.`columns` 330 | WHERE `data_type`='enum' AND `table_schema`='db' 331 | """ 332 | 333 | TABLEAU_TABLES = """ 334 | SELECT TABLE_NAME,TABLE_COMMENT,IF(TABLE_TYPE='BASE TABLE', 'TABLE', TABLE_TYPE),TABLE_SCHEMA 335 | FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'db' AND ( TABLE_TYPE='BASE TABLE' OR TABLE_TYPE='VIEW' ) 336 | ORDER BY TABLE_SCHEMA, TABLE_NAME 337 | """ 338 | 339 | TABLEAU_SCHEMATA = """ 340 | SELECT NULL, NULL, NULL, SCHEMA_NAME 341 | FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME LIKE '%' 342 | ORDER BY SCHEMA_NAME 343 | """ 344 | -------------------------------------------------------------------------------- /tests/test_auth.py: -------------------------------------------------------------------------------- 1 | from contextlib import closing 2 | from typing import Optional, Tuple, Dict 3 | 4 | import pytest 5 | from mysql.connector import DatabaseError 6 | from mysql.connector.abstracts import MySQLConnectionAbstract 7 | from mysql.connector.plugins.mysql_clear_password import MySQLClearPasswordAuthPlugin 8 | 9 | from mysql_mimic import User 10 | from mysql_mimic.auth import ( 11 | NativePasswordAuthPlugin, 12 | AbstractClearPasswordAuthPlugin, 13 | NoLoginAuthPlugin, 14 | ) 15 | from tests.conftest import query, to_thread, ConnectFixture 16 | 17 | # mysql.connector throws an error if you try to use mysql_clear_password without SSL. 18 | # That's silly, since SSL termination doesn't have to be handled by MySQL. 19 | # But it's extra silly in tests. 20 | MySQLClearPasswordAuthPlugin.requires_ssl = False # type: ignore 21 | MySQLConnectionAbstract.is_secure = True # type: ignore 22 | 23 | SIMPLE_AUTH_USER = "levon_helm" 24 | 25 | PASSWORD_AUTH_USER = "rick_danko" 26 | PASSWORD_AUTH_PASSWORD = "nazareth" 27 | PASSWORD_AUTH_OLD_PASSWORD = "cannonball" 28 | PASSWORD_AUTH_PLUGIN = NativePasswordAuthPlugin.client_plugin_name 29 | 30 | 31 | class TestPlugin(AbstractClearPasswordAuthPlugin): 32 | name = "mysql_clear_password" 33 | 34 | async def check(self, username: str, password: str) -> Optional[str]: 35 | return username if username == password else None 36 | 37 | 38 | TEST_PLUGIN_AUTH_USER = "garth_hudson" 39 | TEST_PLUGIN_AUTH_PASSWORD = TEST_PLUGIN_AUTH_USER 40 | TEST_PLUGIN_AUTH_PLUGIN = TestPlugin.client_plugin_name 41 | 42 | NO_LOGIN_USER = "carmen_and_the_devil" 43 | NO_LOGIN_PLUGIN = NoLoginAuthPlugin.name 44 | 45 | UNKNOWN_PLUGIN_USER = "richard_manuel" 46 | NO_PLUGIN_USER = "miss_moses" 47 | 48 | 49 | @pytest.fixture(autouse=True) 50 | def users() -> Dict[str, User]: 51 | return { 52 | SIMPLE_AUTH_USER: User( 53 | name=SIMPLE_AUTH_USER, 54 | auth_string=None, 55 | auth_plugin=NativePasswordAuthPlugin.name, 56 | ), 57 | PASSWORD_AUTH_USER: User( 58 | name=PASSWORD_AUTH_USER, 59 | auth_string=NativePasswordAuthPlugin.create_auth_string( 60 | PASSWORD_AUTH_PASSWORD 61 | ), 62 | old_auth_string=NativePasswordAuthPlugin.create_auth_string( 63 | PASSWORD_AUTH_OLD_PASSWORD 64 | ), 65 | auth_plugin=NativePasswordAuthPlugin.name, 66 | ), 67 | TEST_PLUGIN_AUTH_USER: User( 68 | name=TEST_PLUGIN_AUTH_USER, 69 | auth_string=TEST_PLUGIN_AUTH_PASSWORD, 70 | auth_plugin=TestPlugin.name, 71 | ), 72 | UNKNOWN_PLUGIN_USER: User(name=UNKNOWN_PLUGIN_USER, auth_plugin="unknown"), 73 | NO_PLUGIN_USER: User(name=NO_PLUGIN_USER, auth_plugin=NO_LOGIN_PLUGIN), 74 | } 75 | 76 | 77 | @pytest.mark.asyncio 78 | @pytest.mark.parametrize( 79 | "auth_plugins,username,password,auth_plugin", 80 | [ 81 | ( 82 | [NativePasswordAuthPlugin()], 83 | PASSWORD_AUTH_USER, 84 | PASSWORD_AUTH_PASSWORD, 85 | PASSWORD_AUTH_PLUGIN, 86 | ), 87 | ( 88 | [TestPlugin()], 89 | TEST_PLUGIN_AUTH_USER, 90 | TEST_PLUGIN_AUTH_PASSWORD, 91 | TEST_PLUGIN_AUTH_PLUGIN, 92 | ), 93 | ([NativePasswordAuthPlugin()], SIMPLE_AUTH_USER, None, None), 94 | ([TestPlugin(), NativePasswordAuthPlugin()], SIMPLE_AUTH_USER, None, None), 95 | (None, SIMPLE_AUTH_USER, None, None), 96 | ], 97 | ) 98 | async def test_auth( 99 | connect: ConnectFixture, 100 | username: str, 101 | password: Optional[str], 102 | auth_plugin: Optional[str], 103 | ) -> None: 104 | kwargs = {"user": username, "password": password, "auth_plugin": auth_plugin} 105 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 106 | with closing(await connect(**kwargs)) as conn: 107 | assert await query(conn=conn, sql="SELECT USER() AS a") == [{"a": username}] 108 | 109 | 110 | @pytest.mark.asyncio 111 | @pytest.mark.parametrize( 112 | "auth_plugins", 113 | [ 114 | [NativePasswordAuthPlugin()], 115 | ], 116 | ) 117 | async def test_auth_secondary_password( 118 | connect: ConnectFixture, 119 | ) -> None: 120 | with closing( 121 | await connect( 122 | user=PASSWORD_AUTH_USER, 123 | password=PASSWORD_AUTH_OLD_PASSWORD, 124 | auth_plugin=PASSWORD_AUTH_PLUGIN, 125 | ) 126 | ) as conn: 127 | assert await query(conn=conn, sql="SELECT USER() AS a") == [ 128 | {"a": PASSWORD_AUTH_USER} 129 | ] 130 | 131 | 132 | @pytest.mark.asyncio 133 | @pytest.mark.parametrize( 134 | "auth_plugins,user1,user2", 135 | [ 136 | ( 137 | [NativePasswordAuthPlugin(), TestPlugin()], 138 | (PASSWORD_AUTH_USER, PASSWORD_AUTH_PASSWORD, PASSWORD_AUTH_PLUGIN), 139 | (TEST_PLUGIN_AUTH_USER, TEST_PLUGIN_AUTH_PASSWORD, TEST_PLUGIN_AUTH_PLUGIN), 140 | ), 141 | ( 142 | [NativePasswordAuthPlugin()], 143 | (PASSWORD_AUTH_USER, PASSWORD_AUTH_PASSWORD, PASSWORD_AUTH_PLUGIN), 144 | (PASSWORD_AUTH_USER, PASSWORD_AUTH_PASSWORD, PASSWORD_AUTH_PLUGIN), 145 | ), 146 | ], 147 | ) 148 | async def test_change_user( 149 | connect: ConnectFixture, 150 | user1: Tuple[str, str, str], 151 | user2: Tuple[str, str, str], 152 | ) -> None: 153 | kwargs1 = {"user": user1[0], "password": user1[1], "auth_plugin": user1[2]} 154 | kwargs1 = {k: v for k, v in kwargs1.items() if v is not None} 155 | 156 | with closing(await connect(**kwargs1)) as conn: 157 | # mysql.connector doesn't have great support for COM_CHANGE_USER 158 | # Here, we have to manually override the auth plugin to use 159 | conn._auth_plugin = user2[2] # pylint: disable=protected-access 160 | 161 | await to_thread(conn.cmd_change_user, username=user2[0], password=user2[1]) 162 | assert await query(conn=conn, sql="SELECT USER() AS a") == [{"a": user2[0]}] 163 | 164 | 165 | @pytest.mark.asyncio 166 | @pytest.mark.parametrize( 167 | "auth_plugins,username,password,auth_plugin,msg", 168 | [ 169 | ([NativePasswordAuthPlugin()], None, None, None, "User does not exist"), 170 | ( 171 | [TestPlugin()], 172 | PASSWORD_AUTH_USER, 173 | PASSWORD_AUTH_PASSWORD, 174 | PASSWORD_AUTH_PLUGIN, 175 | "Access denied", 176 | ), 177 | # This test doesn't work with newer versions of mysql-connector-python. 178 | # ([NoLoginAuthPlugin()], NO_PLUGIN_USER, None, None, "Access denied"), 179 | ( 180 | [NativePasswordAuthPlugin(), NoLoginAuthPlugin()], 181 | NO_PLUGIN_USER, 182 | None, 183 | None, 184 | "Access denied", 185 | ), 186 | ], 187 | ) 188 | async def test_access_denied( 189 | connect: ConnectFixture, 190 | username: Optional[str], 191 | password: Optional[str], 192 | auth_plugin: Optional[str], 193 | msg: str, 194 | ) -> None: 195 | kwargs = {"user": username, "password": password, "auth_plugin": auth_plugin} 196 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 197 | 198 | with pytest.raises(DatabaseError) as ctx: 199 | await connect(**kwargs) 200 | 201 | assert msg in str(ctx.value) 202 | -------------------------------------------------------------------------------- /tests/test_control.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from contextlib import closing 3 | from typing import Callable 4 | 5 | import pytest 6 | from mysql.connector import DatabaseError 7 | 8 | from tests.conftest import MockSession, query, ConnectFixture 9 | 10 | 11 | @pytest.fixture 12 | def session_factory(session: MockSession) -> Callable[[], MockSession]: 13 | called_once = False 14 | 15 | def factory() -> MockSession: 16 | nonlocal called_once 17 | if not called_once: 18 | called_once = True 19 | return session 20 | return MockSession() 21 | 22 | return factory 23 | 24 | 25 | @pytest.mark.asyncio 26 | async def test_kill_connection( 27 | session: MockSession, 28 | connect: ConnectFixture, 29 | ) -> None: 30 | session.pause = asyncio.Event() 31 | 32 | with closing(await connect()) as conn1: 33 | with closing(await connect()) as conn2: 34 | q1 = asyncio.create_task(query(conn1, "SELECT a FROM x")) 35 | 36 | # Wait until we know the session is paused 37 | await session.waiting.wait() 38 | 39 | await query(conn2, f"KILL {session.connection.connection_id}") 40 | 41 | with pytest.raises(DatabaseError) as ctx: 42 | await q1 43 | 44 | assert "Session was killed" in str(ctx.value.msg) 45 | 46 | session.pause.clear() 47 | 48 | with pytest.raises(DatabaseError) as ctx: 49 | await query(conn1, "SELECT 1") 50 | 51 | assert "Connection not available" in str(ctx.value.msg) 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_kill_query( 56 | session: MockSession, 57 | connect: ConnectFixture, 58 | ) -> None: 59 | session.pause = asyncio.Event() 60 | 61 | with closing(await connect()) as conn1: 62 | with closing(await connect()) as conn2: 63 | q1 = asyncio.create_task(query(conn1, "SELECT a FROM x")) 64 | 65 | # Wait until we know the session is paused 66 | await session.waiting.wait() 67 | 68 | await query(conn2, f"KILL QUERY {session.connection.connection_id}") 69 | 70 | with pytest.raises(DatabaseError) as ctx: 71 | await q1 72 | 73 | assert "Query was killed" in str(ctx.value.msg) 74 | 75 | session.pause.clear() 76 | 77 | assert (await query(conn1, "SELECT 1")) == [{"1": 1}] 78 | -------------------------------------------------------------------------------- /tests/test_result.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | 5 | from mysql_mimic import ColumnType 6 | from mysql_mimic.errors import MysqlError 7 | from mysql_mimic.results import ensure_result_set 8 | 9 | 10 | async def gen_rows() -> Any: 11 | yield 1, None, None 12 | yield None, "2", None 13 | yield None, None, None 14 | 15 | 16 | @pytest.mark.asyncio 17 | @pytest.mark.parametrize( 18 | "result, column_types", 19 | [ 20 | ( 21 | (gen_rows(), ["a", "b", "c"]), 22 | [ColumnType.LONGLONG, ColumnType.STRING, ColumnType.NULL], 23 | ), 24 | ], 25 | ) 26 | async def test_ensure_result_set_columns(result: Any, column_types: Any) -> None: 27 | result_set = await ensure_result_set(result) 28 | assert [c.type for c in result_set.columns] == column_types 29 | 30 | 31 | @pytest.mark.asyncio 32 | @pytest.mark.parametrize( 33 | "result", 34 | [ 35 | [1, 2], 36 | ([[1, 2]], ["a", "b"], ["a", "b"]), 37 | ], 38 | ) 39 | async def test_ensure_result_set__invalid(result: Any) -> None: 40 | with pytest.raises(MysqlError): 41 | await ensure_result_set(result) 42 | -------------------------------------------------------------------------------- /tests/test_schema.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlglot import expressions as exp 3 | 4 | from mysql_mimic.results import ensure_result_set 5 | from mysql_mimic.schema import ( 6 | Column, 7 | InfoSchema, 8 | mapping_to_columns, 9 | show_statement_to_info_schema_query, 10 | ) 11 | from mysql_mimic.utils import aiterate 12 | 13 | 14 | @pytest.mark.asyncio 15 | def test_mapping_to_columns() -> None: 16 | schema = { 17 | "table_1": { 18 | "col_1": "TEXT", 19 | "col_2": "INT", 20 | } 21 | } 22 | 23 | columns = mapping_to_columns(schema=schema) 24 | 25 | assert columns[0] == Column( 26 | name="col_1", type="TEXT", table="table_1", schema="", catalog="def" 27 | ) 28 | assert columns[1] == Column( 29 | name="col_2", type="INT", table="table_1", schema="", catalog="def" 30 | ) 31 | 32 | 33 | @pytest.mark.asyncio 34 | async def test_info_schema_from_columns() -> None: 35 | input_columns = [ 36 | Column( 37 | name="col_1", 38 | type="TEXT", 39 | table="table_1", 40 | schema="my_db", 41 | catalog="def", 42 | comment="This is a comment", 43 | ), 44 | Column( 45 | name="col_1", type="TEXT", table="table_2", schema="my_db", catalog="def" 46 | ), 47 | ] 48 | schema = InfoSchema.from_columns(columns=input_columns) 49 | table_query = show_statement_to_info_schema_query(exp.Show(this="TABLES"), "my_db") 50 | result_set = await ensure_result_set(await schema.query(table_query)) 51 | table_names = [row[0] async for row in aiterate(result_set.rows)] 52 | assert table_names == ["table_1", "table_2"] 53 | 54 | column_query = show_statement_to_info_schema_query( 55 | exp.Show(this="COLUMNS", full=True, target="table_1"), "my_db" 56 | ) 57 | column_results = await ensure_result_set(await schema.query(column_query)) 58 | columns = [row async for row in aiterate(column_results.rows)] 59 | assert columns == [ 60 | ( 61 | "col_1", 62 | "TEXT", 63 | "YES", 64 | None, 65 | None, 66 | None, 67 | "NULL", 68 | None, 69 | "This is a comment", 70 | ) 71 | ] 72 | -------------------------------------------------------------------------------- /tests/test_ssl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import closing 3 | from ssl import SSLContext, PROTOCOL_TLS_SERVER 4 | 5 | import pytest 6 | 7 | from tests.conftest import MockSession, ConnectFixture, query 8 | 9 | CURRENT_DIR = os.path.dirname(__file__) 10 | SSL_CERT = os.path.join(CURRENT_DIR, "fixtures/certificate.pem") 11 | SSL_KEY = os.path.join(CURRENT_DIR, "fixtures/key.pem") 12 | 13 | 14 | @pytest.fixture(autouse=True) 15 | def ssl() -> SSLContext: 16 | sslcontext = SSLContext(PROTOCOL_TLS_SERVER) 17 | sslcontext.load_cert_chain(certfile=SSL_CERT, keyfile=SSL_KEY) 18 | return sslcontext 19 | 20 | 21 | @pytest.mark.asyncio 22 | @pytest.mark.skipif( 23 | os.getenv("CI") == "true", 24 | reason=""" 25 | There seems to be a sequencing issue with the asyncio server starting tls. 26 | https://stackoverflow.com/questions/74187508/sequencing-issue-when-upgrading-python-asyncio-connection-to-tls 27 | Didn't have time to debug this, so skipping in CI for now. 28 | """, 29 | ) 30 | async def test_ssl( 31 | session: MockSession, 32 | connect: ConnectFixture, 33 | port: int, 34 | ) -> None: 35 | with closing(await connect()) as conn: 36 | results = await query(conn, "SELECT 1 AS a") 37 | assert results == [{"a": 1}] 38 | -------------------------------------------------------------------------------- /tests/test_stream.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mysql_mimic.errors import MysqlError 4 | from mysql_mimic.stream import MysqlStream 5 | 6 | 7 | class MockReader: 8 | def __init__(self, data: bytes): 9 | self.data = data 10 | self.pos = 0 11 | 12 | async def read(self, n: int) -> bytes: 13 | start = self.pos 14 | end = self.pos + n 15 | self.pos = end 16 | 17 | return self.data[start:end] 18 | 19 | readexactly = read 20 | 21 | 22 | class MockWriter: 23 | def __init__(self) -> None: 24 | self.data = b"" 25 | 26 | def write(self, data: bytes) -> None: 27 | self.data += data 28 | 29 | async def drain(self) -> None: 30 | return 31 | 32 | 33 | def test_seq() -> None: 34 | s1 = MysqlStream(reader=None, writer=None) # type: ignore 35 | assert next(s1.seq) == 0 36 | assert next(s1.seq) == 1 37 | assert next(s1.seq) == 2 38 | s2 = MysqlStream(reader=None, writer=None) # type: ignore 39 | assert next(s2.seq) == 0 40 | assert next(s2.seq) == 1 41 | assert next(s2.seq) == 2 42 | 43 | 44 | def test_reset_seq() -> None: 45 | s = MysqlStream(reader=None, writer=None) # type: ignore 46 | assert next(s.seq) == 0 47 | assert next(s.seq) == 1 48 | s.reset_seq() 49 | assert next(s.seq) == 0 50 | assert next(s.seq) == 1 51 | 52 | 53 | @pytest.mark.asyncio 54 | async def test_bad_seq_read() -> None: 55 | reader = MockReader(b"\x00\x00\x00\x01") 56 | s = MysqlStream(reader=reader, writer=None) # type: ignore 57 | with pytest.raises(MysqlError): 58 | await s.read() 59 | 60 | 61 | @pytest.mark.asyncio 62 | async def test_empty_read() -> None: 63 | reader = MockReader(b"\x00\x00\x00\x00") 64 | s = MysqlStream(reader=reader, writer=None) # type: ignore 65 | assert await s.read() == b"" 66 | assert next(s.seq) == 1 67 | 68 | 69 | @pytest.mark.asyncio 70 | async def test_small_read() -> None: 71 | reader = MockReader(b"\x01\x00\x00\x00k") 72 | s = MysqlStream(reader=reader, writer=None) # type: ignore 73 | assert await s.read() == b"k" 74 | assert next(s.seq) == 1 75 | 76 | 77 | @pytest.mark.asyncio 78 | async def test_medium_read() -> None: 79 | reader = MockReader(b"\xff\xff\x00\x00" + bytes(0xFFFF)) 80 | s = MysqlStream(reader=reader, writer=None) # type: ignore 81 | assert await s.read() == bytes(0xFFFF) 82 | assert next(s.seq) == 1 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_edge_read() -> None: 87 | reader = MockReader(b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x00\x00\x00\x01") 88 | s = MysqlStream(reader=reader, writer=None) # type: ignore 89 | assert await s.read() == bytes(0xFFFFFF) 90 | assert next(s.seq) == 2 91 | 92 | 93 | @pytest.mark.asyncio 94 | async def test_large_read() -> None: 95 | reader = MockReader( 96 | b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x06\x00\x00\x01" + b"kelsin" 97 | ) 98 | s = MysqlStream(reader=reader, writer=None) # type: ignore 99 | assert await s.read() == bytes(0xFFFFFF) + b"kelsin" 100 | assert next(s.seq) == 2 101 | 102 | 103 | @pytest.mark.asyncio 104 | async def test_empty_write() -> None: 105 | writer = MockWriter() 106 | s = MysqlStream(reader=None, writer=writer) # type: ignore 107 | await s.write(b"") 108 | assert writer.data == b"\x00\x00\x00\x00" 109 | assert next(s.seq) == 1 110 | 111 | 112 | @pytest.mark.asyncio 113 | async def test_small_write() -> None: 114 | writer = MockWriter() 115 | s = MysqlStream(reader=None, writer=writer) # type: ignore 116 | await s.write(b"kelsin") 117 | assert writer.data == b"\x06\x00\x00\x00kelsin" 118 | assert next(s.seq) == 1 119 | 120 | 121 | @pytest.mark.asyncio 122 | async def test_medium_write() -> None: 123 | writer = MockWriter() 124 | s = MysqlStream(reader=None, writer=writer) # type: ignore 125 | await s.write(bytes(0xFFFF)) 126 | assert writer.data == b"\xff\xff\x00\x00" + bytes(0xFFFF) 127 | assert next(s.seq) == 1 128 | 129 | 130 | @pytest.mark.asyncio 131 | async def test_edge_write() -> None: 132 | writer = MockWriter() 133 | s = MysqlStream(reader=None, writer=writer) # type: ignore 134 | await s.write(bytes(0xFFFFFF)) 135 | assert writer.data == b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x00\x00\x00\x01" 136 | assert next(s.seq) == 2 137 | 138 | 139 | @pytest.mark.asyncio 140 | async def test_large_write() -> None: 141 | writer = MockWriter() 142 | s = MysqlStream(reader=None, writer=writer) # type: ignore 143 | await s.write(bytes(0xFFFFFF) + b"kelsin") 144 | assert ( 145 | writer.data == b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x06\x00\x00\x01kelsin" 146 | ) 147 | assert next(s.seq) == 2 148 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | import io 2 | import struct 3 | 4 | import pytest 5 | 6 | from mysql_mimic import types 7 | 8 | 9 | def test_column_definition() -> None: 10 | assert types.ColumnDefinition.NOT_NULL_FLAG == 1 11 | assert types.ColumnDefinition.PRI_KEY_FLAG == 2 12 | assert types.ColumnDefinition.FIELD_IS_INVISIBLE == 1 << 30 13 | assert ( 14 | types.ColumnDefinition.NOT_NULL_FLAG | types.ColumnDefinition.PRI_KEY_FLAG == 3 15 | ) 16 | 17 | 18 | def test_capabilities() -> None: 19 | assert types.Capabilities.CLIENT_LONG_PASSWORD == 1 20 | assert types.Capabilities.CLIENT_FOUND_ROWS == 2 21 | assert types.Capabilities.CLIENT_PROTOCOL_41 == 512 22 | assert types.Capabilities.CLIENT_DEPRECATE_EOF == 1 << 24 23 | assert ( 24 | types.Capabilities.CLIENT_LONG_PASSWORD | types.Capabilities.CLIENT_FOUND_ROWS 25 | == 3 26 | ) 27 | 28 | 29 | def test_server_status() -> None: 30 | assert types.ServerStatus.SERVER_STATUS_IN_TRANS == 1 31 | assert types.ServerStatus.SERVER_STATUS_AUTOCOMMIT == 2 32 | assert types.ServerStatus.SERVER_STATUS_LAST_ROW_SENT == 128 33 | assert ( 34 | types.ServerStatus.SERVER_STATUS_IN_TRANS 35 | | types.ServerStatus.SERVER_STATUS_AUTOCOMMIT 36 | == 3 37 | ) 38 | 39 | 40 | def test_int_1() -> None: 41 | assert types.uint_1(0) == b"\x00" 42 | assert types.uint_1(1) == b"\x01" 43 | assert types.uint_1(255) == b"\xff" 44 | pytest.raises(struct.error, types.uint_1, 256) 45 | 46 | 47 | def test_int_2() -> None: 48 | assert types.uint_2(0) == b"\x00\x00" 49 | assert types.uint_2(1) == b"\x01\x00" 50 | assert types.uint_2(255) == b"\xff\x00" 51 | assert types.uint_2(2**8) == b"\x00\x01" 52 | assert types.uint_2(2**16 - 1) == b"\xff\xff" 53 | pytest.raises(struct.error, types.uint_2, 2**16) 54 | 55 | 56 | def test_int_3() -> None: 57 | assert types.uint_3(0) == b"\x00\x00\x00" 58 | assert types.uint_3(1) == b"\x01\x00\x00" 59 | assert types.uint_3(255) == b"\xff\x00\x00" 60 | assert types.uint_3(2**8) == b"\x00\x01\x00" 61 | assert types.uint_3(2**16) == b"\x00\x00\x01" 62 | pytest.raises(struct.error, types.uint_3, 2**24) 63 | 64 | 65 | def test_int_4() -> None: 66 | assert types.uint_4(0) == b"\x00\x00\x00\x00" 67 | assert types.uint_4(1) == b"\x01\x00\x00\x00" 68 | assert types.uint_4(255) == b"\xff\x00\x00\x00" 69 | assert types.uint_4(2**8) == b"\x00\x01\x00\x00" 70 | assert types.uint_4(2**16) == b"\x00\x00\x01\x00" 71 | assert types.uint_4(2**24) == b"\x00\x00\x00\x01" 72 | pytest.raises(struct.error, types.uint_4, 2**32) 73 | 74 | 75 | def test_int_6() -> None: 76 | assert types.uint_6(0) == b"\x00\x00\x00\x00\x00\x00" 77 | assert types.uint_6(1) == b"\x01\x00\x00\x00\x00\x00" 78 | assert types.uint_6(255) == b"\xff\x00\x00\x00\x00\x00" 79 | assert types.uint_6(2**8) == b"\x00\x01\x00\x00\x00\x00" 80 | assert types.uint_6(2**16) == b"\x00\x00\x01\x00\x00\x00" 81 | assert types.uint_6(2**24) == b"\x00\x00\x00\x01\x00\x00" 82 | assert types.uint_6(2**32) == b"\x00\x00\x00\x00\x01\x00" 83 | assert types.uint_6(2**40) == b"\x00\x00\x00\x00\x00\x01" 84 | pytest.raises(struct.error, types.uint_6, 2**48) 85 | 86 | 87 | def test_int_8() -> None: 88 | assert types.uint_8(0) == b"\x00\x00\x00\x00\x00\x00\x00\x00" 89 | assert types.uint_8(1) == b"\x01\x00\x00\x00\x00\x00\x00\x00" 90 | assert types.uint_8(255) == b"\xff\x00\x00\x00\x00\x00\x00\x00" 91 | assert types.uint_8(2**8) == b"\x00\x01\x00\x00\x00\x00\x00\x00" 92 | assert types.uint_8(2**16) == b"\x00\x00\x01\x00\x00\x00\x00\x00" 93 | assert types.uint_8(2**24) == b"\x00\x00\x00\x01\x00\x00\x00\x00" 94 | assert types.uint_8(2**32) == b"\x00\x00\x00\x00\x01\x00\x00\x00" 95 | assert types.uint_8(2**56) == b"\x00\x00\x00\x00\x00\x00\x00\x01" 96 | pytest.raises(struct.error, types.uint_8, 2**64) 97 | 98 | 99 | def test_int_len() -> None: 100 | assert types.uint_len(0) == b"\x00" 101 | assert types.uint_len(1) == b"\x01" 102 | assert types.uint_len(250) == b"\xfa" 103 | assert types.uint_len(251) == b"\xfc\xfb\x00" 104 | assert types.uint_len(2**8) == b"\xfc\x00\x01" 105 | assert types.uint_len(2**16) == b"\xfd\x00\x00\x01" 106 | assert types.uint_len(2**24) == b"\xfe\x00\x00\x00\x01\x00\x00\x00\x00" 107 | assert types.uint_len(2**32) == b"\xfe\x00\x00\x00\x00\x01\x00\x00\x00" 108 | assert types.uint_len(2**56) == b"\xfe\x00\x00\x00\x00\x00\x00\x00\x01" 109 | pytest.raises(struct.error, types.uint_len, 2**64) 110 | 111 | 112 | def test_str_fixed() -> None: 113 | assert types.str_fixed(0, b"kelsin") == b"" 114 | assert types.str_fixed(1, b"kelsin") == b"k" 115 | assert types.str_fixed(6, b"kelsin") == b"kelsin" 116 | assert types.str_fixed(8, b"kelsin") == b"kelsin\x00\x00" 117 | 118 | 119 | def test_str_null() -> None: 120 | assert types.str_null(b"") == b"\x00" 121 | assert types.str_null(b"kelsin") == b"kelsin\x00" 122 | 123 | 124 | def test_str_len() -> None: 125 | assert types.str_len(b"") == b"\x00" 126 | assert types.str_len(b"kelsin") == b"\x06kelsin" 127 | 128 | big_str = bytes(256) 129 | assert types.str_len(big_str) == b"\xfc\x00\x01" + big_str 130 | 131 | 132 | def test_str_rest() -> None: 133 | assert types.str_rest(b"") == b"" 134 | assert types.str_rest(b"kelsin") == b"kelsin" 135 | 136 | big_str = bytes(256) 137 | assert types.str_rest(big_str) == big_str 138 | 139 | 140 | def test_read_int_1() -> None: 141 | reader = io.BytesIO(b"\x01") 142 | assert types.read_uint_1(reader) == 1 143 | reader = io.BytesIO(b"\x00\x00") 144 | assert types.read_uint_1(reader) == 0 145 | assert types.read_uint_1(reader) == 0 146 | 147 | 148 | def test_read_int_2() -> None: 149 | reader = io.BytesIO(b"\x00\x01") 150 | assert types.read_uint_2(reader) == 256 151 | 152 | 153 | def test_read_int_3() -> None: 154 | reader = io.BytesIO(b"\x00\x00\x01") 155 | assert types.read_uint_3(reader) == 2**16 156 | 157 | 158 | def test_read_int_4() -> None: 159 | reader = io.BytesIO(b"\x01\x00\x00\x00") 160 | assert types.read_uint_4(reader) == 1 161 | reader = io.BytesIO(b"\x00\x00\x01\x00") 162 | assert types.read_uint_4(reader) == 2**16 163 | reader = io.BytesIO(b"\x00\x00\x00\x01") 164 | assert types.read_uint_4(reader) == 2**24 165 | 166 | 167 | def test_read_int_6() -> None: 168 | reader = io.BytesIO(b"\x01\x00\x00\x00\x00\x00") 169 | assert types.read_uint_6(reader) == 1 170 | reader = io.BytesIO(b"\x00\x00\x00\x00\x01\x00") 171 | assert types.read_uint_6(reader) == 2**32 172 | reader = io.BytesIO(b"\x00\x00\x00\x00\x00\x01") 173 | assert types.read_uint_6(reader) == 2**40 174 | 175 | 176 | def test_read_int_8() -> None: 177 | reader = io.BytesIO(b"\x01\x00\x00\x00\x00\x00\x00\x00") 178 | assert types.read_uint_8(reader) == 1 179 | reader = io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x01\x00") 180 | assert types.read_uint_8(reader) == 2**48 181 | reader = io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x01") 182 | assert types.read_uint_8(reader) == 2**56 183 | 184 | 185 | def test_read_int_len() -> None: 186 | reader = io.BytesIO(b"\x01") 187 | assert types.read_uint_len(reader) == 1 188 | reader = io.BytesIO(b"\xfa") 189 | assert types.read_uint_len(reader) == 250 190 | reader = io.BytesIO(b"\xfc\x00\x01") 191 | assert types.read_uint_len(reader) == 256 192 | reader = io.BytesIO(b"\xfd\x00\x00\x01") 193 | assert types.read_uint_len(reader) == 2**16 194 | reader = io.BytesIO(b"\xfe\x00\x00\x00\x00\x00\x00\x00\x01") 195 | assert types.read_uint_len(reader) == 2**56 196 | 197 | 198 | def test_read_str_fixed() -> None: 199 | reader = io.BytesIO(b"kelsin") 200 | assert types.read_str_fixed(reader, 1) == b"k" 201 | assert types.read_str_fixed(reader, 2) == b"el" 202 | assert types.read_str_fixed(reader, 3) == b"sin" 203 | 204 | 205 | def test_read_str_null() -> None: 206 | reader = io.BytesIO(b"\x00") 207 | assert types.read_str_null(reader) == b"" 208 | reader = io.BytesIO(b"kelsin\x00") 209 | assert types.read_str_null(reader) == b"kelsin" 210 | reader = io.BytesIO(b"kelsin\x00foo") 211 | assert types.read_str_null(reader) == b"kelsin" 212 | 213 | 214 | def test_read_str_len() -> None: 215 | reader = io.BytesIO(b"\x01kelsin") 216 | assert types.read_str_len(reader) == b"k" 217 | 218 | big_str = bytes(256) 219 | reader = io.BytesIO(b"\xfc\x00\x01" + big_str) 220 | assert types.read_str_len(reader) == big_str 221 | 222 | 223 | def test_read_str_rest() -> None: 224 | reader = io.BytesIO(b"kelsin") 225 | assert types.read_str_rest(reader) == b"kelsin" 226 | reader = io.BytesIO(b"kelsin\x00foo") 227 | assert types.read_str_rest(reader) == b"kelsin\x00foo" 228 | -------------------------------------------------------------------------------- /tests/test_variables.py: -------------------------------------------------------------------------------- 1 | from datetime import timezone, timedelta 2 | from typing import Dict 3 | 4 | import pytest 5 | 6 | from mysql_mimic.errors import MysqlError 7 | from mysql_mimic.variables import ( 8 | parse_timezone, 9 | Variables, 10 | VariableSchema, 11 | ) 12 | 13 | 14 | class TestVars(Variables): 15 | 16 | @property 17 | def schema(self) -> Dict[str, VariableSchema]: 18 | return {"foo": (str, "bar", True)} 19 | 20 | 21 | def test_parse_timezone() -> None: 22 | assert timezone(timedelta()) == parse_timezone("UTC") 23 | assert timezone(timedelta()) == parse_timezone("+00:00") 24 | assert timezone(timedelta(hours=1)) == parse_timezone("+01:00") 25 | assert timezone(timedelta(hours=-1)) == parse_timezone("-01:00") 26 | 27 | # Implicitly test cache 28 | assert timezone(timedelta(hours=-1)) == parse_timezone("-01:00") 29 | 30 | with pytest.raises(MysqlError): 31 | parse_timezone("whoops") 32 | 33 | 34 | def test_variable_mapping() -> None: 35 | test_vars = TestVars() 36 | assert test_vars 37 | assert len(test_vars) == 1 38 | 39 | assert test_vars.get_variable("foo") == "bar" 40 | assert test_vars["foo"] == "bar" 41 | 42 | test_vars["foo"] = "hello" 43 | assert test_vars.get_variable("foo") == "hello" 44 | assert test_vars["foo"] == "hello" 45 | 46 | with pytest.raises(KeyError): 47 | assert test_vars["world"] 48 | 49 | with pytest.raises(MysqlError): 50 | test_vars["world"] = "hello" 51 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | import io 2 | import contextlib 3 | 4 | from mysql_mimic import version 5 | 6 | 7 | def test_version() -> None: 8 | assert isinstance(version.__version__, str) 9 | 10 | 11 | def test_main() -> None: 12 | out = io.StringIO() 13 | with contextlib.redirect_stdout(out): 14 | version.main("__main__") 15 | 16 | assert f"{version.__version__}\n" == out.getvalue() 17 | --------------------------------------------------------------------------------