├── .editorconfig
├── .github
└── workflows
│ ├── publish.yml
│ └── tests.yml
├── .gitignore
├── .pylintrc
├── LICENSE
├── Makefile
├── README.md
├── examples
├── auth_clear_password.py
├── auth_native_password.py
├── dbapi_proxy.py
├── simple.py
└── streaming_results.py
├── integration
├── mysql-connector-j
│ ├── Makefile
│ ├── pom.xml
│ └── src
│ │ ├── main
│ │ └── java
│ │ │ └── .gitkeep
│ │ └── test
│ │ └── java
│ │ └── com
│ │ └── mysql_mimic
│ │ └── integration
│ │ └── IntegrationTest.java
└── run.py
├── mysql_mimic
├── __init__.py
├── auth.py
├── charset.py
├── connection.py
├── constants.py
├── context.py
├── control.py
├── errors.py
├── functions.py
├── intercept.py
├── packets.py
├── prepared.py
├── results.py
├── schema.py
├── server.py
├── session.py
├── stream.py
├── types.py
├── utils.py
├── variable_processor.py
├── variables.py
└── version.py
├── pyproject.toml
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── conftest.py
├── fixtures
├── __init__.py
├── certificate.pem
├── key.pem
└── queries.py
├── test_auth.py
├── test_control.py
├── test_query.py
├── test_result.py
├── test_schema.py
├── test_ssl.py
├── test_stream.py
├── test_types.py
├── test_variables.py
└── test_version.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 |
11 | [Makefile]
12 | indent_style = tab
13 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v*"
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 | - name: Set up Python 3.9
14 | uses: actions/setup-python@v5
15 | with:
16 | python-version: '3.9'
17 | cache: pip
18 | cache-dependency-path: setup.py
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | make deps
23 | - name: Build
24 | run: |
25 | make build
26 | - name: Publish
27 | env:
28 | TWINE_USERNAME: __token__
29 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
30 | run: |
31 | make publish
32 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | build:
11 | strategy:
12 | matrix:
13 | version: [
14 | { python: "3.7", ubuntu: "ubuntu-22.04" },
15 | { python: "3.8", ubuntu: "ubuntu-22.04" },
16 | { python: "3.9", ubuntu: "ubuntu-latest" },
17 | { python: "3.10", ubuntu: "ubuntu-latest" },
18 | { python: "3.11", ubuntu: "ubuntu-latest" },
19 | { python: "3.12", ubuntu: "ubuntu-latest" } ]
20 | runs-on: ${{ matrix.version.ubuntu }}
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Kerberos
24 | run: sudo apt-get install -y libkrb5-dev krb5-kdc krb5-admin-server
25 | - name: Set up Python ${{ matrix.version.python }}
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: ${{ matrix.version.python }}
29 | cache: pip
30 | cache-dependency-path: setup.py
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | make deps
35 | - name: Test
36 | run: |
37 | make test
38 | - name: Lint
39 | if: matrix.version.python != '3.7'
40 | run: |
41 | make lint
42 | - name: Format
43 | if: matrix.version.python != '3.7'
44 | run: |
45 | make format-check
46 | - name: Type annotations
47 | if: matrix.version.python != '3.7'
48 | run: |
49 | make types
50 | mysql-connector-j:
51 | name: Integration (mysql-connector-j)
52 | runs-on: ubuntu-latest
53 | steps:
54 | - uses: actions/checkout@v2
55 | - name: Set up Kerberos
56 | run: sudo apt-get install -y libkrb5-dev krb5-kdc krb5-admin-server
57 | - name: Set up Python
58 | uses: actions/setup-python@v5
59 | with:
60 | python-version: '3.10'
61 | cache: pip
62 | cache-dependency-path: setup.py
63 | - name: Install Python dependencies
64 | run: |
65 | python -m pip install --upgrade pip
66 | make deps
67 | - name: Set up Java
68 | uses: actions/setup-java@v3
69 | with:
70 | distribution: 'temurin'
71 | java-version: '17'
72 | cache: 'maven'
73 | - name: Test mysql-connector-j
74 | run: |
75 | python integration/run.py integration/mysql-connector-j/
76 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | venv*/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # PyCharm
133 | .idea/
134 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 |
3 | disable=
4 | import-outside-toplevel,
5 | invalid-name,
6 | line-too-long,
7 | missing-class-docstring,
8 | missing-function-docstring,
9 | missing-module-docstring,
10 | too-few-public-methods,
11 | too-many-instance-attributes,
12 | too-many-public-methods,
13 | too-many-branches,
14 | too-many-return-statements,
15 | unused-argument,
16 | redefined-outer-name,
17 | too-many-statements,
18 | multiple-statements,
19 |
20 | [FORMAT]
21 |
22 | max-args=10
23 | max-line-length=130
24 | max-positional-arguments=15
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Christopher Giroir
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .DEFAULT_TARGET: deps
2 |
3 | .PHONY: deps format format-check run lint test build publish clean
4 |
5 | deps:
6 | pip install --progress-bar off -e .[dev]
7 |
8 | format:
9 | python -m black .
10 |
11 | format-check:
12 | python -m black --check .
13 |
14 | run:
15 | python -m mysql_mimic.server
16 |
17 | lint:
18 | python -m pylint mysql_mimic/ tests/
19 |
20 | types:
21 | python -m mypy -p mysql_mimic -p tests
22 |
23 | test:
24 | coverage run --source=mysql_mimic -m pytest
25 | coverage report
26 | coverage html
27 |
28 | check: format-check lint types test
29 |
30 | build: clean
31 | python setup.py sdist bdist_wheel
32 |
33 | publish: build
34 | twine upload dist/*
35 |
36 | clean:
37 | rm -rf build dist
38 | find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MySQL-Mimic
2 |
3 | [](https://github.com/kelsin/mysql-mimic/actions/workflows/tests.yml)
4 |
5 | Pure-python implementation of the MySQL server [wire protocol](https://dev.mysql.com/doc/internals/en/client-server-protocol.html).
6 |
7 | This can be used to create applications that act as a MySQL server.
8 |
9 | ## Installation
10 |
11 | ```shell
12 | pip install mysql-mimic
13 | ```
14 |
15 | ## Usage
16 |
17 | A minimal use case might look like this:
18 |
19 | ```python
20 | import asyncio
21 |
22 | from mysql_mimic import MysqlServer, Session
23 |
24 |
25 | class MySession(Session):
26 | async def query(self, expression, sql, attrs):
27 | print(f"Parsed abstract syntax tree: {expression}")
28 | print(f"Original SQL string: {sql}")
29 | print(f"Query attributes: {sql}")
30 | print(f"Currently authenticated user: {self.username}")
31 | print(f"Currently selected database: {self.database}")
32 | return [("a", 1), ("b", 2)], ["col1", "col2"]
33 |
34 | async def schema(self):
35 | # Optionally provide the database schema.
36 | # This is used to serve INFORMATION_SCHEMA and SHOW queries.
37 | return {
38 | "table": {
39 | "col1": "TEXT",
40 | "col2": "INT",
41 | }
42 | }
43 |
44 | if __name__ == "__main__":
45 | server = MysqlServer(session_factory=MySession)
46 | asyncio.run(server.serve_forever())
47 | ```
48 |
49 | Using [sqlglot](https://github.com/tobymao/sqlglot), the abstract `Session` class handles queries to metadata, variables, etc. that many MySQL clients expect.
50 |
51 | To bypass this default behavior, you can implement the [`mysql_mimic.session.BaseSession`](mysql_mimic/session.py) interface.
52 |
53 | See [examples](./examples) for more examples.
54 |
55 | ## Authentication
56 |
57 | MySQL-mimic has built in support for several standard MySQL authentication plugins:
58 | - [mysql_native_password](https://dev.mysql.com/doc/refman/8.0/en/native-pluggable-authentication.html)
59 | - The client sends hashed passwords to the server, and the server stores hashed passwords. See the documentation for more details on how this works.
60 | - [example](examples/auth_native_password.py)
61 | - [mysql_clear_password](https://dev.mysql.com/doc/refman/8.0/en/cleartext-pluggable-authentication.html)
62 | - The client sends passwords to the server as clear text, without hashing or encryption.
63 | - This is typically used as the client plugin for a custom server plugin. As such, MySQL-mimic provides an abstract class, [`mysql_mimic.auth.AbstractClearPasswordAuthPlugin`](mysql_mimic/auth.py), which can be extended.
64 | - [example](examples/auth_clear_password.py)
65 | - [mysql_no_login](https://dev.mysql.com/doc/refman/8.0/en/no-login-pluggable-authentication.html)
66 | - The server prevents clients from directly authenticating as an account. See the documentation for relevant use cases.
67 | - [authentication_kerberos](https://dev.mysql.com/doc/mysql-security-excerpt/8.0/en/kerberos-pluggable-authentication.html)
68 | - Kerberos uses tickets together with symmetric-key cryptography, enabling authentication without sending passwords over the network. Kerberos authentication supports userless and passwordless scenarios.
69 |
70 |
71 | By default, a session naively accepts whatever username the client provides.
72 |
73 | Plugins are provided to the server by implementing [`mysql_mimic.IdentityProvider`](mysql_mimic/auth.py), which configures all available plugins and a callback for fetching users.
74 |
75 | Custom plugins can be created by extending [`mysql_mimic.auth.AuthPlugin`](mysql_mimic/auth.py).
76 |
77 | ## Development
78 |
79 | You can install dependencies with `make deps`.
80 |
81 | You can format your code with `make format`.
82 |
83 | You can lint with `make lint`.
84 |
85 | You can check type annotations with `make types`.
86 |
87 | You can run tests with `make test`. This will build a coverage report in `./htmlcov/index.html`.
88 |
89 | You can run all the checks with `make check`.
90 |
91 | You can build a pip package with `make build`.
92 |
--------------------------------------------------------------------------------
/examples/auth_clear_password.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import asyncio
3 |
4 | from mysql_mimic import (
5 | MysqlServer,
6 | IdentityProvider,
7 | User,
8 | )
9 | from mysql_mimic.auth import AbstractClearPasswordAuthPlugin
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | # Storing plain text passwords is not safe.
15 | # This is just done for demonstration purposes.
16 | USERS = {
17 | "user1": "password1",
18 | "user2": "password2",
19 | }
20 |
21 |
22 | class CustomAuthPlugin(AbstractClearPasswordAuthPlugin):
23 | name = "custom_plugin"
24 |
25 | async def check(self, username, password):
26 | return username if USERS.get(username) == password else None
27 |
28 |
29 | class CustomIdentityProvider(IdentityProvider):
30 | def get_plugins(self):
31 | return [CustomAuthPlugin()]
32 |
33 | async def get_user(self, username):
34 | # Because we're storing users/passwords in an external system (the USERS dictionary, in this case),
35 | # we just assume all users exist.
36 | return User(name=username, auth_plugin=CustomAuthPlugin.name)
37 |
38 |
39 | async def main():
40 | logging.basicConfig(level=logging.DEBUG)
41 | identity_provider = CustomIdentityProvider()
42 | server = MysqlServer(identity_provider=identity_provider)
43 | await server.serve_forever()
44 |
45 |
46 | if __name__ == "__main__":
47 | asyncio.run(main())
48 | # To connect with the MySQL command line interface:
49 | # mysql -h127.0.0.1 -P3306 -uuser1 -ppassword1 --enable-cleartext-plugin
50 |
--------------------------------------------------------------------------------
/examples/auth_native_password.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import asyncio
3 |
4 | from mysql_mimic import (
5 | MysqlServer,
6 | IdentityProvider,
7 | NativePasswordAuthPlugin,
8 | User,
9 | )
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class CustomIdentityProvider(IdentityProvider):
15 | def __init__(self, passwords):
16 | # Storing passwords in plain text isn't safe.
17 | # This is done for demonstration purposes.
18 | # It's better to store the password hash, as returned by `NativePasswordAuthPlugin.create_auth_string`
19 | self.passwords = passwords
20 |
21 | def get_plugins(self):
22 | return [NativePasswordAuthPlugin()]
23 |
24 | async def get_user(self, username):
25 | password = self.passwords.get(username)
26 | if password:
27 | return User(
28 | name=username,
29 | auth_string=NativePasswordAuthPlugin.create_auth_string(password),
30 | auth_plugin=NativePasswordAuthPlugin.name,
31 | )
32 | return None
33 |
34 |
35 | async def main():
36 | logging.basicConfig(level=logging.DEBUG)
37 | identity_provider = CustomIdentityProvider(passwords={"user": "password"})
38 | server = MysqlServer(identity_provider=identity_provider)
39 | await server.serve_forever()
40 |
41 |
42 | if __name__ == "__main__":
43 | asyncio.run(main())
44 |
--------------------------------------------------------------------------------
/examples/dbapi_proxy.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import asyncio
3 | import sqlite3
4 |
5 | from mysql_mimic import MysqlServer, Session
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class DbapiProxySession(Session):
11 | def __init__(self):
12 | super().__init__()
13 | self.conn = sqlite3.connect(":memory:")
14 |
15 | async def query(self, expression, sql, attrs):
16 | cursor = self.conn.cursor()
17 | cursor.execute(expression.sql(dialect="sqlite"))
18 | try:
19 | rows = cursor.fetchall()
20 | if cursor.description:
21 | columns = [c[0] for c in cursor.description]
22 | return rows, columns
23 | return None
24 | finally:
25 | cursor.close()
26 |
27 |
28 | async def main():
29 | logging.basicConfig(level=logging.DEBUG)
30 | server = MysqlServer(session_factory=DbapiProxySession)
31 | await server.serve_forever()
32 |
33 |
34 | if __name__ == "__main__":
35 | asyncio.run(main())
36 |
--------------------------------------------------------------------------------
/examples/simple.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import asyncio
3 | from sqlglot.executor import execute
4 |
5 | from mysql_mimic import MysqlServer, Session
6 |
7 |
8 | SCHEMA = {
9 | "test": {
10 | "x": {
11 | "a": "INT",
12 | }
13 | }
14 | }
15 |
16 | TABLES = {
17 | "test": {
18 | "x": [
19 | {"a": 1},
20 | {"a": 2},
21 | {"a": 3},
22 | ]
23 | }
24 | }
25 |
26 |
27 | class MySession(Session):
28 | async def query(self, expression, sql, attrs):
29 | result = execute(expression, schema=SCHEMA, tables=TABLES)
30 | return result.rows, result.columns
31 |
32 | async def schema(self):
33 | return SCHEMA
34 |
35 |
36 | async def main():
37 | logging.basicConfig(level=logging.DEBUG)
38 | server = MysqlServer(session_factory=MySession)
39 | await server.serve_forever()
40 |
41 |
42 | if __name__ == "__main__":
43 | asyncio.run(main())
44 |
--------------------------------------------------------------------------------
/examples/streaming_results.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import asyncio
3 |
4 | from mysql_mimic import MysqlServer, Session
5 |
6 |
7 | class MySession(Session):
8 | async def generate_rows(self, n):
9 | for i in range(n):
10 | if i % 100 == 0:
11 | logging.info("Pretending to fetch another batch of results...")
12 | await asyncio.sleep(1)
13 | yield i,
14 |
15 | async def query(self, expression, sql, attrs):
16 | return self.generate_rows(1000), ["a"]
17 |
18 |
19 | async def main():
20 | logging.basicConfig(level=logging.DEBUG)
21 | server = MysqlServer(session_factory=MySession)
22 | await server.serve_forever()
23 |
24 |
25 | if __name__ == "__main__":
26 | asyncio.run(main())
27 |
--------------------------------------------------------------------------------
/integration/mysql-connector-j/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test
2 |
3 | test:
4 | mvn test
5 |
6 | test-debug:
7 | mvn -X test
8 |
--------------------------------------------------------------------------------
/integration/mysql-connector-j/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 | com.mysql_mimic.integration
8 | integration
9 | 1.0-SNAPSHOT
10 |
11 | integration
12 |
13 | http://www.example.com
14 |
15 |
16 | UTF-8
17 | 1.7
18 | 1.7
19 |
20 |
21 |
22 |
23 | junit
24 | junit
25 | 4.11
26 | test
27 |
28 |
29 | com.mysql
30 | mysql-connector-j
31 | 8.0.31
32 |
33 |
34 | com.google.protobuf
35 | protobuf-java
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | maven-clean-plugin
47 | 3.1.0
48 |
49 |
50 |
51 | maven-resources-plugin
52 | 3.0.2
53 |
54 |
55 | maven-compiler-plugin
56 | 3.8.0
57 |
58 |
59 | maven-surefire-plugin
60 | 2.22.1
61 |
62 |
63 | maven-jar-plugin
64 | 3.0.2
65 |
66 |
67 | maven-install-plugin
68 | 2.5.2
69 |
70 |
71 | maven-deploy-plugin
72 | 2.8.2
73 |
74 |
75 |
76 | maven-site-plugin
77 | 3.7.1
78 |
79 |
80 | maven-project-info-reports-plugin
81 | 3.0.0
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/integration/mysql-connector-j/src/main/java/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/barakalon/mysql-mimic/570cd6cb862a52fabbbdf15d801f4855d530c11d/integration/mysql-connector-j/src/main/java/.gitkeep
--------------------------------------------------------------------------------
/integration/mysql-connector-j/src/test/java/com/mysql_mimic/integration/IntegrationTest.java:
--------------------------------------------------------------------------------
1 | package com.mysql_mimic.integration;
2 |
3 | import static org.junit.Assert.assertEquals;
4 |
5 | import org.junit.Test;
6 | import java.sql.Connection;
7 | import java.sql.DriverManager;
8 | import java.sql.Statement;
9 | import java.sql.SQLException;
10 | import java.sql.ResultSet;
11 | import java.util.ArrayList;
12 | import java.util.List;
13 |
14 | public class IntegrationTest
15 | {
16 | @Test
17 | public void test() throws Exception {
18 | Class.forName("com.mysql.cj.jdbc.Driver").newInstance();
19 |
20 | String port = System.getenv("PORT");
21 | String url = String.format("jdbc:mysql://user:password@127.0.0.1:%s", port);
22 |
23 | Connection conn = DriverManager.getConnection(url);
24 |
25 | Statement stmt = conn.createStatement();
26 | ResultSet rs = stmt.executeQuery("SELECT a FROM x ORDER BY a");
27 |
28 | List result = new ArrayList<>();
29 | while (rs.next()) {
30 | result.add(rs.getInt("a"));
31 | }
32 |
33 | List expected = new ArrayList<>();
34 | expected.add(1);
35 | expected.add(2);
36 | expected.add(3);
37 |
38 | assertEquals(expected, result);
39 | }
40 |
41 | @Test
42 | public void testKrb5() throws Exception {
43 | // System.setProperty("sun.security.krb5.debug", "True");
44 | // System.setProperty("sun.security.jgss.debug", "True");
45 | System.setProperty("java.security.krb5.conf", System.getenv("KRB5_CONFIG"));
46 | System.setProperty("java.security.auth.login.config", System.getenv("JAAS_CONFIG"));
47 |
48 | Class.forName("com.mysql.cj.jdbc.Driver").newInstance();
49 |
50 | String port = System.getenv("PORT");
51 | String url = String.format("jdbc:mysql://127.0.0.1:%s/?defaultAuthenticationPlugin=authentication_kerberos_client", port);
52 |
53 | Connection conn = DriverManager.getConnection(url);
54 |
55 | Statement stmt = conn.createStatement();
56 | ResultSet rs = stmt.executeQuery("SELECT CURRENT_USER AS CURRENT_USER");
57 |
58 | rs.next();
59 |
60 | assertEquals("krb5_user", rs.getString("CURRENT_USER"));
61 | }
62 |
63 | @Test(expected = SQLException.class)
64 | public void testKrb5IncorrectUser() throws Exception {
65 | System.setProperty("java.security.krb5.conf", System.getenv("KRB5_CONFIG"));
66 | System.setProperty("java.security.auth.login.config", System.getenv("JAAS_CONFIG"));
67 |
68 | Class.forName("com.mysql.cj.jdbc.Driver").newInstance();
69 |
70 | String user = "incorrect_user";
71 | String port = System.getenv("PORT");
72 | String url = String.format("jdbc:mysql://%s@127.0.0.1:%s/?defaultAuthenticationPlugin=authentication_kerberos_client", user, port);
73 |
74 | Connection conn = DriverManager.getConnection(url);
75 |
76 | Statement stmt = conn.createStatement();
77 | stmt.executeQuery("SELECT a FROM x ORDER BY a");
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/integration/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import asyncio
4 | import os
5 | import sys
6 | import tempfile
7 | import time
8 |
9 | import k5test
10 | from sqlglot.executor import execute
11 |
12 | from mysql_mimic import (
13 | MysqlServer,
14 | IdentityProvider,
15 | NativePasswordAuthPlugin,
16 | User,
17 | Session,
18 | )
19 | from mysql_mimic.auth import KerberosAuthPlugin
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | SCHEMA = {
24 | "test": {
25 | "x": {
26 | "a": "INT",
27 | }
28 | }
29 | }
30 |
31 | TABLES = {
32 | "test": {
33 | "x": [
34 | {"a": 1},
35 | {"a": 2},
36 | {"a": 3},
37 | ]
38 | }
39 | }
40 |
41 |
42 | class MySession(Session):
43 | async def query(self, expression, sql, attrs):
44 | result = execute(expression, schema=SCHEMA, tables=TABLES)
45 | return result.rows, result.columns
46 |
47 | async def schema(self):
48 | return SCHEMA
49 |
50 |
51 | class CustomIdentityProvider(IdentityProvider):
52 | def __init__(self, krb5_service, krb5_realm):
53 | self.users = {
54 | "user": {"auth_plugin": "mysql_native_password", "password": "password"},
55 | "krb5_user": {"auth_plugin": "authentication_kerberos"},
56 | }
57 | self.krb5_service = krb5_service
58 | self.krb5_realm = krb5_realm
59 |
60 | def get_plugins(self):
61 | return [
62 | NativePasswordAuthPlugin(),
63 | KerberosAuthPlugin(service=self.krb5_service, realm=self.krb5_realm),
64 | ]
65 |
66 | async def get_user(self, username):
67 | user = self.users.get(username)
68 | if user is not None:
69 | auth_plugin = user["auth_plugin"]
70 |
71 | if auth_plugin == "mysql_native_password":
72 | password = user.get("password")
73 | return User(
74 | name=username,
75 | auth_string=(
76 | NativePasswordAuthPlugin.create_auth_string(password)
77 | if password
78 | else None
79 | ),
80 | auth_plugin=NativePasswordAuthPlugin.name,
81 | )
82 | elif auth_plugin == "authentication_kerberos":
83 | return User(name=username, auth_plugin=KerberosAuthPlugin.name)
84 | return None
85 |
86 |
87 | async def wait_for_port(port, host="localhost", timeout=5.0):
88 | start_time = time.time()
89 | while True:
90 | try:
91 | _ = await asyncio.open_connection(host, port)
92 | break
93 | except OSError:
94 | await asyncio.sleep(0.01)
95 | if time.time() - start_time >= timeout:
96 | raise TimeoutError()
97 |
98 |
99 | def setup_krb5(krb5_user):
100 | realm = k5test.K5Realm()
101 | krb5_user_princ = f"{krb5_user}@{realm.realm}"
102 | realm.addprinc(krb5_user_princ, realm.password(krb5_user))
103 | realm.kinit(krb5_user_princ, realm.password(krb5_user))
104 | return realm
105 |
106 |
107 | def write_jaas_conf(realm, debug=False):
108 | conf = f"""
109 | MySQLConnectorJ {{
110 | com.sun.security.auth.module.Krb5LoginModule
111 | required
112 | debug={str(debug).lower()}
113 | useTicketCache=true
114 | ticketCache="{realm.env["KRB5CCNAME"]}";
115 | }};
116 | """
117 | path = tempfile.mktemp()
118 | with open(path, "w") as f:
119 | f.write(conf)
120 |
121 | return path
122 |
123 |
124 | async def main():
125 | parser = argparse.ArgumentParser()
126 | parser.add_argument("test_dir")
127 | parser.add_argument("-p", "--port", type=int, default=3308)
128 | args = parser.parse_args()
129 |
130 | logging.basicConfig(level=logging.DEBUG)
131 | realm = None
132 | jaas_conf = None
133 |
134 | try:
135 | realm = setup_krb5(krb5_user="krb5_user")
136 | os.environ.update(realm.env)
137 |
138 | jaas_conf = write_jaas_conf(realm)
139 | os.environ["JAAS_CONFIG"] = jaas_conf
140 |
141 | krb5_service = realm.host_princ[: realm.host_princ.index("@")]
142 | identity_provider = CustomIdentityProvider(
143 | krb5_service=krb5_service, krb5_realm=realm.realm
144 | )
145 | server = MysqlServer(
146 | identity_provider=identity_provider, session_factory=MySession
147 | )
148 |
149 | await server.start_server(port=args.port)
150 | await wait_for_port(port=args.port)
151 | process = await asyncio.create_subprocess_shell(
152 | "make test",
153 | env={**os.environ, "PORT": str(args.port)},
154 | cwd=args.test_dir,
155 | )
156 | return_code = await process.wait()
157 |
158 | server.close()
159 | await server.wait_closed()
160 | return return_code
161 | finally:
162 | if realm:
163 | realm.stop()
164 | if jaas_conf:
165 | os.remove(jaas_conf)
166 |
167 |
168 | if __name__ == "__main__":
169 | sys.exit(asyncio.run(main()))
170 |
--------------------------------------------------------------------------------
/mysql_mimic/__init__.py:
--------------------------------------------------------------------------------
1 | """Implementation of the mysql server wire protocol"""
2 |
3 | from mysql_mimic.auth import (
4 | User,
5 | IdentityProvider,
6 | NativePasswordAuthPlugin,
7 | NoLoginAuthPlugin,
8 | AuthPlugin,
9 | )
10 | from mysql_mimic.results import AllowedResult, ResultColumn, ResultSet
11 | from mysql_mimic.session import Session
12 | from mysql_mimic.server import MysqlServer
13 | from mysql_mimic.types import ColumnType
14 |
--------------------------------------------------------------------------------
/mysql_mimic/auth.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 | from copy import copy
5 | from hashlib import sha1
6 | import logging
7 | from dataclasses import dataclass
8 | from typing import Optional, Dict, AsyncGenerator, Union, Tuple, Sequence
9 |
10 | from mysql_mimic.types import read_str_null
11 | from mysql_mimic import utils
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | # Many authentication plugins don't need to send any sort of challenge/nonce.
17 | FILLER = b"0" * 20 + b"\x00"
18 |
19 |
20 | @dataclass
21 | class Forbidden:
22 | msg: Optional[str] = None
23 |
24 |
25 | @dataclass
26 | class Success:
27 | authenticated_as: str
28 |
29 |
30 | @dataclass
31 | class User:
32 | name: str
33 | auth_string: Optional[str] = None
34 | auth_plugin: Optional[str] = None
35 |
36 | # For plugins that support a primary and secondary password
37 | # This is helpful for zero-downtime password rotation
38 | old_auth_string: Optional[str] = None
39 |
40 |
41 | @dataclass
42 | class AuthInfo:
43 | username: str
44 | data: bytes
45 | user: User
46 | connect_attrs: Dict[str, str]
47 | client_plugin_name: Optional[str]
48 | handshake_auth_data: Optional[bytes]
49 | handshake_plugin_name: str
50 |
51 | def copy(self, data: bytes) -> AuthInfo:
52 | new = copy(self)
53 | new.data = data
54 | return new
55 |
56 |
57 | Decision = Union[Success, Forbidden, bytes]
58 | AuthState = AsyncGenerator[Decision, AuthInfo]
59 |
60 |
61 | class AuthPlugin:
62 | """
63 | Abstract base class for authentication plugins.
64 | """
65 |
66 | name = ""
67 | client_plugin_name: Optional[str] = None # None means any
68 |
69 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
70 | """
71 | Create an async generator that drives the authentication lifecycle.
72 |
73 | This should either yield `bytes`, in which case an AuthMoreData packet is sent to the client,
74 | or a `Success` or `Forbidden` instance, in which case authentication is complete.
75 |
76 | Args:
77 | auth_info: This is None if authentication is starting from the optimistic handshake.
78 | """
79 | yield Forbidden()
80 |
81 | async def start(
82 | self, auth_info: Optional[AuthInfo] = None
83 | ) -> Tuple[Decision, AuthState]:
84 | state = self.auth(auth_info)
85 | data = await state.__anext__()
86 | return data, state
87 |
88 |
89 | class AbstractClearPasswordAuthPlugin(AuthPlugin):
90 | """
91 | Abstract class for implementing the server-side of the standard client plugin "mysql_clear_password".
92 | """
93 |
94 | name = "abstract_mysql_clear_password"
95 | client_plugin_name = "mysql_clear_password"
96 |
97 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
98 | if not auth_info:
99 | auth_info = yield FILLER
100 |
101 | r = io.BytesIO(auth_info.data)
102 | password = read_str_null(r).decode()
103 | authenticated_as = await self.check(auth_info.username, password)
104 | if authenticated_as is not None:
105 | yield Success(authenticated_as)
106 | else:
107 | yield Forbidden()
108 |
109 | async def check(self, username: str, password: str) -> Optional[str]:
110 | return username
111 |
112 |
113 | class NativePasswordAuthPlugin(AuthPlugin):
114 | """
115 | Standard plugin that uses a password hashing method.
116 |
117 | The client hashed the password using a nonce provided by the server, so the
118 | password can't be snooped on the network.
119 |
120 | Furthermore, thanks to some clever hashing techniques, knowing the hash stored in the
121 | user database isn't enough to authenticate as that user.
122 | """
123 |
124 | name = "mysql_native_password"
125 | client_plugin_name = "mysql_native_password"
126 |
127 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
128 | if (
129 | auth_info
130 | and auth_info.handshake_plugin_name == self.name
131 | and auth_info.handshake_auth_data
132 | ):
133 | # mysql_native_password can reuse the nonce from the initial handshake
134 | nonce = auth_info.handshake_auth_data.rstrip(b"\x00")
135 | else:
136 | nonce = utils.nonce(20)
137 | # Some clients expect a null terminating byte
138 | auth_info = yield nonce + b"\x00"
139 |
140 | user = auth_info.user
141 | if self.password_matches(user=user, scramble=auth_info.data, nonce=nonce):
142 | yield Success(user.name)
143 | else:
144 | yield Forbidden()
145 |
146 | def empty_password_quickpath(self, user: User, scramble: bytes) -> bool:
147 | return not scramble and not user.auth_string
148 |
149 | def password_matches(self, user: User, scramble: bytes, nonce: bytes) -> bool:
150 | return (
151 | self.empty_password_quickpath(user, scramble)
152 | or self.verify_scramble(user.auth_string, scramble, nonce)
153 | or self.verify_scramble(user.old_auth_string, scramble, nonce)
154 | )
155 |
156 | def verify_scramble(
157 | self, auth_string: Optional[str], scramble: bytes, nonce: bytes
158 | ) -> bool:
159 | # From docs,
160 | # response.data should be:
161 | # SHA1(password) XOR SHA1("20-bytes random data from server" SHA1(SHA1(password)))
162 | # auth_string should be:
163 | # SHA1(SHA1(password))
164 | try:
165 | sha1_sha1_password = bytes.fromhex(auth_string or "")
166 | sha1_sha1_with_nonce = sha1(nonce + sha1_sha1_password).digest()
167 | rcvd_sha1_password = utils.xor(scramble, sha1_sha1_with_nonce)
168 | return sha1(rcvd_sha1_password).digest() == sha1_sha1_password
169 | except Exception: # pylint: disable=broad-except
170 | logger.info("Invalid scramble")
171 | return False
172 |
173 | @classmethod
174 | def create_auth_string(cls, password: str) -> str:
175 | return sha1(sha1(password.encode("utf-8")).digest()).hexdigest()
176 |
177 |
178 | class KerberosAuthPlugin(AuthPlugin):
179 | """
180 | This plugin implements the Generic Security Service Application Program Interface (GSS-API) by way of the Kerberos
181 | mechanism as described in RFC1964(https://www.rfc-editor.org/rfc/rfc1964.html).
182 | """
183 |
184 | name = "authentication_kerberos"
185 | client_plugin_name = "authentication_kerberos_client"
186 |
187 | def __init__(self, service: str, realm: str) -> None:
188 | self.service = service
189 | self.realm = realm
190 |
191 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
192 | import gssapi
193 | from gssapi.exceptions import GSSError
194 |
195 | # Fast authentication not supported
196 | if not auth_info:
197 | yield b""
198 |
199 | auth_info = (
200 | yield len(self.service).to_bytes(2, "little")
201 | + self.service.encode("utf-8")
202 | + len(self.realm).to_bytes(2, "little")
203 | + self.realm.encode("utf-8")
204 | )
205 |
206 | server_creds = gssapi.Credentials(
207 | usage="accept", name=gssapi.Name(f"{self.service}@{self.realm}")
208 | )
209 | server_ctx = gssapi.SecurityContext(usage="accept", creds=server_creds)
210 |
211 | try:
212 | server_ctx.step(auth_info.data)
213 | except GSSError as e:
214 | yield Forbidden(str(e))
215 |
216 | username = str(server_ctx.initiator_name).split("@", 1)[0]
217 | if auth_info.username and auth_info.username != username:
218 | yield Forbidden("Given username different than kerberos client")
219 | yield Success(username)
220 |
221 |
222 | class NoLoginAuthPlugin(AuthPlugin):
223 | """
224 | Standard plugin that prevents all clients from direct login.
225 |
226 | This is useful for user accounts that can only be accessed by proxy authentication.
227 | """
228 |
229 | name = "mysql_no_login"
230 | client_plugin_name = None
231 |
232 | async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
233 | if not auth_info:
234 | _ = yield FILLER
235 | yield Forbidden()
236 |
237 |
238 | class IdentityProvider:
239 | """
240 | Abstract base class for an identity provider.
241 |
242 | An identity provider tells the server which authentication plugins to make
243 | available to clients and how to retrieve users.
244 | """
245 |
246 | def get_plugins(self) -> Sequence[AuthPlugin]:
247 | return [NativePasswordAuthPlugin(), NoLoginAuthPlugin()]
248 |
249 | async def get_user(self, username: str) -> Optional[User]:
250 | return None
251 |
252 | def get_default_plugin(self) -> AuthPlugin:
253 | return self.get_plugins()[0]
254 |
255 | def get_plugin(self, name: str) -> Optional[AuthPlugin]:
256 | try:
257 | return next(p for p in self.get_plugins() if p.name == name)
258 | except StopIteration:
259 | return None
260 |
261 |
262 | class SimpleIdentityProvider(IdentityProvider):
263 | """
264 | Simple identity provider implementation that naively accepts whatever username a client provides.
265 | """
266 |
267 | async def get_user(self, username: str) -> Optional[User]:
268 | return User(name=username, auth_plugin=NativePasswordAuthPlugin.name)
269 |
--------------------------------------------------------------------------------
/mysql_mimic/constants.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from enum import auto, Enum
4 |
5 | from mysql_mimic.types import Capabilities
6 |
7 | DEFAULT_SERVER_CAPABILITIES = (
8 | Capabilities.CLIENT_PROTOCOL_41
9 | | Capabilities.CLIENT_DEPRECATE_EOF
10 | | Capabilities.CLIENT_CONNECT_WITH_DB
11 | | Capabilities.CLIENT_QUERY_ATTRIBUTES
12 | | Capabilities.CLIENT_CONNECT_ATTRS
13 | | Capabilities.CLIENT_PLUGIN_AUTH
14 | | Capabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
15 | | Capabilities.CLIENT_SECURE_CONNECTION
16 | | Capabilities.CLIENT_LONG_PASSWORD
17 | | Capabilities.CLIENT_ODBC
18 | | Capabilities.CLIENT_INTERACTIVE
19 | | Capabilities.CLIENT_IGNORE_SPACE
20 | )
21 |
22 |
23 | class KillKind(Enum):
24 | # Terminate the statement the connection is currently executing, but leave the connection itself intact
25 | QUERY = auto()
26 | # Terminate the connection, after terminating any statement the connection is executing
27 | CONNECTION = auto()
28 |
29 |
30 | INFO_SCHEMA = {
31 | "information_schema": {
32 | "character_sets": {
33 | "character_set_name": "TEXT",
34 | "default_collate_name": "TEXT",
35 | "description": "TEXT",
36 | "maxlen": "INT",
37 | },
38 | "collations": {
39 | "collation_name": "TEXT",
40 | "character_set_name": "TEXT",
41 | "id": "TEXT",
42 | "is_default": "TEXT",
43 | "is_compiled": "TEXT",
44 | "sortlen": "INT",
45 | },
46 | "columns": {
47 | "table_catalog": "TEXT",
48 | "table_schema": "TEXT",
49 | "table_name": "TEXT",
50 | "column_name": "TEXT",
51 | "ordinal_position": "INT",
52 | "column_default": "TEXT",
53 | "is_nullable": "TEXT",
54 | "data_type": "TEXT",
55 | "character_maximum_length": "INT",
56 | "character_octet_length": "INT",
57 | "numeric_precision": "INT",
58 | "numeric_scale": "INT",
59 | "datetime_precision": "INT",
60 | "character_set_name": "TEXT",
61 | "collation_name": "TEXT",
62 | "column_type": "TEXT",
63 | "column_key": "TEXT",
64 | "extra": "TEXT",
65 | "privileges": "TEXT",
66 | "column_comment": "TEXT",
67 | "generation_expression": "TEXT",
68 | "srs_id": "TEXT",
69 | },
70 | "column_privileges": {
71 | "grantee": "TEXT",
72 | "table_catalog": "TEXT",
73 | "table_schema": "TEXT",
74 | "table_name": "TEXT",
75 | "column_name": "TEXT",
76 | "privilege_type": "TEXT",
77 | "is_grantable": "TEXT",
78 | },
79 | "events": {
80 | "event_catalog": "TEXT",
81 | "event_schema": "TEXT",
82 | "event_name": "TEXT",
83 | "definer": "TEXT",
84 | "time_zone": "TEXT",
85 | "event_body": "TEXT",
86 | "event_definition": "TEXT",
87 | "event_type": "TEXT",
88 | "execute_at": "TEXT",
89 | "interval_value": "INT",
90 | "interval_field": "TEXT",
91 | "sql_mode": "TEXT",
92 | "starts": "TEXT",
93 | "ends": "TEXT",
94 | "status": "TEXT",
95 | "on_completion": "TEXT",
96 | "created": "TEXT",
97 | "last_altered": "TEXT",
98 | "last_executed": "TEXT",
99 | "event_comment": "TEXT",
100 | "originator": "INT",
101 | "character_set_client": "TEXT",
102 | "collation_connection": "TEXT",
103 | "database_collation": "TEXT",
104 | },
105 | "key_column_usage": {
106 | "constraint_catalog": "TEXT",
107 | "constraint_schema": "TEXT",
108 | "constraint_name": "TEXT",
109 | "table_catalog": "TEXT",
110 | "table_schema": "TEXT",
111 | "table_name": "TEXT",
112 | "column_name": "TEXT",
113 | "ordinal_position": "INT",
114 | "position_in_unique_constraint": "INT",
115 | "referenced_table_schema": "TEXT",
116 | "referenced_table_name": "TEXT",
117 | "referenced_column_name": "TEXT",
118 | },
119 | "parameters": {
120 | "specific_catalog": "TEXT",
121 | "specific_schema": "TEXT",
122 | "specific_name": "TEXT",
123 | "ordinal_position": "INT",
124 | "parameter_mode": "TEXT",
125 | "parameter_name": "TEXT",
126 | "data_type": "TEXT",
127 | "character_maximum_length": "INT",
128 | "character_octet_length": "INT",
129 | "numeric_precision": "INT",
130 | "numeric_scale": "INT",
131 | "datetime_precision": "INT",
132 | "character_set_name": "TEXT",
133 | "collation_name": "TEXT",
134 | "dtd_identifier": "TEXT",
135 | "routine_type": "TEXT",
136 | },
137 | "partitions": {
138 | "table_catalog": "TEXT",
139 | "table_schema": "TEXT",
140 | "table_name": "TEXT",
141 | "partition_name": "TEXT",
142 | "subpartition_name": "TEXT",
143 | "partition_ordinal_position": "INT",
144 | "subpartition_ordinal_position": "INT",
145 | "partition_method": "TEXT",
146 | "subpartition_method": "TEXT",
147 | "partition_expression": "TEXT",
148 | "subpartition_expression": "TEXT",
149 | "partition_description": "TEXT",
150 | "table_rows": "INT",
151 | "avg_row_length": "INT",
152 | "data_length": "INT",
153 | "max_data_length": "INT",
154 | "index_length": "INT",
155 | "data_free": "INT",
156 | "create_time": "TEXT",
157 | "update_time": "TEXT",
158 | "check_time": "TEXT",
159 | "checksum": "TEXT",
160 | "partition_comment": "TEXT",
161 | "nodegroup": "TEXT",
162 | "tablespace_name": "TEXT",
163 | },
164 | "referential_constraints": {
165 | "constraint_catalog": "TEXT",
166 | "constraint_schema": "TEXT",
167 | "constraint_name": "TEXT",
168 | "unique_constraint_catalog": "TEXT",
169 | "unique_constraint_schema": "TEXT",
170 | "unique_constraint_name": "TEXT",
171 | "match_option": "TEXT",
172 | "update_rule": "TEXT",
173 | "delete_rule": "TEXT",
174 | "table_name": "TEXT",
175 | "referenced_table_name": "TEXT",
176 | },
177 | "routines": {
178 | "specific_name": "TEXT",
179 | "routine_catalog": "TEXT",
180 | "routine_schema": "TEXT",
181 | "routine_name": "TEXT",
182 | "routine_type": "TEXT",
183 | "data_type": "TEXT",
184 | "character_maximum_length": "INT",
185 | "character_octet_length": "INT",
186 | "numeric_precision": "INT",
187 | "numeric_scale": "INT",
188 | "datetime_precision": "INT",
189 | "character_set_name": "TEXT",
190 | "collation_name": "TEXT",
191 | "dtd_identifier": "TEXT",
192 | "routine_body": "TEXT",
193 | "routine_definition": "TEXT",
194 | "external_name": "TEXT",
195 | "external_language": "TEXT",
196 | "parameter_style": "TEXT",
197 | "is_deterministic": "TEXT",
198 | "sql_data_access": "TEXT",
199 | "sql_path": "TEXT",
200 | "security_type": "TEXT",
201 | "created": "TEXT",
202 | "last_altered": "TEXT",
203 | "sql_mode": "TEXT",
204 | "routine_comment": "TEXT",
205 | "definer": "TEXT",
206 | "character_set_client": "TEXT",
207 | "collation_connection": "TEXT",
208 | "database_collation": "TEXT",
209 | },
210 | "schema_privileges": {
211 | "grantee": "TEXT",
212 | "table_catalog": "TEXT",
213 | "table_schema": "TEXT",
214 | "privilege_type": "TEXT",
215 | "is_grantable": "TEXT",
216 | },
217 | "schemata": {
218 | "catalog_name": "TEXT",
219 | "schema_name": "TEXT",
220 | "default_character_set_name": "TEXT",
221 | "default_collation_name": "TEXT",
222 | "sql_path": "TEXT",
223 | },
224 | "statistics": {
225 | "table_catalog": "TEXT",
226 | "table_schema": "TEXT",
227 | "table_name": "TEXT",
228 | "non_unique": "INT",
229 | "index_schema": "TEXT",
230 | "index_name": "TEXT",
231 | "seq_in_index": "INT",
232 | "column_name": "TEXT",
233 | "collation": "TEXT",
234 | "cardinality": "INT",
235 | "sub_part": "TEXT",
236 | "packed": "TEXT",
237 | "nullable": "TEXT",
238 | "index_type": "TEXT",
239 | "comment": "TEXT",
240 | "index_comment": "TEXT",
241 | "is_visible": "TEXT",
242 | "expression": "TEXT",
243 | },
244 | "tables": {
245 | "table_catalog": "TEXT",
246 | "table_schema": "TEXT",
247 | "table_name": "TEXT",
248 | "table_type": "TEXT",
249 | "engine": "TEXT",
250 | "version": "TEXT",
251 | "row_format": "TEXT",
252 | "table_rows": "INT",
253 | "avg_row_length": "INT",
254 | "data_length": "INT",
255 | "max_data_length": "INT",
256 | "index_length": "INT",
257 | "data_free": "INT",
258 | "auto_increment": "INT",
259 | "create_time": "TEXT",
260 | "update_time": "TEXT",
261 | "check_time": "TEXT",
262 | "table_collation": "TEXT",
263 | "checksum": "TEXT",
264 | "create_options": "TEXT",
265 | "table_comment": "TEXT",
266 | },
267 | "table_constraints": {
268 | "constraint_name": "TEXT",
269 | "constraint_catalog": "TEXT",
270 | "constraint_schema": "TEXT",
271 | "table_schema": "TEXT",
272 | "table_name": "TEXT",
273 | "constraint_type": "TEXT",
274 | "enforced": "TEXT",
275 | },
276 | "table_privileges": {
277 | "grantee": "TEXT",
278 | "table_catalog": "TEXT",
279 | "table_schema": "TEXT",
280 | "table_name": "TEXT",
281 | "privilege_type": "TEXT",
282 | "is_grantable": "TEXT",
283 | },
284 | "triggers": {
285 | "trigger_catalog": "TEXT",
286 | "trigger_schema": "TEXT",
287 | "trigger_name": "TEXT",
288 | "event_manipulation": "TEXT",
289 | "event_object_catalog": "TEXT",
290 | "event_object_schema": "TEXT",
291 | "event_object_table": "TEXT",
292 | "action_order": "INT",
293 | "action_condition": "TEXT",
294 | "action_statement": "TEXT",
295 | "action_orientation": "TEXT",
296 | "action_timing": "TEXT",
297 | "action_reference_old_table": "TEXT",
298 | "action_reference_new_table": "TEXT",
299 | "action_reference_old_row": "TEXT",
300 | "action_reference_new_row": "TEXT",
301 | "created": "TEXT",
302 | "sql_mode": "TEXT",
303 | "definer": "TEXT",
304 | "character_set_client": "TEXT",
305 | "collation_connection": "TEXT",
306 | "database_collation": "TEXT",
307 | },
308 | "user_privileges": {
309 | "grantee": "TEXT",
310 | "table_catalog": "TEXT",
311 | "privilege_type": "TEXT",
312 | "is_grantable": "TEXT",
313 | },
314 | "views": {
315 | "table_catalog": "TEXT",
316 | "table_schema": "TEXT",
317 | "table_name": "TEXT",
318 | "view_definition": "TEXT",
319 | "check_option": "TEXT",
320 | "is_updatable": "TEXT",
321 | "definer": "TEXT",
322 | "security_type": "TEXT",
323 | "character_set_client": "TEXT",
324 | "collation_connection": "TEXT",
325 | },
326 | },
327 | "mysql": {
328 | "procs_priv": {
329 | "host": "TEXT",
330 | "db": "TEXT",
331 | "user": "TEXT",
332 | "routine_name": "TEXT",
333 | "routine_type": "TEXT",
334 | "proc_priv": "TEXT",
335 | "timestamp": "TEXT",
336 | "grantor": "TEXT",
337 | },
338 | "role_edges": {
339 | "from_host": "TEXT",
340 | "from_user": "TEXT",
341 | "to_host": "TEXT",
342 | "to_user": "TEXT",
343 | "with_admin_option": "TEXT",
344 | },
345 | "user": {
346 | "host": "TEXT",
347 | "user": "TEXT",
348 | "select_priv": "TEXT",
349 | "insert_priv": "TEXT",
350 | "update_priv": "TEXT",
351 | "delete_priv": "TEXT",
352 | "index_priv": "TEXT",
353 | "alter_priv": "TEXT",
354 | "create_priv": "TEXT",
355 | "drop_priv": "TEXT",
356 | "grant_priv": "TEXT",
357 | "create_view_priv": "TEXT",
358 | "show_view_priv": "TEXT",
359 | "create_routine_priv": "TEXT",
360 | "alter_routine_priv": "TEXT",
361 | "execute_priv": "TEXT",
362 | "trigger_priv": "TEXT",
363 | "event_priv": "TEXT",
364 | "create_tmp_table_priv": "TEXT",
365 | "lock_tables_priv": "TEXT",
366 | "references_priv": "TEXT",
367 | "reload_priv": "TEXT",
368 | "shutdown_priv": "TEXT",
369 | "process_priv": "TEXT",
370 | "file_priv": "TEXT",
371 | "show_db_priv": "TEXT",
372 | "super_priv": "TEXT",
373 | "repl_slave_priv": "TEXT",
374 | "repl_client_priv": "TEXT",
375 | "create_user_priv": "TEXT",
376 | "create_tablespace_priv": "TEXT",
377 | "create_role_priv": "TEXT",
378 | "drop_role_priv": "TEXT",
379 | "ssl_type": "TEXT",
380 | "ssl_cipher": "TEXT",
381 | "x509_issuer": "TEXT",
382 | "x509_subject": "TEXT",
383 | "plugin": "TEXT",
384 | "authentication_string": "TEXT",
385 | "password_expired": "TEXT",
386 | "password_last_changed": "TEXT",
387 | "password_lifetime": "TEXT",
388 | "account_locked": "TEXT",
389 | "password_reuse_history": "TEXT",
390 | "password_reuse_time": "TEXT",
391 | "password_require_current": "TEXT",
392 | "user_attributes": "TEXT",
393 | "max_questions": "INT",
394 | "max_updates": "INT",
395 | "max_connections": "INT",
396 | "max_user_connections": "INT",
397 | },
398 | },
399 | }
400 |
--------------------------------------------------------------------------------
/mysql_mimic/context.py:
--------------------------------------------------------------------------------
1 | from contextvars import ContextVar
2 |
3 | connection_id: ContextVar[int] = ContextVar("connection_id")
4 |
--------------------------------------------------------------------------------
/mysql_mimic/control.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import random
4 | from typing import Dict, Optional, TYPE_CHECKING
5 |
6 | from mysql_mimic.constants import KillKind
7 | from mysql_mimic.utils import seq
8 |
9 | if TYPE_CHECKING:
10 | from mysql_mimic.connection import Connection
11 |
12 |
13 | class TooManyConnections(Exception):
14 | pass
15 |
16 |
17 | class Control:
18 | """
19 | Base class for controlling server state.
20 |
21 | This is intended to encapsulate operations that might span deployment of multiple
22 | MySQL mimic server instances.
23 | """
24 |
25 | async def add(self, connection: Connection) -> int:
26 | """
27 | Add a new connection
28 |
29 | Returns:
30 | connection id
31 | """
32 | raise NotImplementedError()
33 |
34 | async def remove(self, connection_id: int) -> None:
35 | """Remove an existing connection"""
36 | raise NotImplementedError()
37 |
38 | async def kill(
39 | self, connection_id: int, kind: KillKind = KillKind.CONNECTION
40 | ) -> None:
41 | """Request termination of an existing connection"""
42 | raise NotImplementedError()
43 |
44 |
45 | class LocalControl(Control):
46 | """
47 | Simple Control implementation that handles everything in-memory.
48 |
49 | Args:
50 | server_id: set a unique server ID. This is used to generate globally unique
51 | connection IDs. This should be an integer between 0 and 65535.
52 | If left as None, a random server ID will be generated.
53 | """
54 |
55 | _CONNECTION_ID_BITS = 16
56 | _MAX_CONNECTION_SEQ = 2**_CONNECTION_ID_BITS
57 | _MAX_SERVER_ID = 2**16
58 |
59 | def __init__(self, server_id: Optional[int] = None):
60 | self._connection_seq = seq(self._MAX_CONNECTION_SEQ)
61 | self._connections: Dict[int, Connection] = {}
62 | self.server_id = server_id or random.randint(0, self._MAX_SERVER_ID - 1)
63 |
64 | async def add(self, connection: Connection) -> int:
65 | connection_id = self._new_connection_id()
66 | self._connections[connection_id] = connection
67 | return connection_id
68 |
69 | async def remove(self, connection_id: int) -> None:
70 | self._connections.pop(connection_id, None)
71 |
72 | async def kill(
73 | self, connection_id: int, kind: KillKind = KillKind.CONNECTION
74 | ) -> None:
75 | conn = self._connections.get(connection_id)
76 | if conn:
77 | conn.kill(kind)
78 |
79 | def _new_connection_id(self) -> int:
80 | """
81 | Generate a connection ID.
82 |
83 | MySQL connection IDs are 4 bytes.
84 |
85 | This is tricky for us, as there may be multiple MySQL-Mimic server instances.
86 |
87 | We use a concatenation of the server ID (first two bytes) and connection
88 | sequence (second two bytes):
89 |
90 | |<---server ID--->|<-conn sequence->|
91 | 00000000 00000000 00000000 00000000
92 |
93 | If incremental connection IDs aren't manually provided, this isn't guaranteed
94 | to be unique. But collisions should be highly unlikely.
95 | """
96 | if len(self._connections) >= self._MAX_CONNECTION_SEQ:
97 | raise TooManyConnections()
98 |
99 | server_id_prefix = (
100 | self.server_id % self._MAX_SERVER_ID
101 | ) << self._CONNECTION_ID_BITS
102 |
103 | connection_id = server_id_prefix + next(self._connection_seq)
104 |
105 | # Avoid connection ID collisions in the unlikely chance that a connection is
106 | # alive longer than it takes for the sequence to reset.
107 | while connection_id in self._connections:
108 | connection_id = server_id_prefix + next(self._connection_seq)
109 |
110 | return connection_id
111 |
--------------------------------------------------------------------------------
/mysql_mimic/errors.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 |
3 |
4 | class ErrorCode(IntEnum):
5 | """https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html"""
6 |
7 | CON_COUNT_ERROR = 1040
8 | HANDSHAKE_ERROR = 1043
9 | ACCESS_DENIED_ERROR = 1045
10 | NO_DB_ERROR = 1046
11 | PARSE_ERROR = 1064
12 | EMPTY_QUERY = 1065
13 | UNKNOWN_PROCEDURE = 1106
14 | UNKNOWN_SYSTEM_VARIABLE = 1193
15 | UNKNOWN_COM_ERROR = 1047
16 | UNKNOWN_ERROR = 1105
17 | WRONG_VALUE_FOR_VAR = 1231
18 | NOT_SUPPORTED_YET = 1235
19 | MALFORMED_PACKET = 1835
20 | USER_DOES_NOT_EXIST = 3162
21 | SESSION_WAS_KILLED = 3169
22 | PLUGIN_REQUIRES_REGISTRATION = 4055
23 |
24 |
25 | # For more info, see https://dev.mysql.com/doc/refman/8.0/en/error-message-elements.html
26 | SQLSTATES = {
27 | ErrorCode.CON_COUNT_ERROR: b"08004",
28 | ErrorCode.ACCESS_DENIED_ERROR: b"28000",
29 | ErrorCode.HANDSHAKE_ERROR: b"08S01",
30 | ErrorCode.NO_DB_ERROR: b"3D000",
31 | ErrorCode.PARSE_ERROR: b"42000",
32 | ErrorCode.EMPTY_QUERY: b"42000",
33 | ErrorCode.UNKNOWN_PROCEDURE: b"42000",
34 | ErrorCode.UNKNOWN_COM_ERROR: b"08S01",
35 | ErrorCode.WRONG_VALUE_FOR_VAR: b"42000",
36 | ErrorCode.NOT_SUPPORTED_YET: b"42000",
37 | }
38 |
39 |
40 | def get_sqlstate(code: ErrorCode) -> bytes:
41 | return SQLSTATES.get(code, b"HY000")
42 |
43 |
44 | class MysqlError(Exception):
45 | def __init__(self, msg: str, code: ErrorCode = ErrorCode.UNKNOWN_ERROR):
46 | super().__init__(f"{code}: {msg}")
47 | self.msg = msg
48 | self.code = code
49 |
--------------------------------------------------------------------------------
/mysql_mimic/functions.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping, Callable
2 | from datetime import datetime
3 |
4 | Functions = Mapping[str, Callable[[], Any]]
5 |
6 |
7 | def mysql_datetime_function_mapping(timestamp: datetime) -> Functions:
8 | functions = {
9 | "NOW": lambda: timestamp.strftime("%Y-%m-%d %H:%M:%S"),
10 | "CURDATE": lambda: timestamp.strftime("%Y-%m-%d"),
11 | "CURTIME": lambda: timestamp.strftime("%H:%M:%S"),
12 | }
13 | functions.update(
14 | {
15 | "CURRENT_TIMESTAMP": functions["NOW"],
16 | "LOCALTIME": functions["NOW"],
17 | "LOCALTIMESTAMP": functions["NOW"],
18 | "CURRENT_DATE": functions["CURDATE"],
19 | "CURRENT_TIME": functions["CURTIME"],
20 | }
21 | )
22 | return functions
23 |
--------------------------------------------------------------------------------
/mysql_mimic/intercept.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from sqlglot import expressions as exp
6 |
7 | from mysql_mimic.errors import MysqlError, ErrorCode
8 | from mysql_mimic.variables import DEFAULT
9 |
10 |
11 | # Mapping of transaction characteristic from SET TRANSACTION statements to their corresponding system variable
12 | TRANSACTION_CHARACTERISTICS = {
13 | "ISOLATION LEVEL REPEATABLE READ": ("transaction_isolation", "REPEATABLE-READ"),
14 | "ISOLATION LEVEL READ COMMITTED": ("transaction_isolation", "READ-COMMITTED"),
15 | "ISOLATION LEVEL READ UNCOMMITTED": ("transaction_isolation", "READ-UNCOMMITTED"),
16 | "ISOLATION LEVEL SERIALIZABLE": ("transaction_isolation", "SERIALIZABLE"),
17 | "READ WRITE": ("transaction_read_only", False),
18 | "READ ONLY": ("transaction_read_only", True),
19 | }
20 |
21 |
22 | def setitem_kind(setitem: exp.SetItem) -> str:
23 | kind = setitem.text("kind")
24 | if not kind:
25 | return "VARIABLE"
26 |
27 | if kind in {"GLOBAL", "PERSIST", "PERSIST_ONLY", "SESSION", "LOCAL"}:
28 | return "VARIABLE"
29 |
30 | return kind
31 |
32 |
33 | def value_to_expression(value: Any) -> exp.Expression:
34 | if value is True:
35 | return exp.true()
36 | if value is False:
37 | return exp.false()
38 | if value is None:
39 | return exp.null()
40 | if isinstance(value, (int, float)):
41 | return exp.Literal.number(value)
42 | return exp.Literal.string(str(value))
43 |
44 |
45 | def expression_to_value(expression: exp.Expression) -> Any:
46 | if expression == exp.true():
47 | return True
48 | if expression == exp.false():
49 | return False
50 | if expression == exp.null():
51 | return None
52 | if isinstance(expression, exp.Literal) and not expression.args.get("is_string"):
53 | try:
54 | return int(expression.this)
55 | except ValueError:
56 | return float(expression.this)
57 | if isinstance(expression, exp.Literal):
58 | return expression.name
59 | if expression.name == "DEFAULT":
60 | return DEFAULT
61 | if expression.name == "ON":
62 | return True
63 | if expression.name == "OFF":
64 | return False
65 | raise MysqlError(
66 | "Complex expressions in variables not supported yet",
67 | code=ErrorCode.NOT_SUPPORTED_YET,
68 | )
69 |
--------------------------------------------------------------------------------
/mysql_mimic/packets.py:
--------------------------------------------------------------------------------
1 | import io
2 | from dataclasses import dataclass, field
3 | from typing import Optional, Dict, Any, Sequence, Callable, Tuple, List, Union
4 |
5 | from mysql_mimic.charset import Collation, CharacterSet
6 | from mysql_mimic.constants import DEFAULT_SERVER_CAPABILITIES
7 | from mysql_mimic.errors import ErrorCode, get_sqlstate, MysqlError
8 | from mysql_mimic.prepared import PreparedStatement, REGEX_PARAM
9 | from mysql_mimic.results import NullBitmap, ResultColumn
10 | from mysql_mimic.types import (
11 | Capabilities,
12 | uint_2,
13 | uint_1,
14 | str_null,
15 | uint_4,
16 | str_fixed,
17 | read_uint_4,
18 | read_str_null,
19 | read_str_fixed,
20 | read_uint_1,
21 | read_str_len,
22 | read_uint_len,
23 | str_len,
24 | str_rest,
25 | uint_len,
26 | ColumnType,
27 | read_int_1,
28 | read_int_2,
29 | read_uint_2,
30 | read_int_4,
31 | read_uint_8,
32 | read_int_8,
33 | read_float,
34 | read_double,
35 | ResultsetMetadata,
36 | ColumnDefinition,
37 | ComStmtExecuteFlags,
38 | peek,
39 | ServerStatus,
40 | read_str_rest,
41 | )
42 |
43 |
44 | @dataclass
45 | class SSLRequest:
46 | max_packet_size: int
47 | capabilities: Capabilities
48 | client_charset: CharacterSet
49 |
50 |
51 | @dataclass
52 | class HandshakeResponse41:
53 | max_packet_size: int
54 | capabilities: Capabilities
55 | client_charset: CharacterSet
56 | username: str
57 | auth_response: bytes
58 | connect_attrs: Dict[str, str] = field(default_factory=dict)
59 | database: Optional[str] = None
60 | client_plugin: Optional[str] = None
61 | zstd_compression_level: int = 0
62 |
63 |
64 | @dataclass
65 | class ComChangeUser:
66 | username: str
67 | auth_response: bytes
68 | database: str
69 | client_charset: Optional[CharacterSet] = None
70 | client_plugin: Optional[str] = None
71 | connect_attrs: Dict[str, str] = field(default_factory=dict)
72 |
73 |
74 | @dataclass
75 | class ComQuery:
76 | sql: str
77 | query_attrs: Dict[str, str]
78 |
79 |
80 | @dataclass
81 | class ComStmtSendLongData:
82 | stmt_id: int
83 | param_id: int
84 | data: bytes
85 |
86 |
87 | @dataclass
88 | class ComStmtExecute:
89 | sql: str
90 | query_attrs: Dict[str, str]
91 | stmt: PreparedStatement
92 | use_cursor: bool
93 |
94 |
95 | @dataclass
96 | class ComStmtFetch:
97 | stmt_id: int
98 | num_rows: int
99 |
100 |
101 | @dataclass
102 | class ComStmtReset:
103 | stmt_id: int
104 |
105 |
106 | @dataclass
107 | class ComStmtClose:
108 | stmt_id: int
109 |
110 |
111 | @dataclass
112 | class ComFieldList:
113 | table: str
114 | wildcard: str
115 |
116 |
117 | def make_ok(
118 | capabilities: Capabilities,
119 | status_flags: ServerStatus,
120 | eof: bool = False,
121 | affected_rows: int = 0,
122 | last_insert_id: int = 0,
123 | warnings: int = 0,
124 | flags: int = 0,
125 | ) -> bytes:
126 | parts = [
127 | uint_1(0) if not eof else uint_1(0xFE),
128 | uint_len(affected_rows),
129 | uint_len(last_insert_id),
130 | ]
131 |
132 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities:
133 | parts.append(uint_2(status_flags | flags))
134 | parts.append(uint_2(warnings))
135 | elif Capabilities.CLIENT_TRANSACTIONS in capabilities:
136 | parts.append(uint_2(status_flags | flags))
137 |
138 | return _concat(*parts)
139 |
140 |
141 | def make_eof(
142 | capabilities: Capabilities,
143 | status_flags: ServerStatus,
144 | warnings: int = 0,
145 | flags: int = 0,
146 | ) -> bytes:
147 | parts = [uint_1(0xFE)]
148 |
149 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities:
150 | parts.append(uint_2(warnings))
151 | parts.append(uint_2(status_flags | flags))
152 |
153 | return _concat(*parts)
154 |
155 |
156 | def make_error(
157 | capabilities: Capabilities = DEFAULT_SERVER_CAPABILITIES,
158 | server_charset: CharacterSet = CharacterSet.utf8mb4,
159 | msg: Any = "",
160 | code: ErrorCode = ErrorCode.UNKNOWN_ERROR,
161 | ) -> bytes:
162 | parts = [uint_1(0xFF), uint_2(code)]
163 |
164 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities:
165 | parts.append(str_fixed(1, b"#"))
166 | parts.append(str_fixed(5, get_sqlstate(code)))
167 |
168 | parts.append(str_rest(server_charset.encode(str(msg))))
169 |
170 | return _concat(*parts)
171 |
172 |
173 | def make_handshake_v10(
174 | capabilities: Capabilities,
175 | server_charset: CharacterSet,
176 | server_version: str,
177 | connection_id: int,
178 | auth_data: bytes,
179 | status_flags: ServerStatus,
180 | auth_plugin_name: str,
181 | ) -> bytes:
182 | auth_plugin_data_len = (
183 | len(auth_data) if Capabilities.CLIENT_PLUGIN_AUTH in capabilities else 0
184 | )
185 |
186 | parts = [
187 | uint_1(10), # protocol version
188 | str_null(server_charset.encode(server_version)), # server version
189 | uint_4(connection_id), # connection ID
190 | str_null(auth_data[:8]), # plugin data
191 | uint_2(capabilities & 0xFFFF), # lower capabilities flag
192 | uint_1(server_charset), # lower character set
193 | uint_2(status_flags), # server status flag
194 | uint_2(capabilities >> 16), # higher capabilities flag
195 | uint_1(auth_plugin_data_len),
196 | str_fixed(10, bytes(10)), # reserved
197 | str_fixed(max(13, auth_plugin_data_len - 8), auth_data[8:]),
198 | ]
199 | if Capabilities.CLIENT_PLUGIN_AUTH in capabilities:
200 | parts.append(str_null(server_charset.encode(auth_plugin_name)))
201 | return _concat(*parts)
202 |
203 |
204 | def parse_handshake_response(
205 | capabilities: Capabilities, data: bytes
206 | ) -> Union[HandshakeResponse41, SSLRequest]:
207 | r = io.BytesIO(data)
208 |
209 | client_capabilities = Capabilities(read_uint_4(r))
210 |
211 | capabilities = capabilities & client_capabilities
212 |
213 | max_packet_size = read_uint_4(r)
214 | client_charset = Collation(read_uint_1(r)).charset
215 | read_str_fixed(r, 23)
216 |
217 | if not peek(r):
218 | return SSLRequest(
219 | max_packet_size=max_packet_size,
220 | capabilities=capabilities,
221 | client_charset=client_charset,
222 | )
223 |
224 | username = client_charset.decode(read_str_null(r))
225 |
226 | if Capabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA in capabilities:
227 | auth_response = read_str_len(r)
228 | else:
229 | l_auth_response = read_uint_1(r)
230 | auth_response = read_str_fixed(r, l_auth_response)
231 |
232 | response = HandshakeResponse41(
233 | max_packet_size=max_packet_size,
234 | capabilities=capabilities,
235 | client_charset=client_charset,
236 | username=username,
237 | auth_response=auth_response,
238 | )
239 |
240 | if Capabilities.CLIENT_CONNECT_WITH_DB in capabilities:
241 | response.database = client_charset.decode(read_str_null(r))
242 |
243 | if Capabilities.CLIENT_PLUGIN_AUTH in capabilities:
244 | response.client_plugin = client_charset.decode(read_str_null(r))
245 |
246 | if Capabilities.CLIENT_CONNECT_ATTRS in capabilities:
247 | response.connect_attrs = _read_connect_attrs(r, client_charset)
248 |
249 | if Capabilities.CLIENT_ZSTD_COMPRESSION_ALGORITHM in capabilities:
250 | response.zstd_compression_level = read_uint_1(r)
251 |
252 | return response
253 |
254 |
255 | def parse_handshake_response_41(
256 | capabilities: Capabilities, data: bytes
257 | ) -> HandshakeResponse41:
258 | response = parse_handshake_response(capabilities, data)
259 | assert isinstance(response, HandshakeResponse41)
260 | return response
261 |
262 |
263 | def make_auth_more_data(data: bytes) -> bytes:
264 | return _concat(uint_1(1), str_rest(data)) # status tag
265 |
266 |
267 | def make_auth_switch_request(
268 | server_charset: CharacterSet, plugin_name: str, plugin_provided_data: bytes
269 | ) -> bytes:
270 | return _concat(
271 | uint_1(254), # status tag
272 | str_null(server_charset.encode(plugin_name)),
273 | str_rest(plugin_provided_data),
274 | )
275 |
276 |
277 | def parse_com_change_user(
278 | capabilities: Capabilities, client_charset: CharacterSet, data: bytes
279 | ) -> ComChangeUser:
280 | r = io.BytesIO(data)
281 | username = client_charset.decode(read_str_null(r))
282 | if Capabilities.CLIENT_SECURE_CONNECTION in capabilities:
283 | l_auth_response = read_uint_1(r)
284 | auth_response = read_str_fixed(r, l_auth_response)
285 | else:
286 | auth_response = read_str_null(r)
287 | database = client_charset.decode(read_str_null(r))
288 |
289 | response = ComChangeUser(
290 | username=username, auth_response=auth_response, database=database
291 | )
292 |
293 | if peek(r): # more data available
294 | if Capabilities.CLIENT_PROTOCOL_41 in capabilities:
295 | client_charset = Collation(read_uint_2(r)).charset
296 | response.client_charset = client_charset
297 | if Capabilities.CLIENT_PLUGIN_AUTH in capabilities:
298 | response.client_plugin = client_charset.decode(read_str_null(r))
299 | if Capabilities.CLIENT_CONNECT_ATTRS in capabilities:
300 | response.connect_attrs = _read_connect_attrs(r, client_charset)
301 |
302 | return response
303 |
304 |
305 | def parse_com_query(
306 | capabilities: Capabilities, client_charset: CharacterSet, data: bytes
307 | ) -> ComQuery:
308 | r = io.BytesIO(data)
309 |
310 | if Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities:
311 | parameter_count = read_uint_len(r)
312 | read_uint_len(r) # parameter_set_count. Always 1.
313 | query_attrs = {
314 | k: v
315 | for k, v in _read_params(capabilities, client_charset, r, parameter_count)
316 | if k is not None
317 | }
318 | else:
319 | query_attrs = {}
320 |
321 | sql = r.read().decode(client_charset.codec)
322 |
323 | return ComQuery(
324 | sql=sql,
325 | query_attrs=query_attrs,
326 | )
327 |
328 |
329 | def make_column_count(capabilities: Capabilities, column_count: int) -> bytes:
330 | parts = []
331 |
332 | if Capabilities.CLIENT_OPTIONAL_RESULTSET_METADATA in capabilities:
333 | parts.append(uint_1(ResultsetMetadata.RESULTSET_METADATA_FULL))
334 |
335 | parts.append(uint_len(column_count))
336 |
337 | return _concat(*parts)
338 |
339 |
340 | # pylint: disable=too-many-arguments
341 | def make_column_definition_41(
342 | server_charset: CharacterSet,
343 | schema: Optional[str] = None,
344 | table: Optional[str] = None,
345 | org_table: Optional[str] = None,
346 | name: Optional[str] = None,
347 | org_name: Optional[str] = None,
348 | character_set: CharacterSet = CharacterSet.utf8mb4,
349 | column_length: int = 256,
350 | column_type: ColumnType = ColumnType.VARCHAR,
351 | flags: ColumnDefinition = ColumnDefinition(0),
352 | decimals: int = 0,
353 | is_com_field_list: bool = False,
354 | default: Optional[str] = None,
355 | ) -> bytes:
356 | schema = schema or ""
357 | table = table or ""
358 | org_table = org_table or table
359 | name = name or ""
360 | org_name = org_name or name
361 | parts = [
362 | str_len(b"def"),
363 | str_len(server_charset.encode(schema)),
364 | str_len(server_charset.encode(table)),
365 | str_len(server_charset.encode(org_table)),
366 | str_len(server_charset.encode(name)),
367 | str_len(server_charset.encode(org_name)),
368 | uint_len(0x0C), # Length of the following fields
369 | uint_2(character_set),
370 | uint_4(column_length),
371 | uint_1(column_type),
372 | uint_2(flags),
373 | uint_1(decimals),
374 | uint_2(0), # filler
375 | ]
376 | if is_com_field_list:
377 | if default is None:
378 | parts.append(uint_len(0))
379 | else:
380 | default_values = server_charset.encode(default)
381 | parts.extend([uint_len(len(default_values)), str_len(default_values)])
382 | return _concat(*parts)
383 |
384 |
385 | def make_text_resultset_row(
386 | row: Sequence[Any], columns: Sequence[ResultColumn]
387 | ) -> bytes:
388 | parts = []
389 |
390 | for value, column in zip(row, columns):
391 | if value is None:
392 | parts.append(b"\xfb")
393 | else:
394 | text = column.text_encode(value)
395 | parts.append(str_len(text))
396 |
397 | return _concat(*parts)
398 |
399 |
400 | def make_com_stmt_prepare_ok(statement: PreparedStatement) -> bytes:
401 | return _concat(
402 | uint_1(0), # OK
403 | uint_4(statement.stmt_id),
404 | uint_2(0), # number of columns
405 | uint_2(statement.num_params),
406 | uint_1(0), # filler
407 | uint_2(0), # number of warnings
408 | )
409 |
410 |
411 | def parse_com_stmt_send_long_data(data: bytes) -> ComStmtSendLongData:
412 | r = io.BytesIO(data)
413 | return ComStmtSendLongData(
414 | stmt_id=read_uint_4(r),
415 | param_id=read_uint_2(r),
416 | data=r.read(),
417 | )
418 |
419 |
420 | def parse_com_stmt_execute(
421 | capabilities: Capabilities,
422 | client_charset: CharacterSet,
423 | data: bytes,
424 | get_stmt: Callable[[int], PreparedStatement],
425 | ) -> ComStmtExecute:
426 | r = io.BytesIO(data)
427 | stmt_id = read_uint_4(r)
428 | stmt = get_stmt(stmt_id)
429 | use_cursor, param_count_available = _read_cursor_flags(r)
430 | read_uint_4(r) # iteration count. Always 1.
431 | sql, query_attrs = _interpolate_params(
432 | capabilities, client_charset, r, stmt, param_count_available
433 | )
434 | return ComStmtExecute(
435 | sql=sql,
436 | query_attrs=query_attrs,
437 | stmt=stmt,
438 | use_cursor=use_cursor,
439 | )
440 |
441 |
442 | def make_binary_resultrow(row: Sequence[Any], columns: Sequence[ResultColumn]) -> bytes:
443 | column_count = len(row)
444 |
445 | null_bitmap = NullBitmap.new(column_count, offset=2)
446 |
447 | values = []
448 | for i, (val, col) in enumerate(zip(row, columns)):
449 | if val is None:
450 | null_bitmap.flip(i)
451 | else:
452 | values.append(col.binary_encode(val))
453 |
454 | values_data = b"".join(values)
455 |
456 | null_bitmap_data = bytes(null_bitmap)
457 |
458 | return b"".join(
459 | [
460 | uint_1(0), # packet header
461 | str_fixed(len(null_bitmap_data), null_bitmap_data),
462 | str_fixed(len(values_data), values_data),
463 | ]
464 | )
465 |
466 |
467 | def parse_handle_stmt_fetch(data: bytes) -> ComStmtFetch:
468 | r = io.BytesIO(data)
469 | return ComStmtFetch(
470 | stmt_id=read_uint_4(r),
471 | num_rows=read_uint_4(r),
472 | )
473 |
474 |
475 | def parse_com_stmt_reset(data: bytes) -> ComStmtReset:
476 | r = io.BytesIO(data)
477 | return ComStmtReset(stmt_id=read_uint_4(r))
478 |
479 |
480 | def parse_com_stmt_close(data: bytes) -> ComStmtClose:
481 | r = io.BytesIO(data)
482 | return ComStmtClose(stmt_id=read_uint_4(r))
483 |
484 |
485 | def parse_com_init_db(client_charset: CharacterSet, data: bytes) -> str:
486 | return client_charset.decode(data)
487 |
488 |
489 | def parse_com_field_list(client_charset: CharacterSet, data: bytes) -> ComFieldList:
490 | r = io.BytesIO(data)
491 | return ComFieldList(
492 | table=client_charset.decode(read_str_null(r)),
493 | wildcard=client_charset.decode(read_str_rest(r)),
494 | )
495 |
496 |
497 | def _read_cursor_flags(reader: io.BytesIO) -> Tuple[bool, bool]:
498 | flags = ComStmtExecuteFlags(read_uint_1(reader))
499 | param_count_available = ComStmtExecuteFlags.PARAMETER_COUNT_AVAILABLE in flags
500 |
501 | if ComStmtExecuteFlags.CURSOR_TYPE_READ_ONLY in flags:
502 | return True, param_count_available
503 | if ComStmtExecuteFlags.CURSOR_TYPE_NO_CURSOR in flags:
504 | return False, param_count_available
505 | raise MysqlError(f"Unsupported cursor flags: {flags}", ErrorCode.NOT_SUPPORTED_YET)
506 |
507 |
508 | def _interpolate_params(
509 | capabilities: Capabilities,
510 | client_charset: CharacterSet,
511 | reader: io.BytesIO,
512 | stmt: PreparedStatement,
513 | param_count_available: bool,
514 | ) -> Tuple[str, Dict[str, str]]:
515 | sql = stmt.sql
516 | query_attrs = {}
517 | parameter_count = stmt.num_params
518 |
519 | if stmt.num_params > 0 or (
520 | Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities and param_count_available
521 | ):
522 | if Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities:
523 | parameter_count = read_uint_len(reader)
524 |
525 | if parameter_count > 0:
526 | # When there are query attributes, they are combined with statement parameters.
527 | # The statement parameters will be first, query attributes second.
528 | params = _read_params(
529 | capabilities, client_charset, reader, parameter_count, stmt.param_buffers
530 | )
531 |
532 | for _, value in params[: stmt.num_params]:
533 | sql = REGEX_PARAM.sub(_encode_param_as_sql(value), sql, 1)
534 |
535 | query_attrs = {k: v for k, v in params[stmt.num_params :] if k is not None}
536 |
537 | return sql, query_attrs
538 |
539 |
540 | def _encode_param_as_sql(param: Any) -> str:
541 | if isinstance(param, str):
542 | return f"'{param}'"
543 | if param is None:
544 | return "NULL"
545 | if param is True:
546 | return "TRUE"
547 | if param is False:
548 | return "FALSE"
549 | return str(param)
550 |
551 |
552 | def _read_params(
553 | capabilities: Capabilities,
554 | client_charset: CharacterSet,
555 | reader: io.BytesIO,
556 | parameter_count: int,
557 | buffers: Optional[Dict[int, bytearray]] = None,
558 | ) -> Sequence[Tuple[Optional[str], Any]]:
559 | """
560 | Read parameters from a stream.
561 |
562 | This is intended for reading query attributes from COM_QUERY and parameters from COM_STMT_EXECUTE.
563 |
564 | Returns:
565 | Name/value pairs. Name is None if there is no name, which is the case for stmt_execute, which
566 | combines statement parameters and query attributes.
567 | """
568 | params: List[Tuple[Optional[str], Any]] = []
569 |
570 | if parameter_count:
571 | null_bitmap = NullBitmap.from_buffer(reader, parameter_count)
572 | new_params_bound_flag = read_uint_1(reader)
573 |
574 | if not new_params_bound_flag:
575 | raise MysqlError(
576 | "Server requires the new-params-bound-flag to be set",
577 | ErrorCode.NOT_SUPPORTED_YET,
578 | )
579 |
580 | param_types = []
581 |
582 | for i in range(parameter_count):
583 | param_type, unsigned = _read_param_type(reader)
584 |
585 | if Capabilities.CLIENT_QUERY_ATTRIBUTES in capabilities:
586 | # Only query attributes have names
587 | # Statement parameters will have an empty name, e.g. b"\x00"
588 | param_name = client_charset.decode(read_str_len(reader))
589 | else:
590 | param_name = ""
591 |
592 | param_types.append((param_name, param_type, unsigned))
593 |
594 | for i, (param_name, param_type, unsigned) in enumerate(param_types):
595 | if null_bitmap.is_flipped(i):
596 | params.append((param_name, None))
597 | elif buffers and i in buffers:
598 | params.append((param_name, client_charset.decode(buffers[i])))
599 | else:
600 | params.append(
601 | (
602 | param_name,
603 | _read_param_value(client_charset, reader, param_type, unsigned),
604 | )
605 | )
606 |
607 | return params
608 |
609 |
610 | def _read_param_type(reader: io.BytesIO) -> Tuple[ColumnType, bool]:
611 | param_type = ColumnType(read_uint_1(reader))
612 | is_unsigned = (read_uint_1(reader) & 0x80) > 0
613 | return param_type, is_unsigned
614 |
615 |
616 | def _read_param_value(
617 | client_charset: CharacterSet,
618 | reader: io.BytesIO,
619 | param_type: ColumnType,
620 | unsigned: bool,
621 | ) -> Any:
622 | if param_type in {
623 | ColumnType.VARCHAR,
624 | ColumnType.VAR_STRING,
625 | ColumnType.STRING,
626 | ColumnType.BLOB,
627 | ColumnType.TINY_BLOB,
628 | ColumnType.MEDIUM_BLOB,
629 | ColumnType.LONG_BLOB,
630 | }:
631 | val = read_str_len(reader)
632 | return client_charset.decode(val)
633 |
634 | if param_type == ColumnType.TINY:
635 | return (read_uint_1 if unsigned else read_int_1)(reader)
636 |
637 | if param_type == ColumnType.BOOL:
638 | return read_uint_1(reader)
639 |
640 | if param_type in {ColumnType.SHORT, ColumnType.YEAR}:
641 | return (read_uint_2 if unsigned else read_int_2)(reader)
642 |
643 | if param_type in {ColumnType.LONG, ColumnType.INT24}:
644 | return (read_uint_4 if unsigned else read_int_4)(reader)
645 |
646 | if param_type == ColumnType.LONGLONG:
647 | return (read_uint_8 if unsigned else read_int_8)(reader)
648 |
649 | if param_type == ColumnType.FLOAT:
650 | return read_float(reader)
651 |
652 | if param_type == ColumnType.DOUBLE:
653 | return read_double(reader)
654 |
655 | if param_type == ColumnType.NULL:
656 | return None
657 |
658 | raise MysqlError(
659 | f"Unsupported parameter type: {param_type}", ErrorCode.NOT_SUPPORTED_YET
660 | )
661 |
662 |
663 | def _read_connect_attrs(
664 | reader: io.BytesIO, client_charset: CharacterSet
665 | ) -> Dict[str, str]:
666 | connect_attrs = {}
667 | total_l = read_uint_len(reader)
668 |
669 | while total_l > 0:
670 | key = read_str_len(reader)
671 | value = read_str_len(reader)
672 | connect_attrs[client_charset.decode(key)] = client_charset.decode(value)
673 |
674 | item_l = len(str_len(key) + str_len(value))
675 | total_l -= item_l
676 | return connect_attrs
677 |
678 |
679 | def _concat(*parts: bytes) -> bytes:
680 | return b"".join(parts)
681 |
--------------------------------------------------------------------------------
/mysql_mimic/prepared.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 | from typing import Optional, Dict, AsyncIterable
4 |
5 | # Borrowed from mysql-connector-python
6 | REGEX_PARAM = re.compile(r"""\?(?=(?:[^"'`]*["'`][^"'`]*["'`])*[^"'`]*$)""")
7 |
8 |
9 | @dataclass
10 | class PreparedStatement:
11 | stmt_id: int
12 | sql: str
13 | num_params: int
14 | param_buffers: Optional[Dict[int, bytearray]] = None
15 | cursor: Optional[AsyncIterable[bytes]] = None
16 |
--------------------------------------------------------------------------------
/mysql_mimic/results.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 | import struct
5 | from dataclasses import dataclass
6 | from datetime import datetime, date, timedelta
7 | from typing import (
8 | Iterable,
9 | Sequence,
10 | Optional,
11 | Callable,
12 | Any,
13 | Union,
14 | Tuple,
15 | Dict,
16 | AsyncIterable,
17 | cast,
18 | )
19 |
20 | from mysql_mimic.errors import MysqlError
21 | from mysql_mimic.types import ColumnType, str_len, uint_1, uint_2, uint_4
22 | from mysql_mimic.charset import CharacterSet
23 | from mysql_mimic.utils import aiterate
24 |
25 | Encoder = Callable[[Any, "ResultColumn"], bytes]
26 |
27 |
28 | class ResultColumn:
29 | """
30 | Column data for a result set
31 |
32 | Args:
33 | name: column name
34 | type: column type
35 | character_set: column character set. Only relevant for string columns.
36 | text_encoder: Optionally override the function used to encode values for MySQL's text protocol
37 | binary_encoder: Optionally override the function used to encode values for MySQL's binary protocol
38 | """
39 |
40 | def __init__(
41 | self,
42 | name: str,
43 | type: ColumnType, # pylint: disable=redefined-builtin
44 | character_set: CharacterSet = CharacterSet.utf8mb4,
45 | text_encoder: Optional[Encoder] = None,
46 | binary_encoder: Optional[Encoder] = None,
47 | ):
48 | self.name = name
49 | self.type = type
50 | self.character_set = character_set
51 | self.text_encoder = text_encoder or _TEXT_ENCODERS.get(type) or _unsupported
52 | self.binary_encoder = (
53 | binary_encoder or _BINARY_ENCODERS.get(type) or _unsupported
54 | )
55 |
56 | def text_encode(self, val: Any) -> bytes:
57 | return self.text_encoder(self, val)
58 |
59 | def binary_encode(self, val: Any) -> bytes:
60 | return self.binary_encoder(self, val)
61 |
62 | def __repr__(self) -> str:
63 | return f"ResultColumn({self.name} {self.type.name})"
64 |
65 |
66 | @dataclass
67 | class ResultSet:
68 | rows: Iterable[Sequence] | AsyncIterable[Sequence]
69 | columns: Sequence[ResultColumn]
70 |
71 | def __bool__(self) -> bool:
72 | return bool(self.columns)
73 |
74 |
75 | AllowedColumn = Union[ResultColumn, str]
76 | AllowedResult = Union[
77 | ResultSet,
78 | Tuple[Sequence[Sequence[Any]], Sequence[AllowedColumn]],
79 | Tuple[AsyncIterable[Sequence[Any]], Sequence[AllowedColumn]],
80 | None,
81 | ]
82 |
83 |
84 | async def ensure_result_set(result: AllowedResult) -> ResultSet:
85 | if result is None:
86 | return ResultSet([], [])
87 | if isinstance(result, ResultSet):
88 | return result
89 | if isinstance(result, tuple):
90 | if len(result) != 2:
91 | raise MysqlError(
92 | f"Result tuple should be of size 2. Received: {len(result)}"
93 | )
94 | rows = result[0]
95 | columns = result[1]
96 |
97 | return await _ensure_result_cols(rows, columns)
98 |
99 | raise MysqlError(f"Unexpected result set type: {type(result)}")
100 |
101 |
102 | async def _ensure_result_cols(
103 | rows: Sequence[Sequence[Any]] | AsyncIterable[Sequence[Any]],
104 | columns: Sequence[AllowedColumn],
105 | ) -> ResultSet:
106 | # Which columns need to be inferred?
107 | remaining = {
108 | col: i for i, col in enumerate(columns) if not isinstance(col, ResultColumn)
109 | }
110 |
111 | if not remaining:
112 | return ResultSet(
113 | rows=rows,
114 | columns=cast(Sequence[ResultColumn], columns),
115 | )
116 |
117 | # Copy the columns
118 | columns = list(columns)
119 |
120 | arows = aiterate(rows)
121 |
122 | # Keep track of rows we've consumed from the iterator so we can add them back
123 | peeks = []
124 |
125 | # Find the first non-null value for each column
126 | while remaining:
127 | try:
128 | peek = await arows.__anext__()
129 | except StopAsyncIteration:
130 | break
131 |
132 | peeks.append(peek)
133 |
134 | inferred = []
135 | for name, i in remaining.items():
136 | value = peek[i]
137 | if value is not None:
138 | type_ = infer_type(value)
139 | columns[i] = ResultColumn(
140 | name=str(name),
141 | type=type_,
142 | )
143 | inferred.append(name)
144 |
145 | for name in inferred:
146 | remaining.pop(name)
147 |
148 | # If we failed to find a non-null value, set the type to NULL
149 | for name, i in remaining.items():
150 | columns[i] = ResultColumn(
151 | name=str(name),
152 | type=ColumnType.NULL,
153 | )
154 |
155 | # Add the consumed rows back in to the iterator
156 | async def gen_rows() -> AsyncIterable[Sequence[Any]]:
157 | for row in peeks:
158 | yield row
159 |
160 | async for row in arows:
161 | yield row
162 |
163 | assert all(isinstance(col, ResultColumn) for col in columns)
164 | return ResultSet(rows=gen_rows(), columns=cast(Sequence[ResultColumn], columns))
165 |
166 |
167 | def _binary_encode_tiny(col: ResultColumn, val: Any) -> bytes:
168 | return uint_1(int(bool(val)))
169 |
170 |
171 | def _binary_encode_str(col: ResultColumn, val: Any) -> bytes:
172 | if not isinstance(val, bytes):
173 | val = str(val)
174 |
175 | if isinstance(val, str):
176 | val = val.encode(col.character_set.codec)
177 |
178 | return str_len(val)
179 |
180 |
181 | def _binary_encode_date(col: ResultColumn, val: Any) -> bytes:
182 | if isinstance(val, (float, int)):
183 | val = datetime.fromtimestamp(val)
184 |
185 | year = val.year
186 | month = val.month
187 | day = val.day
188 |
189 | if isinstance(val, datetime):
190 | hour = val.hour
191 | minute = val.minute
192 | second = val.second
193 | microsecond = val.microsecond
194 | else:
195 | hour = minute = second = microsecond = 0
196 |
197 | if microsecond == 0:
198 | if hour == minute == second == 0:
199 | if year == month == day == 0:
200 | return uint_1(0)
201 | return b"".join([uint_1(4), uint_2(year), uint_1(month), uint_1(day)])
202 | return b"".join(
203 | [
204 | uint_1(7),
205 | uint_2(year),
206 | uint_1(month),
207 | uint_1(day),
208 | uint_1(hour),
209 | uint_1(minute),
210 | uint_1(second),
211 | ]
212 | )
213 | return b"".join(
214 | [
215 | uint_1(11),
216 | uint_2(year),
217 | uint_1(month),
218 | uint_1(day),
219 | uint_1(hour),
220 | uint_1(minute),
221 | uint_1(second),
222 | uint_4(microsecond),
223 | ]
224 | )
225 |
226 |
227 | def _binary_encode_short(col: ResultColumn, val: Any) -> bytes:
228 | return struct.pack(" bytes:
232 | return struct.pack(" bytes:
236 | return struct.pack(" bytes:
240 | return struct.pack(" bytes:
244 | return struct.pack(" bytes:
248 | return struct.pack(" bytes:
252 | days = abs(val.days)
253 | hours, remainder = divmod(abs(val.seconds), 3600)
254 | minutes, seconds = divmod(remainder, 60)
255 | microseconds = val.microseconds
256 | is_negative = val.total_seconds() < 0
257 |
258 | if microseconds == 0:
259 | if days == hours == minutes == seconds == 0:
260 | return uint_1(0)
261 | return b"".join(
262 | [
263 | uint_1(8),
264 | uint_1(is_negative),
265 | uint_4(days),
266 | uint_1(hours),
267 | uint_1(minutes),
268 | uint_1(seconds),
269 | ]
270 | )
271 | return b"".join(
272 | [
273 | uint_1(12),
274 | uint_1(is_negative),
275 | uint_4(days),
276 | uint_1(hours),
277 | uint_1(minutes),
278 | uint_1(seconds),
279 | uint_4(microseconds),
280 | ]
281 | )
282 |
283 |
284 | def _text_encode_str(col: ResultColumn, val: Any) -> bytes:
285 | if not isinstance(val, bytes):
286 | val = str(val)
287 |
288 | if isinstance(val, str):
289 | val = val.encode(col.character_set.codec)
290 |
291 | return val
292 |
293 |
294 | def _text_encode_tiny(col: ResultColumn, val: Any) -> bytes:
295 | return _text_encode_str(col, int(val))
296 |
297 |
298 | def _unsupported(col: ResultColumn, val: Any) -> bytes:
299 | raise MysqlError(f"Unsupported column type: {col.type}")
300 |
301 |
302 | def _noop(col: ResultColumn, val: Any) -> bytes:
303 | return val
304 |
305 |
306 | # Order matters
307 | # bool is a subclass of int
308 | # datetime is a subclass of date
309 | _PY_TO_MYSQL_TYPE = {
310 | bool: ColumnType.TINY,
311 | datetime: ColumnType.DATETIME,
312 | str: ColumnType.STRING,
313 | bytes: ColumnType.BLOB,
314 | int: ColumnType.LONGLONG,
315 | float: ColumnType.DOUBLE,
316 | date: ColumnType.DATE,
317 | timedelta: ColumnType.TIME,
318 | }
319 |
320 |
321 | _TEXT_ENCODERS: Dict[ColumnType, Encoder] = {
322 | ColumnType.DECIMAL: _text_encode_str,
323 | ColumnType.TINY: _text_encode_tiny,
324 | ColumnType.SHORT: _text_encode_str,
325 | ColumnType.LONG: _text_encode_str,
326 | ColumnType.FLOAT: _text_encode_str,
327 | ColumnType.DOUBLE: _text_encode_str,
328 | ColumnType.NULL: _unsupported,
329 | ColumnType.TIMESTAMP: _text_encode_str,
330 | ColumnType.LONGLONG: _text_encode_str,
331 | ColumnType.INT24: _text_encode_str,
332 | ColumnType.DATE: _text_encode_str,
333 | ColumnType.TIME: _text_encode_str,
334 | ColumnType.DATETIME: _text_encode_str,
335 | ColumnType.YEAR: _text_encode_str,
336 | ColumnType.NEWDATE: _text_encode_str,
337 | ColumnType.VARCHAR: _text_encode_str,
338 | ColumnType.BIT: _text_encode_str,
339 | ColumnType.TIMESTAMP2: _unsupported,
340 | ColumnType.DATETIME2: _unsupported,
341 | ColumnType.TIME2: _unsupported,
342 | ColumnType.TYPED_ARRAY: _unsupported,
343 | ColumnType.INVALID: _unsupported,
344 | ColumnType.BOOL: _binary_encode_tiny,
345 | ColumnType.JSON: _text_encode_str,
346 | ColumnType.NEWDECIMAL: _text_encode_str,
347 | ColumnType.ENUM: _text_encode_str,
348 | ColumnType.SET: _text_encode_str,
349 | ColumnType.TINY_BLOB: _text_encode_str,
350 | ColumnType.MEDIUM_BLOB: _text_encode_str,
351 | ColumnType.LONG_BLOB: _text_encode_str,
352 | ColumnType.BLOB: _text_encode_str,
353 | ColumnType.VAR_STRING: _text_encode_str,
354 | ColumnType.STRING: _text_encode_str,
355 | ColumnType.GEOMETRY: _text_encode_str,
356 | }
357 |
358 | _BINARY_ENCODERS: Dict[ColumnType, Encoder] = {
359 | ColumnType.DECIMAL: _binary_encode_str,
360 | ColumnType.TINY: _binary_encode_tiny,
361 | ColumnType.SHORT: _binary_encode_short,
362 | ColumnType.LONG: _binary_encode_long,
363 | ColumnType.FLOAT: _binary_encode_float,
364 | ColumnType.DOUBLE: _binary_encode_double,
365 | ColumnType.NULL: _unsupported,
366 | ColumnType.TIMESTAMP: _binary_encode_date,
367 | ColumnType.LONGLONG: _binary_encode_longlong,
368 | ColumnType.INT24: _binary_encode_long,
369 | ColumnType.DATE: _binary_encode_date,
370 | ColumnType.TIME: _binary_encode_timedelta,
371 | ColumnType.DATETIME: _binary_encode_date,
372 | ColumnType.YEAR: _binary_encode_short,
373 | ColumnType.NEWDATE: _unsupported,
374 | ColumnType.VARCHAR: _binary_encode_str,
375 | ColumnType.BIT: _binary_encode_str,
376 | ColumnType.TIMESTAMP2: _unsupported,
377 | ColumnType.DATETIME2: _unsupported,
378 | ColumnType.TIME2: _unsupported,
379 | ColumnType.TYPED_ARRAY: _unsupported,
380 | ColumnType.INVALID: _unsupported,
381 | ColumnType.BOOL: _binary_encode_tiny,
382 | ColumnType.JSON: _binary_encode_str,
383 | ColumnType.NEWDECIMAL: _binary_encode_str,
384 | ColumnType.ENUM: _binary_encode_str,
385 | ColumnType.SET: _binary_encode_str,
386 | ColumnType.TINY_BLOB: _binary_encode_str,
387 | ColumnType.MEDIUM_BLOB: _binary_encode_str,
388 | ColumnType.LONG_BLOB: _binary_encode_str,
389 | ColumnType.BLOB: _binary_encode_str,
390 | ColumnType.VAR_STRING: _binary_encode_str,
391 | ColumnType.STRING: _binary_encode_str,
392 | ColumnType.GEOMETRY: _binary_encode_str,
393 | }
394 |
395 |
396 | def infer_type(val: Any) -> ColumnType:
397 | for py_type, my_type in _PY_TO_MYSQL_TYPE.items():
398 | if isinstance(val, py_type):
399 | return my_type
400 | return ColumnType.VARCHAR
401 |
402 |
403 | class NullBitmap:
404 | """See https://dev.mysql.com/doc/internals/en/null-bitmap.html"""
405 |
406 | __slots__ = ("offset", "bitmap")
407 |
408 | def __init__(self, bitmap: bytearray, offset: int = 0):
409 | self.offset = offset
410 | self.bitmap = bitmap
411 |
412 | @classmethod
413 | def new(cls, num_bits: int, offset: int = 0) -> NullBitmap:
414 | bitmap = bytearray(cls._num_bytes(num_bits, offset))
415 | return cls(bitmap, offset)
416 |
417 | @classmethod
418 | def from_buffer(
419 | cls, buffer: io.BytesIO, num_bits: int, offset: int = 0
420 | ) -> NullBitmap:
421 | bitmap = bytearray(buffer.read(cls._num_bytes(num_bits, offset)))
422 | return cls(bitmap, offset)
423 |
424 | @classmethod
425 | def _num_bytes(cls, num_bits: int, offset: int) -> int:
426 | return (num_bits + 7 + offset) // 8
427 |
428 | def flip(self, i: int) -> None:
429 | byte_position, bit_position = self._pos(i)
430 | self.bitmap[byte_position] |= 1 << bit_position
431 |
432 | def is_flipped(self, i: int) -> bool:
433 | byte_position, bit_position = self._pos(i)
434 | return bool(self.bitmap[byte_position] & (1 << bit_position))
435 |
436 | def _pos(self, i: int) -> Tuple[int, int]:
437 | byte_position = (i + self.offset) // 8
438 | bit_position = (i + self.offset) % 8
439 | return byte_position, bit_position
440 |
441 | def __bytes__(self) -> bytes:
442 | return bytes(self.bitmap)
443 |
444 | def __repr__(self) -> str:
445 | return "".join(format(b, "08b") for b in self.bitmap)
446 |
--------------------------------------------------------------------------------
/mysql_mimic/schema.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 | from collections import defaultdict
5 | from itertools import chain
6 | from dataclasses import dataclass
7 | from typing import Any, Optional, List, Dict, Iterable
8 |
9 | from sqlglot.executor import Table, execute
10 | from sqlglot import expressions as exp
11 |
12 | from mysql_mimic.constants import INFO_SCHEMA
13 | from mysql_mimic.results import AllowedResult
14 | from mysql_mimic.errors import MysqlError, ErrorCode
15 | from mysql_mimic.packets import ComFieldList
16 | from mysql_mimic.utils import dict_depth
17 |
18 |
19 | @dataclass
20 | class Column:
21 | name: str
22 | type: str
23 | table: str
24 | is_nullable: bool = True
25 | default: Optional[str] = None
26 | comment: Optional[str] = None
27 | schema: Optional[str] = None
28 | catalog: Optional[str] = "def"
29 |
30 |
31 | def mapping_to_columns(schema: dict) -> List[Column]:
32 | """Convert a schema mapping into a list of Column instances.
33 |
34 | Example schema defining columns by type:
35 | {
36 | "customer_table" : {
37 | "first_name" : "TEXT"
38 | "last_name" : "TEXT"
39 | "id" : "INT"
40 | },
41 | "sales_table" : {
42 | "ds" : "DATE"
43 | "customer_id" : "INT"
44 | "amount" : "DOUBLE"
45 | }
46 | }
47 | """
48 | depth = dict_depth(schema)
49 | if depth < 2:
50 | return []
51 | if depth == 2:
52 | # {table: {col: type}}
53 | schema = {"": schema}
54 | depth += 1
55 | if depth == 3:
56 | # {db: {table: {col: type}}}
57 | schema = {"def": schema} # def is the default MySQL catalog
58 | depth += 1
59 | if depth != 4:
60 | raise MysqlError("Invalid schema mapping")
61 |
62 | result = []
63 | for catalog, dbs in schema.items():
64 | for db, tables in dbs.items():
65 | for table, cols in tables.items():
66 | for column, coltype in cols.items():
67 | result.append(
68 | Column(
69 | name=column,
70 | type=coltype,
71 | table=table,
72 | schema=db,
73 | catalog=catalog,
74 | )
75 | )
76 |
77 | return result
78 |
79 |
80 | def info_schema_tables(columns: Iterable[Column]) -> Dict[str, Dict[str, Table]]:
81 | """
82 | Convert a list of Column instances into a mapping of SQLGlot Tables.
83 |
84 | These Tables are used by SQLGlot to execute INFORMATION_SCHEMA queries.
85 | """
86 | ordinal_positions: dict[Any, int] = defaultdict(lambda: 0)
87 |
88 | data = {
89 | db: {k: Table(tuple(v.keys())) for k, v in tables.items()}
90 | for db, tables in INFO_SCHEMA.items()
91 | }
92 |
93 | tables = set()
94 | dbs = set()
95 | catalogs = set()
96 |
97 | info_schema_cols = mapping_to_columns(INFO_SCHEMA)
98 |
99 | for column in chain(columns, info_schema_cols):
100 | tables.add((column.catalog, column.schema, column.table))
101 | dbs.add((column.catalog, column.schema))
102 | catalogs.add(column.catalog)
103 | key = (column.catalog, column.schema, column.table)
104 | ordinal_position = ordinal_positions[key]
105 | ordinal_positions[key] += 1
106 | data["information_schema"]["columns"].append(
107 | (
108 | column.catalog, # table_catalog
109 | column.schema, # table_schema
110 | column.table, # table_name
111 | column.name, # column_name
112 | ordinal_position, # ordinal_position
113 | column.default, # column_default
114 | "YES" if column.is_nullable else "NO", # is_nullable
115 | column.type, # data_type
116 | None, # character_maximum_length
117 | None, # character_octet_length
118 | None, # numeric_precision
119 | None, # numeric_scale
120 | None, # datetime_precision
121 | "NULL", # character_set_name
122 | "NULL", # collation_name
123 | column.type, # column_type
124 | None, # column_key
125 | None, # extra
126 | None, # privileges
127 | column.comment, # column_comment
128 | None, # generation_expression
129 | None, # srs_id
130 | )
131 | )
132 |
133 | for catalog, db, table in sorted(tables):
134 | data["information_schema"]["tables"].append(
135 | (
136 | catalog, # table_catalog
137 | db, # table_schema
138 | table, # table_name
139 | "SYSTEM TABLE" if db in INFO_SCHEMA else "BASE TABLE", # table_type
140 | "MinervaSQL", # engine
141 | "1.0", # version
142 | None, # row_format
143 | None, # table_rows
144 | None, # avg_row_length
145 | None, # data_length
146 | None, # max_data_length
147 | None, # index_length
148 | None, # data_free
149 | None, # auto_increment
150 | None, # create_time
151 | None, # update_time
152 | None, # check_time
153 | "utf8mb4_general_ci ", # table_collation
154 | None, # checksum
155 | None, # create_options
156 | None, # table_comment
157 | )
158 | )
159 |
160 | for catalog, db in sorted(dbs):
161 | data["information_schema"]["schemata"].append(
162 | (
163 | catalog, # catalog_name
164 | db, # schema_name
165 | "utf8mb4", # default_character_set_name
166 | "utf8mb4_general_ci", # default_collation_name
167 | None, # sql_path
168 | )
169 | )
170 |
171 | return data
172 |
173 |
174 | def show_statement_to_info_schema_query(
175 | show: exp.Show, database: Optional[str] = None
176 | ) -> exp.Select:
177 | kind = show.name.upper()
178 | if kind == "COLUMNS":
179 | outputs = [
180 | 'column_name AS "Field"',
181 | 'data_type AS "Type"',
182 | 'is_nullable AS "Null"',
183 | 'column_key AS "Key"',
184 | 'column_default AS "Default"',
185 | 'extra AS "Extra"',
186 | ]
187 | if show.args.get("full"):
188 | outputs.extend(
189 | [
190 | 'collation_name AS "Collation"',
191 | 'privileges AS "Privileges"',
192 | 'column_comment AS "Comment"',
193 | ]
194 | )
195 | table = show.text("target")
196 | if not table:
197 | raise MysqlError(
198 | "You have an error in your SQL syntax. Table name is missing.",
199 | code=ErrorCode.PARSE_ERROR,
200 | )
201 | select = (
202 | exp.select(*outputs)
203 | .from_("information_schema.columns")
204 | .where(f"table_name = '{table}'")
205 | )
206 | db = show.text("db") or database
207 | if db:
208 | select = select.where(f"table_schema = '{db}'")
209 | like = show.text("like")
210 | if like:
211 | select = select.where(f"column_name LIKE '{like}'")
212 | elif kind == "TABLES":
213 | outputs = ['table_name AS "Table_name"']
214 | if show.args.get("full"):
215 | outputs.extend(['table_type AS "Table_type"'])
216 |
217 | select = exp.select(*outputs).from_("information_schema.tables")
218 | db = show.text("db") or database
219 | if not db:
220 | raise MysqlError("No database selected.", code=ErrorCode.NO_DB_ERROR)
221 | select = select.where(f"table_schema = '{db}'")
222 | like = show.text("like")
223 | if like:
224 | select = select.where(f"table_name LIKE '{like}'")
225 | elif kind == "DATABASES":
226 | select = exp.select('schema_name AS "Database"').from_(
227 | "information_schema.schemata"
228 | )
229 | like = show.text("like")
230 | if like:
231 | select = select.where(f"schema_name LIKE '{like}'")
232 | elif kind == "INDEX":
233 | outputs = [
234 | '"table_name" AS Table',
235 | '"non_unique" AS Non_unique',
236 | '"index_name" AS Key_name',
237 | '"seq_in_index" AS Seq_in_index',
238 | '"column_name" AS Column_name',
239 | '"collation" AS Collation',
240 | '"cardinality" AS Cardinality',
241 | '"sub_part" AS Sub_part',
242 | '"packed" AS Packed',
243 | '"nullable" AS Null',
244 | '"index_type" AS Index_type',
245 | '"comment" AS Comment',
246 | '"index_comment" AS Index_comment',
247 | '"is_visible" AS Visible',
248 | '"expression" AS Expression',
249 | ]
250 | table = show.text("target")
251 | if not table:
252 | raise MysqlError(
253 | "You have an error in your SQL syntax. Table name is missing.",
254 | code=ErrorCode.PARSE_ERROR,
255 | )
256 | select = (
257 | exp.select(*outputs)
258 | .from_("information_schema.statistics")
259 | .where(f"table_name = '{table}'")
260 | )
261 | db = show.text("db") or database
262 | if db:
263 | select = select.where(f"table_schema = '{db}'")
264 | else:
265 | raise MysqlError(
266 | f"Unsupported SHOW command: {kind}", code=ErrorCode.NOT_SUPPORTED_YET
267 | )
268 |
269 | return select
270 |
271 |
272 | def com_field_list_to_show_statement(com_field_list: ComFieldList) -> str:
273 | show = exp.Show(
274 | this="COLUMNS",
275 | target=exp.to_identifier(com_field_list.table),
276 | )
277 | if com_field_list.wildcard:
278 | show.set("like", exp.Literal.string(com_field_list.wildcard))
279 | return show.sql(dialect="mysql")
280 |
281 |
282 | def like_to_regex(like: str) -> re.Pattern:
283 | like = like.replace("%", ".*?")
284 | like = like.replace("_", ".")
285 | return re.compile(like)
286 |
287 |
288 | class BaseInfoSchema:
289 | """
290 | Base InfoSchema interface used by the `Session` class.
291 | """
292 |
293 | async def query(self, expression: exp.Expression) -> AllowedResult: ...
294 |
295 |
296 | class InfoSchema(BaseInfoSchema):
297 | """
298 | InfoSchema implementation that uses SQLGlot to execute queries.
299 | """
300 |
301 | def __init__(self, tables: Dict[str, Dict[str, Table]]):
302 | self.tables = tables
303 |
304 | async def query(self, expression: exp.Expression) -> AllowedResult:
305 | expression = self._preprocess(expression)
306 | result = execute(expression, schema=INFO_SCHEMA, tables=self.tables)
307 | return result.rows, result.columns
308 |
309 | @classmethod
310 | def from_mapping(cls, mapping: dict) -> InfoSchema:
311 | columns = mapping_to_columns(mapping)
312 | return cls(info_schema_tables(columns))
313 |
314 | @classmethod
315 | def from_columns(cls, columns: List[Column]) -> InfoSchema:
316 | return cls(info_schema_tables(columns))
317 |
318 | def _preprocess(self, expression: exp.Expression) -> exp.Expression:
319 | return expression.transform(_remove_collate)
320 |
321 |
322 | def ensure_info_schema(schema: dict | BaseInfoSchema) -> BaseInfoSchema:
323 | if isinstance(schema, BaseInfoSchema):
324 | return schema
325 | return InfoSchema.from_mapping(schema)
326 |
327 |
328 | def _remove_collate(node: exp.Expression) -> exp.Expression:
329 | """
330 | SQLGlot's executor doesn't support collation specifiers.
331 |
332 | Just remove them for now.
333 | """
334 | if isinstance(node, exp.Collate):
335 | return node.left
336 | return node
337 |
--------------------------------------------------------------------------------
/mysql_mimic/server.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | import inspect
5 | import logging
6 | from ssl import SSLContext
7 | from socket import socket
8 | from typing import Callable, Any, Optional, Sequence, Awaitable
9 |
10 | from mysql_mimic import packets
11 | from mysql_mimic.auth import IdentityProvider, SimpleIdentityProvider
12 | from mysql_mimic.connection import Connection
13 | from mysql_mimic.control import Control, LocalControl, TooManyConnections
14 | from mysql_mimic.errors import ErrorCode
15 | from mysql_mimic.session import Session, BaseSession
16 | from mysql_mimic.constants import DEFAULT_SERVER_CAPABILITIES
17 | from mysql_mimic.stream import MysqlStream
18 | from mysql_mimic.types import Capabilities
19 |
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class MaxConnectionsExceeded(Exception):
25 | pass
26 |
27 |
28 | class MysqlServer:
29 | """
30 | MySQL mimic server.
31 |
32 | Args:
33 | session_factory: Callable that takes no arguments and returns a session
34 | capabilities: server capability flags
35 | control: Control instance to use. Defaults to a LocalControl instance.
36 | identity_provider: Authentication plugins to register. Defaults to `SimpleIdentityProvider`,
37 | which just blindly accepts whatever `username` is given by the client.
38 | ssl: SSLContext instance if this server should enable TLS over connections
39 |
40 | **kwargs: extra keyword args passed to the asyncio start server command
41 | """
42 |
43 | def __init__(
44 | self,
45 | session_factory: Callable[[], BaseSession | Awaitable[BaseSession]] = Session,
46 | capabilities: Capabilities = DEFAULT_SERVER_CAPABILITIES,
47 | control: Control | None = None,
48 | identity_provider: IdentityProvider | None = None,
49 | ssl: SSLContext | None = None,
50 | **serve_kwargs: Any,
51 | ):
52 | self.session_factory = session_factory
53 | self.capabilities = capabilities
54 | self.identity_provider = identity_provider or SimpleIdentityProvider()
55 | self.ssl = ssl
56 |
57 | self.control = control or LocalControl()
58 | self._serve_kwargs = serve_kwargs
59 | self._server: Optional[asyncio.base_events.Server] = None
60 |
61 | async def _client_connected_cb(
62 | self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
63 | ) -> None:
64 | stream = MysqlStream(reader, writer)
65 |
66 | try:
67 | if inspect.iscoroutinefunction(self.session_factory):
68 | session = await self.session_factory()
69 | else:
70 | session = self.session_factory()
71 |
72 | connection = Connection(
73 | stream=stream,
74 | session=session,
75 | control=self.control,
76 | server_capabilities=self.capabilities,
77 | identity_provider=self.identity_provider,
78 | ssl=self.ssl,
79 | )
80 |
81 | except Exception: # pylint: disable=broad-except
82 | logger.exception("Failed to create connection")
83 | await stream.write(
84 | # Return an error so clients don't freeze
85 | packets.make_error(capabilities=self.capabilities)
86 | )
87 | return
88 |
89 | try:
90 | connection_id = await self.control.add(connection)
91 | connection.connection_id = connection_id
92 | except TooManyConnections:
93 | await stream.write(
94 | connection.error(
95 | msg="Too many connections",
96 | code=ErrorCode.CON_COUNT_ERROR,
97 | )
98 | )
99 | return
100 | except Exception: # pylint: disable=broad-except
101 | logger.exception("Failed to register connection")
102 | await stream.write(connection.error(msg="Failed to register connection"))
103 | return
104 |
105 | try:
106 | return await connection.start()
107 | finally:
108 | writer.close()
109 | await self.control.remove(connection_id)
110 |
111 | async def start_server(self, **kwargs: Any) -> None:
112 | """
113 | Start an asyncio socket server.
114 |
115 | Args:
116 | **kwargs: keyword args passed to `asyncio.start_server`
117 | """
118 | kw = {}
119 | kw.update(self._serve_kwargs)
120 | kw.update(kwargs)
121 | if "port" not in kw:
122 | kw["port"] = 3306
123 | self._server = await asyncio.start_server(self._client_connected_cb, **kw)
124 |
125 | async def start_unix_server(self, **kwargs: Any) -> None:
126 | """
127 | Start an asyncio unix socket server.
128 |
129 | Args:
130 | **kwargs: keyword args passed to `asyncio.start_unix_server`
131 | """
132 | kw = {}
133 | kw.update(self._serve_kwargs)
134 | kw.update(kwargs)
135 | self._server = await asyncio.start_unix_server(self._client_connected_cb, **kw)
136 |
137 | async def serve_forever(self, **kwargs: Any) -> None:
138 | """
139 | Start accepting connections until the coroutine is cancelled.
140 |
141 | If a server isn't running, this starts a server with `start_server`.
142 |
143 | Args:
144 | **kwargs: keyword args passed to `start_server`
145 | """
146 | if not self._server:
147 | await self.start_server(**kwargs)
148 | assert self._server is not None
149 | await self._server.serve_forever()
150 |
151 | def close(self) -> None:
152 | """
153 | Stop serving.
154 |
155 | The server is closed asynchronously -
156 | use the `wait_closed` coroutine to wait until the server is closed.
157 | """
158 | if self._server:
159 | self._server.close()
160 |
161 | async def wait_closed(self) -> None:
162 | """Wait until the `close` method completes."""
163 | if self._server:
164 | await self._server.wait_closed()
165 |
166 | def sockets(self) -> Sequence[socket]:
167 | """Get sockets the server is listening on."""
168 | if self._server:
169 | return self._server.sockets
170 | return ()
171 |
--------------------------------------------------------------------------------
/mysql_mimic/session.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from datetime import datetime, timezone as timezone_
5 | from typing import (
6 | Dict,
7 | List,
8 | TYPE_CHECKING,
9 | Optional,
10 | Callable,
11 | Awaitable,
12 | Type,
13 | Any,
14 | )
15 |
16 | from sqlglot import Dialect
17 | from sqlglot.dialects import MySQL
18 | from sqlglot import expressions as exp
19 | from sqlglot.executor import execute
20 |
21 | from mysql_mimic.charset import CharacterSet
22 | from mysql_mimic.errors import ErrorCode, MysqlError
23 | from mysql_mimic.functions import Functions, mysql_datetime_function_mapping
24 | from mysql_mimic.intercept import (
25 | setitem_kind,
26 | expression_to_value,
27 | TRANSACTION_CHARACTERISTICS,
28 | )
29 | from mysql_mimic.schema import (
30 | show_statement_to_info_schema_query,
31 | like_to_regex,
32 | BaseInfoSchema,
33 | ensure_info_schema,
34 | )
35 | from mysql_mimic.constants import INFO_SCHEMA, KillKind
36 | from mysql_mimic.variable_processor import VariableProcessor
37 | from mysql_mimic.utils import find_dbs
38 | from mysql_mimic.variables import (
39 | Variables,
40 | SessionVariables,
41 | GlobalVariables,
42 | DEFAULT,
43 | parse_timezone,
44 | )
45 | from mysql_mimic.results import AllowedResult
46 |
47 | if TYPE_CHECKING:
48 | from mysql_mimic.connection import Connection
49 |
50 |
51 | Middleware = Callable[["Query"], Awaitable[AllowedResult]]
52 |
53 |
54 | def mysql_function_mapping(session: Session) -> Functions:
55 | # Information functions.
56 | # These will be replaced in the AST with their corresponding values.
57 | functions = {
58 | **mysql_datetime_function_mapping(session.timestamp),
59 | "CONNECTION_ID": lambda: session.connection.connection_id,
60 | "USER": lambda: session.variables.get("external_user"),
61 | "CURRENT_USER": lambda: session.username,
62 | "VERSION": lambda: session.variables.get("version"),
63 | "DATABASE": lambda: session.database,
64 | }
65 | # Synonyms
66 | functions.update(
67 | {
68 | "SYSTEM_USER": functions["USER"],
69 | "SESSION_USER": functions["USER"],
70 | "SCHEMA": functions["DATABASE"],
71 | }
72 | )
73 | return functions
74 |
75 |
76 | @dataclass
77 | class Query:
78 | """
79 | Convenience class that wraps the parameters to middleware.
80 |
81 | Args:
82 | expression: current query expression
83 | sql: the original SQL sent by the client
84 | attrs: query attributes
85 | _middlewares: subsequent middleware functions
86 | _query: the ultimate query method
87 | """
88 |
89 | expression: exp.Expression
90 | sql: str
91 | attrs: Dict[str, str]
92 | _middlewares: list[Middleware]
93 | _query: Callable[[exp.Expression, str, dict[str, str]], Awaitable[AllowedResult]]
94 |
95 | async def next(self) -> AllowedResult:
96 | """
97 | Call the next middleware in the chain of middlewares.
98 |
99 | Returns:
100 | The final query result.
101 | """
102 | if not self._middlewares:
103 | return await self._query(self.expression, self.sql, self.attrs)
104 | q = Query(
105 | expression=self.expression,
106 | sql=self.sql,
107 | attrs=self.attrs,
108 | _middlewares=self._middlewares[1:],
109 | _query=self._query,
110 | )
111 | return await self._middlewares[0](q)
112 |
113 | async def start(self) -> AllowedResult:
114 | """
115 | Start the middleware chain.
116 |
117 | This should only be called by the framework code
118 | """
119 | return await self._middlewares[0](self)
120 |
121 |
122 | class BaseSession:
123 | """
124 | Session interface.
125 |
126 | This defines what the Connection object depends on.
127 |
128 | Most applications to implement the abstract `Session`, not this class.
129 | """
130 |
131 | variables: Variables
132 | username: Optional[str]
133 | database: Optional[str]
134 |
135 | async def handle_query(self, sql: str, attrs: Dict[str, str]) -> AllowedResult:
136 | """
137 | Main entrypoint for queries.
138 |
139 | Args:
140 | sql: SQL statement
141 | attrs: Mapping of query attributes
142 | Returns:
143 | One of:
144 | - tuple(rows, column_names), where "rows" is a sequence of sequences
145 | and "column_names" is a sequence of strings with the same length
146 | as every row in "rows"
147 | - tuple(rows, result_columns), where "rows" is the same
148 | as above, and "result_columns" is a sequence of mysql_mimic.ResultColumn
149 | instances.
150 | - mysql_mimic.ResultSet instance
151 | """
152 |
153 | async def init(self, connection: Connection) -> None:
154 | """
155 | Called when connection phase is complete.
156 | """
157 |
158 | async def close(self) -> None:
159 | """
160 | Called when the client closes the connection.
161 | """
162 |
163 | async def reset(self) -> None:
164 | """
165 | Called when a client resets the connection and after a COM_CHANGE_USER command.
166 | """
167 |
168 | async def use(self, database: str) -> None:
169 | """
170 | Use a new default database.
171 |
172 | Called when a USE database_name command is received.
173 |
174 | Args:
175 | database: database name
176 | """
177 |
178 |
179 | class Session(BaseSession):
180 | """
181 | Abstract session.
182 |
183 | This automatically handles lots of behavior that many clients except,
184 | e.g. session variables, SHOW commands, queries to INFORMATION_SCHEMA, and more
185 | """
186 |
187 | dialect: Type[Dialect] = MySQL
188 |
189 | def __init__(self, variables: Variables | None = None):
190 | self.variables = variables or SessionVariables(GlobalVariables())
191 |
192 | # Query middlewares.
193 | # These allow queries to be intercepted or wrapped.
194 | self.middlewares: list[Middleware] = [
195 | self._set_var_middleware,
196 | self._set_middleware,
197 | self._static_query_middleware,
198 | self._use_middleware,
199 | self._kill_middleware,
200 | self._show_middleware,
201 | self._describe_middleware,
202 | self._begin_middleware,
203 | self._commit_middleware,
204 | self._rollback_middleware,
205 | self._info_schema_middleware,
206 | ]
207 |
208 | # Current database
209 | self.database = None
210 |
211 | # Current authenticated user
212 | self.username = None
213 |
214 | # Time when query started
215 | self.timestamp: datetime = datetime.now()
216 |
217 | self._connection: Optional[Connection] = None
218 |
219 | async def query(
220 | self, expression: exp.Expression, sql: str, attrs: Dict[str, str]
221 | ) -> AllowedResult:
222 | """
223 | Process a SQL query.
224 |
225 | Args:
226 | expression: parsed AST of the statement from client
227 | sql: original SQL statement from client
228 | attrs: arbitrary query attributes set by client
229 | Returns:
230 | One of:
231 | - tuple(rows, column_names), where "rows" is a sequence of sequences
232 | and "column_names" is a sequence of strings with the same length
233 | as every row in "rows"
234 | - tuple(rows, result_columns), where "rows" is the same
235 | as above, and "result_columns" is a sequence of mysql_mimic.ResultColumn
236 | instances.
237 | - mysql_mimic.ResultSet instance
238 | """
239 | return [], []
240 |
241 | async def schema(self) -> dict | BaseInfoSchema:
242 | """
243 | Provide the database schema.
244 |
245 | This is used to serve INFORMATION_SCHEMA and SHOW queries.
246 |
247 | Returns:
248 | One of:
249 | - Mapping of:
250 | {table: {column: column_type}} or
251 | {db: {table: {column: column_type}}} or
252 | {catalog: {db: {table: {column: column_type}}}}
253 | - Instance of `BaseInfoSchema`
254 | """
255 | return {}
256 |
257 | @property
258 | def connection(self) -> Connection:
259 | """
260 | Get the connection associated with this session.
261 | """
262 | if self._connection is None:
263 | raise AttributeError("Session is not yet bound")
264 | return self._connection
265 |
266 | async def init(self, connection: Connection) -> None:
267 | """
268 | Called when connection phase is complete.
269 | """
270 | self._connection = connection
271 |
272 | async def close(self) -> None:
273 | """
274 | Called when the client closes the connection.
275 | """
276 | self._connection = None
277 |
278 | async def handle_query(self, sql: str, attrs: Dict[str, str]) -> AllowedResult:
279 | self.timestamp = datetime.now(tz=self.timezone())
280 | result = None
281 | for expression in self._parse(sql):
282 | if not expression:
283 | continue
284 | q = Query(
285 | expression=expression,
286 | sql=sql,
287 | attrs=attrs,
288 | _middlewares=self.middlewares,
289 | _query=self.query,
290 | )
291 | result = await q.start()
292 | return result
293 |
294 | async def use(self, database: str) -> None:
295 | self.database = database
296 |
297 | def _parse(self, sql: str) -> List[exp.Expression]:
298 | return [e for e in self.dialect().parse(sql) if e]
299 |
300 | async def _query_info_schema(self, expression: exp.Expression) -> AllowedResult:
301 | return await ensure_info_schema(await self.schema()).query(expression)
302 |
303 | async def _set_var_middleware(self, q: Query) -> AllowedResult:
304 | """Handles SET_VAR hints and replaces functions defined in the _functions mapping with their mapped values."""
305 | with VariableProcessor(
306 | mysql_function_mapping(self), self.variables, q.expression
307 | ).set_variables():
308 | return await q.next()
309 |
310 | async def _use_middleware(self, q: Query) -> AllowedResult:
311 | """Intercept USE statements"""
312 | if isinstance(q.expression, exp.Use):
313 | await self.use(q.expression.this.name)
314 | return [], []
315 | return await q.next()
316 |
317 | async def _kill_middleware(self, q: Query) -> AllowedResult:
318 | """Intercept KILL statements"""
319 | if isinstance(q.expression, exp.Kill):
320 | control = self.connection.control
321 | if control:
322 | kind = KillKind[q.expression.text("kind").upper() or "CONNECTION"]
323 | this = q.expression.this.name
324 |
325 | try:
326 | connection_id = int(this)
327 | except ValueError as e:
328 | raise MysqlError(
329 | f"Invalid KILL connection ID: {this}",
330 | code=ErrorCode.PARSE_ERROR,
331 | ) from e
332 |
333 | await control.kill(connection_id, kind)
334 | return [], []
335 | return await q.next()
336 |
337 | async def _show_middleware(self, q: Query) -> AllowedResult:
338 | """Intercept SHOW statements"""
339 | if isinstance(q.expression, exp.Show):
340 | return await self._show(q.expression)
341 | return await q.next()
342 |
343 | async def _show(self, expression: exp.Show) -> AllowedResult:
344 | kind = expression.name.upper()
345 | if kind == "VARIABLES":
346 | return self._show_variables(expression)
347 | if kind == "STATUS":
348 | return self._show_status(expression)
349 | if kind == "WARNINGS":
350 | return self._show_warnings(expression)
351 | if kind == "ERRORS":
352 | return self._show_errors(expression)
353 | select = show_statement_to_info_schema_query(expression, self.database)
354 | return await self._query_info_schema(select)
355 |
356 | async def _describe_middleware(self, q: Query) -> AllowedResult:
357 | """Intercept DESCRIBE statements"""
358 | if isinstance(q.expression, exp.Describe):
359 | if isinstance(q.expression.this, exp.Select):
360 | # Mysql parse treats EXPLAIN SELECT as a DESCRIBE SELECT statement
361 | return await q.next()
362 | name = q.expression.this.name
363 | show = self.dialect().parse(f"SHOW COLUMNS FROM {name}")[0]
364 | return await self._show(show) if isinstance(show, exp.Show) else None
365 | return await q.next()
366 |
367 | async def _rollback_middleware(self, q: Query) -> AllowedResult:
368 | """Intercept ROLLBACK statements"""
369 | if isinstance(q.expression, exp.Rollback):
370 | return [], []
371 | return await q.next()
372 |
373 | async def _commit_middleware(self, q: Query) -> AllowedResult:
374 | """Intercept COMMIT statements"""
375 | if isinstance(q.expression, exp.Commit):
376 | return [], []
377 | return await q.next()
378 |
379 | async def _begin_middleware(self, q: Query) -> AllowedResult:
380 | """Intercept BEGIN statements"""
381 | if isinstance(q.expression, exp.Transaction):
382 | return [], []
383 | return await q.next()
384 |
385 | async def _static_query_middleware(self, q: Query) -> AllowedResult:
386 | """
387 | Handle static queries (e.g. SELECT 1).
388 |
389 | These very common, as many clients execute commands like SELECT DATABASE() when connecting.
390 | """
391 | if isinstance(q.expression, exp.Select) and not any(
392 | q.expression.args.get(a)
393 | for a in set(exp.Select.arg_types) - {"expressions", "limit", "hint"}
394 | ):
395 | result = execute(q.expression)
396 | return result.rows, result.columns
397 | return await q.next()
398 |
399 | async def _set_middleware(self, q: Query) -> AllowedResult:
400 | """Intercept SET statements"""
401 | if isinstance(q.expression, exp.Set):
402 | expressions = q.expression.expressions
403 | for item in expressions:
404 | assert isinstance(item, exp.SetItem)
405 |
406 | kind = setitem_kind(item)
407 |
408 | if kind == "VARIABLE":
409 | self._set_variable(item)
410 | elif kind == "CHARACTER SET":
411 | self._set_charset(item)
412 | elif kind == "NAMES":
413 | self._set_names(item)
414 | elif kind == "TRANSACTION":
415 | self._set_transaction(item)
416 | else:
417 | raise MysqlError(
418 | f"Unsupported SET statement: {kind}",
419 | code=ErrorCode.NOT_SUPPORTED_YET,
420 | )
421 |
422 | return [], []
423 | return await q.next()
424 |
425 | async def _info_schema_middleware(self, q: Query) -> AllowedResult:
426 | """Intercept queries to INFORMATION_SCHEMA tables"""
427 | dbs = find_dbs(q.expression)
428 | if (self.database and self.database.lower() in INFO_SCHEMA) or (
429 | dbs and all(db.lower() in INFO_SCHEMA for db in dbs)
430 | ):
431 | return await self._query_info_schema(q.expression)
432 | return await q.next()
433 |
434 | def _set_variable(self, setitem: exp.SetItem) -> None:
435 | assignment = setitem.this
436 | left = assignment.left
437 |
438 | if isinstance(left, exp.SessionParameter):
439 | scope = left.text("kind") or "SESSION"
440 | name = left.name
441 | elif isinstance(left, exp.Parameter):
442 | raise MysqlError(
443 | "User-defined variables not supported yet",
444 | code=ErrorCode.NOT_SUPPORTED_YET,
445 | )
446 | else:
447 | scope = setitem.text("kind") or "SESSION"
448 | name = left.name
449 |
450 | scope = scope.upper()
451 | value = expression_to_value(assignment.right)
452 |
453 | if scope in {"SESSION", "LOCAL"}:
454 | self.variables.set(name, value)
455 | else:
456 | raise MysqlError(
457 | f"Cannot SET variable {name} with scope {scope}",
458 | code=ErrorCode.NOT_SUPPORTED_YET,
459 | )
460 |
461 | def _set_charset(self, item: exp.SetItem) -> None:
462 | charset_name: Any
463 | charset_conn: Any
464 | if item.name == "DEFAULT":
465 | charset_name = DEFAULT
466 | charset_conn = DEFAULT
467 | else:
468 | charset_name = item.name
469 | charset_conn = self.variables.get("character_set_database")
470 | self.variables.set("character_set_client", charset_name)
471 | self.variables.set("character_set_results", charset_name)
472 | self.variables.set("character_set_connection", charset_conn)
473 |
474 | def _set_names(self, item: exp.SetItem) -> None:
475 | charset_name: Any
476 | collation_name: Any
477 | if item.name == "DEFAULT":
478 | charset_name = DEFAULT
479 | collation_name = DEFAULT
480 | else:
481 | charset_name = item.name
482 | collation_name = (
483 | item.text("collate")
484 | or CharacterSet[charset_name].default_collation.name
485 | )
486 | self.variables.set("character_set_client", charset_name)
487 | self.variables.set("character_set_connection", charset_name)
488 | self.variables.set("character_set_results", charset_name)
489 | self.variables.set("collation_connection", collation_name)
490 |
491 | def _set_transaction(self, item: exp.SetItem) -> None:
492 | characteristics = [e.name.upper() for e in item.expressions]
493 | for characteristic in characteristics:
494 | variable, value = TRANSACTION_CHARACTERISTICS[characteristic]
495 | self.variables.set(variable, value)
496 |
497 | def _show_variables(self, show: exp.Show) -> AllowedResult:
498 | rows = [(k, None if v is None else str(v)) for k, v in self.variables.list()]
499 | like = show.text("like")
500 | if like:
501 | rows = [(k, v) for k, v in rows if like_to_regex(like).match(k)]
502 | return rows, ["Variable_name", "Value"]
503 |
504 | def _show_status(self, show: exp.Show) -> AllowedResult:
505 | return [], ["Variable_name", "Value"]
506 |
507 | def _show_warnings(self, show: exp.Show) -> AllowedResult:
508 | return [], ["Level", "Code", "Message"]
509 |
510 | def _show_errors(self, show: exp.Show) -> AllowedResult:
511 | return [], ["Level", "Code", "Message"]
512 |
513 | def timezone(self) -> timezone_:
514 | tz = self.variables.get("time_zone")
515 | return parse_timezone(tz)
516 |
--------------------------------------------------------------------------------
/mysql_mimic/stream.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import struct
3 | from ssl import SSLContext
4 |
5 | from mysql_mimic.errors import MysqlError, ErrorCode
6 | from mysql_mimic.types import uint_3, uint_1
7 | from mysql_mimic.utils import seq
8 |
9 |
10 | class ConnectionClosed(Exception):
11 | pass
12 |
13 |
14 | class MysqlStream:
15 | def __init__(
16 | self,
17 | reader: asyncio.StreamReader,
18 | writer: asyncio.StreamWriter,
19 | buffer_size: int = 2**15,
20 | ):
21 | self.reader = reader
22 | self.writer = writer
23 | self.seq = seq(256)
24 | self._buffer = bytearray()
25 | self._buffer_size = buffer_size
26 |
27 | async def read(self) -> bytes:
28 | data = b""
29 | while True:
30 | header = await self.reader.read(4)
31 |
32 | if not header:
33 | raise ConnectionClosed()
34 |
35 | i = struct.unpack("> 24
38 |
39 | expected = next(self.seq)
40 | if sequence_id != expected:
41 | raise MysqlError(
42 | f"Expected seq({expected}) got seq({sequence_id})",
43 | ErrorCode.MALFORMED_PACKET,
44 | )
45 |
46 | if payload_length == 0:
47 | return data
48 |
49 | data += await self.reader.readexactly(payload_length)
50 |
51 | if payload_length < 0xFFFFFF:
52 | return data
53 |
54 | async def write(self, data: bytes, drain: bool = True) -> None:
55 | while True:
56 | # Grab first 0xFFFFFF bytes to send
57 | payload = data[:0xFFFFFF]
58 | data = data[0xFFFFFF:]
59 |
60 | payload_length = uint_3(len(payload))
61 | sequence_id = uint_1(next(self.seq))
62 | packet = payload_length + sequence_id + payload
63 |
64 | self._buffer.extend(packet)
65 | if drain or len(self._buffer) >= self._buffer_size:
66 | await self.drain()
67 |
68 | # We are done unless len(send) == 0xFFFFFF
69 | if len(payload) != 0xFFFFFF:
70 | return
71 |
72 | async def drain(self) -> None:
73 | if self._buffer:
74 | self.writer.write(self._buffer)
75 | self._buffer.clear()
76 | await self.writer.drain()
77 |
78 | def reset_seq(self) -> None:
79 | self.seq.reset()
80 |
81 | async def start_tls(self, ssl: SSLContext) -> None:
82 | transport = self.writer.transport
83 | protocol = transport.get_protocol()
84 | loop = asyncio.get_event_loop()
85 | new_transport = await loop.start_tls(
86 | transport=transport,
87 | protocol=protocol,
88 | sslcontext=ssl,
89 | server_side=True,
90 | )
91 |
92 | # This seems to be the easiest way to wrap the socket created by asyncio
93 | self.writer._transport = new_transport # type: ignore # pylint: disable=protected-access
94 | self.reader._transport = new_transport # type: ignore # pylint: disable=protected-access
95 |
--------------------------------------------------------------------------------
/mysql_mimic/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import io
3 | import struct
4 |
5 | from enum import IntEnum, IntFlag, auto
6 |
7 |
8 | class ColumnType(IntEnum):
9 | DECIMAL = 0x00
10 | TINY = 0x01
11 | SHORT = 0x02
12 | LONG = 0x03
13 | FLOAT = 0x04
14 | DOUBLE = 0x05
15 | NULL = 0x06
16 | TIMESTAMP = 0x07
17 | LONGLONG = 0x08
18 | INT24 = 0x09
19 | DATE = 0x0A
20 | TIME = 0x0B
21 | DATETIME = 0x0C
22 | YEAR = 0x0D
23 | NEWDATE = 0x0E
24 | VARCHAR = 0x0F
25 | BIT = 0x10
26 | TIMESTAMP2 = 0x11
27 | DATETIME2 = 0x12
28 | TIME2 = 0x13
29 | TYPED_ARRAY = 0x14
30 | INVALID = 0xF3
31 | BOOL = 0xF4
32 | JSON = 0xF5
33 | NEWDECIMAL = 0xF6
34 | ENUM = 0xF7
35 | SET = 0xF8
36 | TINY_BLOB = 0xF9
37 | MEDIUM_BLOB = 0xFA
38 | LONG_BLOB = 0xFB
39 | BLOB = 0xFC
40 | VAR_STRING = 0xFD
41 | STRING = 0xFE
42 | GEOMETRY = 0xFF
43 |
44 |
45 | class Commands(IntEnum):
46 | COM_SLEEP = 0x00
47 | COM_QUIT = 0x01
48 | COM_INIT_DB = 0x02
49 | COM_QUERY = 0x03
50 | COM_FIELD_LIST = 0x04
51 | COM_CREATE_DB = 0x05
52 | COM_DROP_DB = 0x06
53 | COM_REFRESH = 0x07
54 | COM_SHUTDOWN = 0x08
55 | COM_STATISTICS = 0x09
56 | COM_PROCESS_INFO = 0x0A
57 | COM_CONNECT = 0x0B
58 | COM_PROCESS_KILL = 0x0C
59 | COM_DEBUG = 0x0D
60 | COM_PING = 0x0E
61 | COM_TIME = 0x0F
62 | COM_DELAYED_INSERT = 0x10
63 | COM_CHANGE_USER = 0x11
64 | COM_BINLOG_DUMP = 0x12
65 | COM_TABLE_DUMP = 0x13
66 | COM_CONNECT_OUT = 0x14
67 | COM_REGISTER_SLAVE = 0x15
68 | COM_STMT_PREPARE = 0x16
69 | COM_STMT_EXECUTE = 0x17
70 | COM_STMT_SEND_LONG_DATA = 0x18
71 | COM_STMT_CLOSE = 0x19
72 | COM_STMT_RESET = 0x1A
73 | COM_SET_OPTION = 0x1B
74 | COM_STMT_FETCH = 0x1C
75 | COM_DAEMON = 0x1D
76 | COM_BINLOG_DUMP_GTID = 0x1E
77 | COM_RESET_CONNECTION = 0x1F
78 |
79 |
80 | class ColumnDefinition(IntFlag):
81 | NOT_NULL_FLAG = auto()
82 | PRI_KEY_FLAG = auto()
83 | UNIQUE_KEY_FLAG = auto()
84 | MULTIPLE_KEY_FLAG = auto()
85 | BLOB_FLAG = auto()
86 | UNSIGNED_FLAG = auto()
87 | ZEROFILL_FLAG = auto()
88 | BINARY_FLAG = auto()
89 | ENUM_FLAG = auto()
90 | AUTO_INCREMENT_FLAG = auto()
91 | TIMESTAMP_FLAG = auto()
92 | SET_FLAG = auto()
93 | NO_DEFAULT_VALUE_FLAG = auto()
94 | ON_UPDATE_NOW_FLAG = auto()
95 | NUM_FLAG = auto()
96 | PART_KEY_FLAG = auto()
97 | GROUP_FLAG = auto()
98 | UNIQUE_FLAG = auto()
99 | BINCMP_FLAG = auto()
100 | GET_FIXED_FIELDS_FLAG = auto()
101 | FIELD_IN_PART_FUNC_FLAG = auto()
102 | FIELD_IN_ADD_INDEX = auto()
103 | FIELD_IS_RENAMED = auto()
104 | FIELD_FLAGS_STORAGE_MEDIA = auto()
105 | FIELD_FLAGS_STORAGE_MEDIA_MASK = auto()
106 | FIELD_FLAGS_COLUMN_FORMAT = auto()
107 | FIELD_FLAGS_COLUMN_FORMAT_MASK = auto()
108 | FIELD_IS_DROPPED = auto()
109 | EXPLICIT_NULL_FLAG = auto()
110 | NOT_SECONDARY_FLAG = auto()
111 | FIELD_IS_INVISIBLE = auto()
112 |
113 |
114 | class Capabilities(IntFlag):
115 | CLIENT_LONG_PASSWORD = auto()
116 | CLIENT_FOUND_ROWS = auto()
117 | CLIENT_LONG_FLAG = auto()
118 | CLIENT_CONNECT_WITH_DB = auto()
119 | CLIENT_NO_SCHEMA = auto()
120 | CLIENT_COMPRESS = auto()
121 | CLIENT_ODBC = auto()
122 | CLIENT_LOCAL_FILES = auto()
123 | CLIENT_IGNORE_SPACE = auto()
124 | CLIENT_PROTOCOL_41 = auto()
125 | CLIENT_INTERACTIVE = auto()
126 | CLIENT_SSL = auto()
127 | CLIENT_IGNORE_SIGPIPE = auto()
128 | CLIENT_TRANSACTIONS = auto()
129 | CLIENT_RESERVED = auto()
130 | CLIENT_SECURE_CONNECTION = auto()
131 | CLIENT_MULTI_STATEMENTS = auto()
132 | CLIENT_MULTI_RESULTS = auto()
133 | CLIENT_PS_MULTI_RESULTS = auto()
134 | CLIENT_PLUGIN_AUTH = auto()
135 | CLIENT_CONNECT_ATTRS = auto()
136 | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = auto()
137 | CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = auto()
138 | CLIENT_SESSION_TRACK = auto()
139 | CLIENT_DEPRECATE_EOF = auto()
140 | CLIENT_OPTIONAL_RESULTSET_METADATA = auto()
141 | CLIENT_ZSTD_COMPRESSION_ALGORITHM = auto()
142 | CLIENT_QUERY_ATTRIBUTES = auto()
143 | MULTI_FACTOR_AUTHENTICATION = auto()
144 | CLIENT_CAPABILITY_EXTENSION = auto()
145 | CLIENT_SSL_VERIFY_SERVER_CERT = auto()
146 | CLIENT_REMEMBER_OPTIONS = auto()
147 |
148 |
149 | class ServerStatus(IntFlag):
150 | SERVER_STATUS_IN_TRANS = 0x0001
151 | SERVER_STATUS_AUTOCOMMIT = 0x0002
152 | SERVER_MORE_RESULTS_EXISTS = 0x0008
153 | SERVER_STATUS_NO_GOOD_INDEX_USED = 0x0010
154 | SERVER_STATUS_NO_INDEX_USED = 0x0020
155 | SERVER_STATUS_CURSOR_EXISTS = 0x0040
156 | SERVER_STATUS_LAST_ROW_SENT = 0x0080
157 | SERVER_STATUS_DB_DROPPED = 0x0100
158 | SERVER_STATUS_NO_BACKSLASH_ESCAPES = 0x0200
159 | SERVER_STATUS_METADATA_CHANGED = 0x0400
160 | SERVER_QUERY_WAS_SLOW = 0x0800
161 | SERVER_PS_OUT_PARAMS = 0x1000
162 | SERVER_STATUS_IN_TRANS_READONLY = 0x2000
163 | SERVER_SESSION_STATE_CHANGED = 0x4000
164 |
165 |
166 | class ResultsetMetadata(IntEnum):
167 | RESULTSET_METADATA_NONE = 0
168 | RESULTSET_METADATA_FULL = 1
169 |
170 |
171 | class ComStmtExecuteFlags(IntFlag):
172 | CURSOR_TYPE_NO_CURSOR = 0x00
173 | CURSOR_TYPE_READ_ONLY = 0x01
174 | CURSOR_TYPE_FOR_UPDATE = 0x02
175 | CURSOR_TYPE_SCROLLABLE = 0x04
176 | PARAMETER_COUNT_AVAILABLE = 0x08
177 |
178 |
179 | def uint_len(i: int) -> bytes:
180 | if i < 251:
181 | return struct.pack(" bytes:
191 | return struct.pack(" bytes:
195 | return struct.pack(" bytes:
199 | return struct.pack("> 16)
200 |
201 |
202 | def uint_4(i: int) -> bytes:
203 | return struct.pack(" bytes:
207 | return struct.pack("> 32)
208 |
209 |
210 | def uint_8(i: int) -> bytes:
211 | return struct.pack(" bytes:
215 | return struct.pack(f"<{l}s", s)
216 |
217 |
218 | def str_null(s: bytes) -> bytes:
219 | l = len(s)
220 | return struct.pack(f"<{l}sB", s, 0)
221 |
222 |
223 | def str_len(s: bytes) -> bytes:
224 | l = len(s)
225 | return uint_len(l) + str_fixed(l, s)
226 |
227 |
228 | def str_rest(s: bytes) -> bytes:
229 | l = len(s)
230 | return str_fixed(l, s)
231 |
232 |
233 | def read_int_1(reader: io.BytesIO) -> int:
234 | data = reader.read(1)
235 | return struct.unpack(" int:
239 | data = reader.read(1)
240 | return struct.unpack(" int:
244 | data = reader.read(2)
245 | return struct.unpack(" int:
249 | data = reader.read(2)
250 | return struct.unpack(" int:
254 | data = reader.read(3)
255 | t = struct.unpack(" int:
260 | data = reader.read(4)
261 | return struct.unpack(" int:
265 | data = reader.read(4)
266 | return struct.unpack(" int:
270 | data = reader.read(6)
271 | t = struct.unpack(" int:
276 | data = reader.read(8)
277 | return struct.unpack(" int:
281 | data = reader.read(8)
282 | return struct.unpack(" float:
286 | data = reader.read(4)
287 | return struct.unpack(" float:
291 | data = reader.read(8)
292 | return struct.unpack(" int:
296 | i = read_uint_1(reader)
297 |
298 | if i == 0xFE:
299 | return read_uint_8(reader)
300 |
301 | if i == 0xFD:
302 | return read_uint_3(reader)
303 |
304 | if i == 0xFC:
305 | return read_uint_2(reader)
306 |
307 | return i
308 |
309 |
310 | def read_str_fixed(reader: io.BytesIO, l: int) -> bytes:
311 | return reader.read(l)
312 |
313 |
314 | def read_str_null(reader: io.BytesIO) -> bytes:
315 | data = b""
316 | while True:
317 | b = reader.read(1)
318 | if b == b"\x00":
319 | return data
320 | data += b
321 |
322 |
323 | def read_str_len(reader: io.BytesIO) -> bytes:
324 | l = read_uint_len(reader)
325 | return read_str_fixed(reader, l)
326 |
327 |
328 | def read_str_rest(reader: io.BytesIO) -> bytes:
329 | return reader.read()
330 |
331 |
332 | def peek(reader: io.BytesIO, num_bytes: int = 1) -> bytes:
333 | pos = reader.tell()
334 | val = reader.read(num_bytes)
335 | reader.seek(pos)
336 | return val
337 |
--------------------------------------------------------------------------------
/mysql_mimic/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 | import inspect
5 | import sys
6 | from collections.abc import Iterator # pylint: disable=import-error
7 | import random
8 | from typing import List, TypeVar, AsyncIterable, Iterable, AsyncIterator, cast
9 | import string
10 |
11 | from sqlglot import expressions as exp
12 | from sqlglot.optimizer.scope import traverse_scope
13 |
14 |
15 | T = TypeVar("T")
16 |
17 | # MySQL Connector/J uses ASCII to decode nonce
18 | SAFE_NONCE_CHARS = (string.ascii_letters + string.digits).encode()
19 |
20 |
21 | class seq(Iterator):
22 | """Auto-incrementing sequence with an optional maximum size"""
23 |
24 | def __init__(self, size: int | None = None):
25 | self.size = size
26 | self.value = 0
27 |
28 | def __next__(self) -> int:
29 | value = self.value
30 | self.value = self.value + 1
31 | if self.size:
32 | self.value = self.value % self.size
33 | return value
34 |
35 | def reset(self) -> None:
36 | self.value = 0
37 |
38 |
39 | def xor(a: bytes, b: bytes) -> bytes:
40 | # Fast XOR implementation, according to https://stackoverflow.com/questions/29408173/byte-operations-xor-in-python
41 | a, b = a[: len(b)], b[: len(a)]
42 | int_b = int.from_bytes(b, sys.byteorder)
43 | int_a = int.from_bytes(a, sys.byteorder)
44 | int_enc = int_b ^ int_a
45 | return int_enc.to_bytes(len(b), sys.byteorder)
46 |
47 |
48 | def nonce(nbytes: int) -> bytes:
49 | return bytes(
50 | bytearray(
51 | [random.SystemRandom().choice(SAFE_NONCE_CHARS) for _ in range(nbytes)]
52 | )
53 | )
54 |
55 |
56 | def find_tables(expression: exp.Expression) -> List[exp.Table]:
57 | """Find all tables in an expression"""
58 | if isinstance(
59 | expression, (exp.Select, exp.Subquery, exp.Union, exp.Except, exp.Intersect)
60 | ):
61 | return [
62 | source
63 | for scope in traverse_scope(expression)
64 | for source in scope.sources.values()
65 | if isinstance(source, exp.Table)
66 | ]
67 | return []
68 |
69 |
70 | def find_dbs(expression: exp.Expression) -> List[str]:
71 | """Find all database names in an expression"""
72 | return [table.text("db") for table in find_tables(expression)]
73 |
74 |
75 | def dict_depth(d: dict) -> int:
76 | """
77 | Get the nesting depth of a dictionary.
78 | For example:
79 | >>> dict_depth(None)
80 | 0
81 | >>> dict_depth({})
82 | 1
83 | >>> dict_depth({"a": "b"})
84 | 1
85 | >>> dict_depth({"a": {}})
86 | 2
87 | >>> dict_depth({"a": {"b": {}}})
88 | 3
89 | Args:
90 | d (dict): dictionary
91 | Returns:
92 | int: depth
93 | """
94 | try:
95 | return 1 + dict_depth(next(iter(d.values())))
96 | except AttributeError:
97 | # d doesn't have attribute "values"
98 | return 0
99 | except StopIteration:
100 | # d.values() returns an empty sequence
101 | return 1
102 |
103 |
104 | async def aiterate(iterable: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T]:
105 | """Iterate either an async iterable or a regular iterable"""
106 | if inspect.isasyncgen(iterable):
107 | async for item in iterable:
108 | yield item
109 | else:
110 | for item in cast(Iterable, iterable):
111 | yield item
112 |
113 |
114 | async def cooperative_iterate(
115 | iterable: AsyncIterable[T], batch_size: int = 10_000
116 | ) -> AsyncIterator[T]:
117 | """
118 | Iterate an async iterable in a cooperative manner, yielding control back to the event loop every `batch_size` iterations
119 | """
120 | i = 0
121 | async for item in iterable:
122 | if i != 0 and i % batch_size == 0:
123 | await asyncio.sleep(0)
124 | yield item
125 | i += 1
126 |
--------------------------------------------------------------------------------
/mysql_mimic/variable_processor.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from typing import Dict, Mapping, Generator, MutableMapping, Any
3 |
4 | from sqlglot import expressions as exp
5 |
6 | from mysql_mimic.intercept import value_to_expression, expression_to_value
7 |
8 | variable_constants = {
9 | "CURRENT_USER",
10 | "CURRENT_TIME",
11 | "CURRENT_TIMESTAMP",
12 | "CURRENT_DATE",
13 | }
14 |
15 |
16 | def _get_var_assignments(expression: exp.Expression) -> Dict[str, str]:
17 | """Returns a dictionary of session variables to replace, as indicated by SET_VAR hints."""
18 | hints = expression.find_all(exp.Hint)
19 | if not hints:
20 | return {}
21 |
22 | assignments = {}
23 |
24 | # Iterate in reverse order so higher SET_VAR hints get priority
25 | for hint in reversed(list(hints)):
26 | set_var_hint = None
27 |
28 | for e in hint.expressions:
29 | if isinstance(e, exp.Func) and e.name == "SET_VAR":
30 | set_var_hint = e
31 | for eq in e.expressions:
32 | assignments[eq.left.name] = expression_to_value(eq.right)
33 |
34 | if set_var_hint:
35 | set_var_hint.pop()
36 |
37 | # Remove the hint entirely if SET_VAR was the only expression
38 | if not hint.expressions:
39 | hint.pop()
40 |
41 | return assignments
42 |
43 |
44 | class VariableProcessor:
45 | """
46 | This class modifies the query in two ways:
47 | 1. Processing SET_VAR hints for system variables in the query text
48 | 2. Replacing functions in the query with their replacements defined in the functions argument.
49 | original values.
50 | """
51 |
52 | def __init__(
53 | self, functions: Mapping, variables: MutableMapping, expression: exp.Expression
54 | ):
55 | self._functions = functions
56 | self._variables = variables
57 | self._expression = expression
58 |
59 | # Stores the original system variable values.
60 | self._orig: Dict[str, Any] = {}
61 |
62 | @contextmanager
63 | def set_variables(self) -> Generator[exp.Expression, None, None]:
64 | assignments = _get_var_assignments(self._expression)
65 | self._orig = {k: self._variables.get(k) for k in assignments}
66 | for k, v in assignments.items():
67 | self._variables[k] = v
68 |
69 | self._replace_variables()
70 |
71 | yield self._expression
72 |
73 | for k, v in self._orig.items():
74 | self._variables[k] = v
75 |
76 | def _replace_variables(self) -> None:
77 | """Replaces certain functions in the query with literals provided from the mapping in _functions,
78 | and session parameters with the values of the session variables.
79 | """
80 | if isinstance(self._expression, exp.Set):
81 | for setitem in self._expression.expressions:
82 | if isinstance(setitem.this, exp.Binary):
83 | # In the case of statements like: SET @@foo = @@bar
84 | # We only want to replace variables on the right
85 | setitem.this.set(
86 | "expression",
87 | setitem.this.expression.transform(self._transform, copy=True),
88 | )
89 | else:
90 | self._expression.transform(self._transform, copy=False)
91 |
92 | def _transform(self, node: exp.Expression) -> exp.Expression:
93 | new_node = None
94 |
95 | if isinstance(node, exp.Func):
96 | if isinstance(node, exp.Anonymous):
97 | func_name = node.name.upper()
98 | else:
99 | func_name = node.sql_name()
100 | func = self._functions.get(func_name)
101 | if func:
102 | value = func()
103 | new_node = value_to_expression(value)
104 | elif isinstance(node, exp.Column) and node.sql() in variable_constants:
105 | value = self._functions[node.sql()]()
106 | new_node = value_to_expression(value)
107 | elif isinstance(node, exp.SessionParameter):
108 | value = self._variables.get(node.name)
109 | new_node = value_to_expression(value)
110 |
111 | if (
112 | new_node
113 | and isinstance(node.parent, exp.Select)
114 | and node.arg_key == "expressions"
115 | ):
116 | new_node = exp.alias_(new_node, exp.to_identifier(node.sql()))
117 |
118 | return new_node or node
119 |
--------------------------------------------------------------------------------
/mysql_mimic/variables.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | import re
5 | from datetime import timezone, timedelta
6 | from functools import lru_cache
7 | from typing import Any, Callable, Tuple, Iterator, MutableMapping
8 |
9 | from mysql_mimic.charset import CharacterSet, Collation
10 | from mysql_mimic.errors import MysqlError, ErrorCode
11 |
12 |
13 | class Default: ...
14 |
15 |
16 | VariableType = Callable[[Any], Any]
17 | VariableSchema = Tuple[VariableType, Any, bool]
18 |
19 |
20 | DEFAULT = Default()
21 |
22 | SYSTEM_VARIABLES: dict[str, VariableSchema] = {
23 | # name: (type, default, dynamic)
24 | "auto_increment_increment": (int, 1, True),
25 | "autocommit": (bool, True, True),
26 | "character_set_client": (str, CharacterSet.utf8mb4.name, True),
27 | "character_set_connection": (str, CharacterSet.utf8mb4.name, True),
28 | "character_set_database": (str, CharacterSet.utf8mb4.name, True),
29 | "character_set_results": (str, CharacterSet.utf8mb4.name, True),
30 | "character_set_server": (str, CharacterSet.utf8mb4.name, True),
31 | "collation_connection": (str, Collation.utf8mb4_general_ci.name, True),
32 | "collation_database": (str, Collation.utf8mb4_general_ci.name, True),
33 | "collation_server": (str, Collation.utf8mb4_general_ci.name, True),
34 | "external_user": (str, "", False),
35 | "init_connect": (str, "", True),
36 | "interactive_timeout": (int, 28800, True),
37 | "license": (str, "MIT", False),
38 | "lower_case_table_names": (int, 0, True),
39 | "max_allowed_packet": (int, 67108864, True),
40 | "max_execution_time": (int, 0, True),
41 | "net_buffer_length": (int, 16384, True),
42 | "net_write_timeout": (int, 28800, True),
43 | "performance_schema": (bool, False, False),
44 | "sql_auto_is_null": (bool, False, True),
45 | "sql_mode": (str, "ANSI", True),
46 | "sql_select_limit": (int, None, True),
47 | "system_time_zone": (str, "UTC", False),
48 | "time_zone": (str, "UTC", True),
49 | "transaction_read_only": (bool, False, True),
50 | "transaction_isolation": (str, "READ-COMMITTED", True),
51 | "version": (str, "8.0.29", False),
52 | "version_comment": (str, "mysql-mimic", False),
53 | "wait_timeout": (int, 28800, True),
54 | "event_scheduler": (str, "OFF", True),
55 | "default_storage_engine": (str, "mysql-mimic", True),
56 | "default_tmp_storage_engine": (str, "mysql-mimic", True),
57 | }
58 |
59 |
60 | class Variables(abc.ABC, MutableMapping[str, Any]):
61 | """
62 | Abstract class for MySQL system variables.
63 | """
64 |
65 | def __init__(self) -> None:
66 | # Current variable values
67 | self._values: dict[str, Any] = {}
68 |
69 | def __getitem__(self, key: str) -> Any | None:
70 | try:
71 | return self.get_variable(key)
72 | except MysqlError as e:
73 | raise KeyError from e
74 |
75 | def __setitem__(self, key: str, value: Any) -> None:
76 | return self.set(key, value)
77 |
78 | def __delitem__(self, key: str) -> None:
79 | raise MysqlError(f"Cannot delete session variable {key}.")
80 |
81 | def __iter__(self) -> Iterator[str]:
82 | return self.schema.__iter__()
83 |
84 | def __len__(self) -> int:
85 | return len(self.schema)
86 |
87 | def get_schema(self, name: str) -> VariableSchema:
88 | schema = self.schema.get(name)
89 | if not schema:
90 | raise MysqlError(
91 | f"Unknown variable: {name}", code=ErrorCode.UNKNOWN_SYSTEM_VARIABLE
92 | )
93 | return schema
94 |
95 | def set(self, name: str, value: Any, force: bool = False) -> None:
96 | name = name.lower()
97 | type_, default, dynamic = self.get_schema(name)
98 |
99 | if not dynamic and not force:
100 | raise MysqlError(
101 | f"Variable is not dynamic: {name}", code=ErrorCode.PARSE_ERROR
102 | )
103 |
104 | if value is DEFAULT or value is None:
105 | self._values[name] = default
106 | else:
107 | self._values[name] = type_(value)
108 |
109 | def get_variable(self, name: str) -> Any | None:
110 | name = name.lower()
111 | if name in self._values:
112 | return self._values[name]
113 | _, default, _ = self.get_schema(name)
114 |
115 | return default
116 |
117 | def list(self) -> list[tuple[str, Any]]:
118 | return [(name, self.get(name)) for name in sorted(self.schema)]
119 |
120 | @property
121 | @abc.abstractmethod
122 | def schema(self) -> dict[str, VariableSchema]: ...
123 |
124 |
125 | class GlobalVariables(Variables):
126 | def __init__(self, schema: dict[str, VariableSchema] | None = None):
127 | self._schema = schema or SYSTEM_VARIABLES
128 | super().__init__()
129 |
130 | @property
131 | def schema(self) -> dict[str, VariableSchema]:
132 | return self._schema
133 |
134 |
135 | class SessionVariables(Variables):
136 | def __init__(self, global_variables: Variables):
137 | self.global_variables = global_variables
138 | super().__init__()
139 |
140 | @property
141 | def schema(self) -> dict[str, VariableSchema]:
142 | return self.global_variables.schema
143 |
144 |
145 | RE_TIMEZONE = re.compile(r"^(?P[+-])(?P\d\d):(?P\d\d)")
146 |
147 |
148 | @lru_cache(maxsize=48)
149 | def parse_timezone(tz: str) -> timezone:
150 | if tz.lower() == "utc":
151 | return timezone.utc
152 | match = RE_TIMEZONE.match(tz)
153 | if not match:
154 | raise MysqlError(msg=f"Invalid timezone: {tz}")
155 | offset = timedelta(
156 | hours=int(match.group("hours")), minutes=int(match.group("minutes"))
157 | )
158 | if match.group("sign") == "-":
159 | offset = offset * -1
160 | return timezone(offset)
161 |
--------------------------------------------------------------------------------
/mysql_mimic/version.py:
--------------------------------------------------------------------------------
1 | """mysql-mimic version information"""
2 |
3 | __version__ = "2.6.4"
4 |
5 |
6 | def main(name: str) -> None:
7 | if name == "__main__":
8 | print(__version__)
9 |
10 |
11 | main(__name__)
12 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.black]
6 | target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
7 |
8 | #[tool.pytest.ini_options]
9 | #log_cli = true
10 | #log_cli_level = "DEBUG"
11 |
12 | [tool.mypy]
13 | ignore_missing_imports = true
14 | disallow_untyped_defs = true
15 | disallow_incomplete_defs = true
16 | exclude = ['build']
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | # Import __version__
4 | exec(open("mysql_mimic/version.py").read())
5 |
6 | setup(
7 | name="mysql-mimic",
8 | version=__version__,
9 | description="A python implementation of the mysql server protocol",
10 | long_description=open("README.md").read(),
11 | long_description_content_type="text/markdown",
12 | url="https://github.com/kelsin/mysql-mimic",
13 | author="Christopher Giroir",
14 | author_email="kelsin@valefor.com",
15 | license="MIT",
16 | packages=find_packages(include=["mysql_mimic", "mysql_mimic.*"]),
17 | python_requires=">=3.6",
18 | install_requires=["sqlglot"],
19 | extras_require={
20 | "dev": [
21 | "aiomysql",
22 | "mypy",
23 | "mysql-connector-python",
24 | "black",
25 | "coverage",
26 | "freezegun",
27 | "gssapi",
28 | "k5test",
29 | "pylint",
30 | "pytest",
31 | "pytest-asyncio",
32 | "sphinx",
33 | "sqlalchemy",
34 | "twine",
35 | "wheel",
36 | ],
37 | "krb5": ["gssapi"],
38 | },
39 | classifiers=[
40 | "Development Status :: 3 - Alpha",
41 | "Intended Audience :: Developers",
42 | "Intended Audience :: Science/Research",
43 | "Operating System :: OS Independent",
44 | "Programming Language :: SQL",
45 | "Programming Language :: Python :: 3 :: Only",
46 | ],
47 | )
48 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/barakalon/mysql-mimic/570cd6cb862a52fabbbdf15d801f4855d530c11d/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import asyncio
3 | import functools
4 | from contextvars import Context, copy_context
5 | from ssl import SSLContext
6 | from typing import (
7 | Optional,
8 | List,
9 | Dict,
10 | Any,
11 | Callable,
12 | TypeVar,
13 | Awaitable,
14 | Sequence,
15 | AsyncGenerator,
16 | Type,
17 | )
18 |
19 | import aiomysql
20 | import mysql.connector
21 | from mysql.connector.abstracts import MySQLConnectionAbstract
22 | from mysql.connector.connection import (
23 | MySQLCursorPrepared,
24 | MySQLCursorDict,
25 | )
26 | from mysql.connector.cursor import MySQLCursor
27 | from sqlglot import expressions as exp
28 | from sqlglot.executor import execute
29 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
30 | import pytest
31 | import pytest_asyncio
32 |
33 | from mysql_mimic import MysqlServer, Session
34 | from mysql_mimic.auth import (
35 | User,
36 | AuthPlugin,
37 | IdentityProvider,
38 | )
39 | from mysql_mimic.connection import Connection
40 | from mysql_mimic.results import AllowedResult
41 | from mysql_mimic.schema import InfoSchema
42 |
43 |
44 | class PreparedDictCursor(MySQLCursorPrepared):
45 | def fetchall(self) -> Any:
46 | rows = super().fetchall()
47 |
48 | if rows is not None:
49 | return [dict(zip(self.column_names, row)) for row in rows]
50 |
51 | return None
52 |
53 |
54 | class MockSession(Session):
55 | def __init__(self) -> None:
56 | super().__init__()
57 | self.ctx: Context | None = None
58 | self.return_value: Any = None
59 | self.execute = False
60 | self.echo = False
61 | self.last_query_attrs: Optional[Dict[str, str]] = None
62 | self.users: Optional[Dict[str, User]] = None
63 |
64 | # tests can set an Event that the session will use to wait before continuing
65 | self.pause: Optional[asyncio.Event] = None
66 | # the session will set this event while waiting on `pause`
67 | self.waiting = asyncio.Event()
68 |
69 | async def init(self, connection: Connection) -> None:
70 | await super().init(connection)
71 | self.ctx = copy_context()
72 |
73 | async def query(
74 | self, expression: exp.Expression, sql: str, attrs: Dict[str, str]
75 | ) -> AllowedResult:
76 | if self.pause:
77 | self.waiting.set()
78 | await self.pause.wait()
79 | self.waiting.clear()
80 |
81 | self.last_query_attrs = attrs
82 | if self.echo:
83 | return [(sql,)], ["sql"]
84 | if self.execute:
85 | result = execute(expression)
86 | return result.rows, result.columns
87 | return self.return_value
88 |
89 | async def schema(self) -> dict | InfoSchema:
90 | return {
91 | "db": {
92 | "x": {
93 | "a": "TEXT",
94 | "b": "TEXT",
95 | },
96 | "y": {
97 | "b": "TEXT",
98 | "c": "TEXT",
99 | },
100 | }
101 | }
102 |
103 |
104 | class MockIdentityProvider(IdentityProvider):
105 | def __init__(self, auth_plugins: List[AuthPlugin], users: Dict[str, User]):
106 | self.auth_plugins = auth_plugins
107 | self.users = users
108 |
109 | def get_plugins(self) -> Sequence[AuthPlugin]:
110 | return self.auth_plugins
111 |
112 | async def get_user(self, username: str) -> Optional[User]:
113 | return self.users.get(username)
114 |
115 |
116 | T = TypeVar("T")
117 |
118 |
119 | async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
120 | loop = asyncio.get_running_loop()
121 | func_call = functools.partial(func, *args, **kwargs)
122 | return await loop.run_in_executor(None, func_call)
123 |
124 |
125 | @pytest_asyncio.fixture
126 | async def session() -> MockSession:
127 | return MockSession()
128 |
129 |
130 | @pytest.fixture
131 | def session_factory(session: MockSession) -> Callable[[], MockSession]:
132 | return lambda: session
133 |
134 |
135 | @pytest.fixture
136 | def auth_plugins() -> Optional[List[AuthPlugin]]:
137 | return None
138 |
139 |
140 | @pytest.fixture
141 | def users() -> Dict[str, User]:
142 | return {}
143 |
144 |
145 | @pytest.fixture
146 | def identity_provider(
147 | auth_plugins: Optional[List[AuthPlugin]], users: Dict[str, User]
148 | ) -> Optional[MockIdentityProvider]:
149 | if auth_plugins:
150 | return MockIdentityProvider(auth_plugins, users)
151 | return None
152 |
153 |
154 | @pytest.fixture
155 | def ssl() -> Optional[SSLContext]:
156 | return None
157 |
158 |
159 | @pytest_asyncio.fixture
160 | async def server(
161 | session_factory: Callable[[], MockSession],
162 | identity_provider: Optional[MockIdentityProvider],
163 | ssl: Optional[SSLContext],
164 | ) -> AsyncGenerator[MysqlServer, None]:
165 | srv = MysqlServer(
166 | session_factory=session_factory,
167 | identity_provider=identity_provider,
168 | ssl=ssl,
169 | )
170 | try:
171 | await srv.start_server(port=3307)
172 | except OSError as e:
173 | if e.errno == 48:
174 | # Port already in use.
175 | # This should only happen if there is a race condition between concurrent test runs.
176 | # Getting a new free port can be slow, so we optimistically try to use the free_port fixture first.
177 | await srv.start_server(port=0)
178 | else:
179 | raise
180 | asyncio.create_task(srv.serve_forever())
181 | try:
182 | yield srv
183 | finally:
184 | srv.close()
185 | await srv.wait_closed()
186 |
187 |
188 | @pytest.fixture
189 | def port(server: MysqlServer) -> int:
190 | return server.sockets()[0].getsockname()[1]
191 |
192 |
193 | ConnectFixture = Callable[..., Awaitable[MySQLConnectionAbstract]]
194 |
195 |
196 | @pytest.fixture
197 | def connect(port: int) -> ConnectFixture:
198 | async def conn(**kwargs: Any) -> MySQLConnectionAbstract:
199 | return await to_thread(
200 | mysql.connector.connect, use_pure=True, port=port, **kwargs # type: ignore
201 | )
202 |
203 | return conn
204 |
205 |
206 | @pytest_asyncio.fixture
207 | async def mysql_connector_conn(
208 | connect: ConnectFixture,
209 | ) -> AsyncGenerator[MySQLConnectionAbstract, None]:
210 | conn = await connect(user="levon_helm")
211 | # Different versions of mysql-connector use various default collations.
212 | # Execute the following query to ensure consistent test results across different versions.
213 | await query(conn, "SET NAMES 'utf8mb4' COLLATE 'utf8mb4_general_ci'")
214 | try:
215 | yield conn
216 | finally:
217 | conn.close()
218 |
219 |
220 | @pytest_asyncio.fixture
221 | async def aiomysql_conn(port: int) -> aiomysql.Connection:
222 | async with aiomysql.connect(port=port, user="levon_helm") as conn:
223 | yield conn
224 |
225 |
226 | @pytest_asyncio.fixture
227 | async def sqlalchemy_engine(
228 | port: int,
229 | ) -> AsyncGenerator[AsyncEngine, None]:
230 | engine = create_async_engine(url=f"mysql+aiomysql://levon_helm@127.0.0.1:{port}")
231 | try:
232 | yield engine
233 | finally:
234 | await engine.dispose()
235 |
236 |
237 | async def query(
238 | conn: MySQLConnectionAbstract,
239 | sql: str,
240 | cursor_class: Type[MySQLCursor] = MySQLCursorDict,
241 | params: Sequence[Any] | None = None,
242 | query_attributes: Dict[str, str] | None = None,
243 | ) -> Sequence[Any]:
244 | cursor = await to_thread(conn.cursor, cursor_class=cursor_class)
245 | if query_attributes:
246 | for key, value in query_attributes.items():
247 | cursor.add_attribute(key, value)
248 | await to_thread(cursor.execute, sql, *(p for p in [params] if p))
249 | result = await to_thread(cursor.fetchall)
250 | await to_thread(cursor.close)
251 | return result
252 |
--------------------------------------------------------------------------------
/tests/fixtures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/barakalon/mysql-mimic/570cd6cb862a52fabbbdf15d801f4855d530c11d/tests/fixtures/__init__.py
--------------------------------------------------------------------------------
/tests/fixtures/certificate.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIEIzCCAwugAwIBAgIUUMROPFb3ZMssWfl9/tJVMB0Zl0owDQYJKoZIhvcNAQEL
3 | BQAwgZ8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQH
4 | DA1TYW4gRnJhbmNpc2NvMRQwEgYDVQQKDAtteXNxbC1taW1pYzEUMBIGA1UECwwL
5 | bXlzcWwtbWltaWMxFDASBgNVBAMMC215c3FsLW1pbWljMSEwHwYJKoZIhvcNAQkB
6 | FhJleGFtcGxlQGRvbWFpbi5jb20wIBcNMjIxMDI0MTgyNzQ0WhgPMjEyMjA5MzAx
7 | ODI3NDRaMIGfMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQG
8 | A1UEBwwNU2FuIEZyYW5jaXNjbzEUMBIGA1UECgwLbXlzcWwtbWltaWMxFDASBgNV
9 | BAsMC215c3FsLW1pbWljMRQwEgYDVQQDDAtteXNxbC1taW1pYzEhMB8GCSqGSIb3
10 | DQEJARYSZXhhbXBsZUBkb21haW4uY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
11 | MIIBCgKCAQEAsnmS0k1Q0f+mv7izbT5K8zqYP54KwyW000qMhYLW/Ir5fGAw2QHg
12 | Q1vMBY+1oaHcO1DxjriG4l8P8ho9AShAD63Q3PVUiy3Prxi3PaZ/jPsI5rN/8s7s
13 | TXai6Po6gD56uYtZACl4cjF2ob3Vy/qzIPitW3D7UVEL+nqDEZUSmbFnT+NAMaCv
14 | Bq+Zf94vQSpQXgUSbTNAuFqjwMLeb8VX31e5yFOvhRd1Y65MOwmeCX0hMZ+XtcPc
15 | xtgCyQ9uT2OaKcqcngE/LtGnC4UJy0u5bcI8pPgwenGWNhQ6LrqLWAUizEXoprY8
16 | PUVoE42Itm7j1ODKm71OpNPqEjXxvs924wIDAQABo1MwUTAdBgNVHQ4EFgQUSVIG
17 | PQ1FTdvJeup85cCLpp/5wGIwHwYDVR0jBBgwFoAUSVIGPQ1FTdvJeup85cCLpp/5
18 | wGIwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAJcoALjwItgWa
19 | Qpr/hJCHhtf1gS6DCbtQoc+Owvz0iJkC+zqeaFlOYwVQz/n8iQ45g03cLMxHfMEv
20 | lj1ctHfVGI+2exp2hesMmU+CT49D9qFS2vQnaKxkeUqv5ia2d+V49jE6fj3CxIPP
21 | T3yI9+K2Ojx+7/WVVxCgz+eIZJbTCugoypDksldfuY4mx48E9qefOlBNcYyBzHt4
22 | 4iymCFFrfQHfMX+PnKDGHaoh/T/oZxdZRyDxNqTLfiyZ1PtAbueeT4Jf6CmamQZh
23 | IadFPXnA1YlEqreMk6nDrF8pqgOosHIngLhhUXHAVj/Br3UaDTaUGzlrlL7rLpRL
24 | oToCkplUgg==
25 | -----END CERTIFICATE-----
26 |
--------------------------------------------------------------------------------
/tests/fixtures/key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN PRIVATE KEY-----
2 | MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCyeZLSTVDR/6a/
3 | uLNtPkrzOpg/ngrDJbTTSoyFgtb8ivl8YDDZAeBDW8wFj7Whodw7UPGOuIbiXw/y
4 | Gj0BKEAPrdDc9VSLLc+vGLc9pn+M+wjms3/yzuxNdqLo+jqAPnq5i1kAKXhyMXah
5 | vdXL+rMg+K1bcPtRUQv6eoMRlRKZsWdP40AxoK8Gr5l/3i9BKlBeBRJtM0C4WqPA
6 | wt5vxVffV7nIU6+FF3Vjrkw7CZ4JfSExn5e1w9zG2ALJD25PY5opypyeAT8u0acL
7 | hQnLS7ltwjyk+DB6cZY2FDouuotYBSLMReimtjw9RWgTjYi2buPU4MqbvU6k0+oS
8 | NfG+z3bjAgMBAAECggEBAI1wPUO+k/soUByGImPDxyAE4p0gAUVv/2KnJL+11exj
9 | sp23mV6Q1wpqmEAcCIQkQuUbG6PQZszFK1zhIFFndYU3aVuCbNKzpnAL9UO9TD4M
10 | v5wcypxBEhG9oBNkIrJ5UUbzwL+ZHePZgTtitykk74qEqNXbrr9drFF/f5mSeyAi
11 | nUkizI/aGJVp2+NVpsFvZd0M5gM1tEWoBq/5fny/b1oWM/blh1ZwWQII3bTcDABH
12 | 1h30oTuGWYyeQExUjiS4SDDJ8BH9PZxUGQhQsK1ZuJczJ5B+jh0OWicdD+P+V2EB
13 | npoWQVHKwqwfDQisx2T04TAb65QcbM7V/EwgQPpcd5ECgYEA3bQbff5bIvbgJwCo
14 | 1bHxj3MMmu1Qs+Km9nR8FbEtFaF1lbXzPLCj98VDpT3tBNw4VwpbVYn5rd4ZSs6h
15 | jlK2zEakFRvU/qgtTrUviW4oeSrEtUt8jI0QkdobUWPM8HuZih6Lgm5DTU6HwIBR
16 | dxkhojx7c1zwSm66Bojvn/CjVakCgYEAzhWHAxkNAj97XvjUe9XsV41jiksheNS2
17 | 7vAm1PASq9MGFpSazq2NeXTxuChvMK8vN9r2FngnFKW4hLGMi8wpReT+S1C32sNP
18 | ZWRq6w2JWgQvh7dNSNmn0mBc7zwUB6sFeS72q3Ge355dPJiuomRLxZmm+P3suHIV
19 | b+4Fj6q0p6sCgYEAqj5EojJwn1+97pVGEJqM6N+qvUkgoJGaLkRyiGG+Qg7y8RyA
20 | BImLz5ZuBHSSDhphNQ1h50SFMusKtvQHAPgpIKHaG898dnSEHh1pvHmXoLujw6eM
21 | o40rPSSjt5MQa1YuJ+6eqHCtQ67a9YpThEYLGr6g+YxThISUWrJKd6HceskCgYB0
22 | gWMUc0MRdEYQyOeHIsc8L+iINDU2FDtfFVE+rIJBtUkJ1vU1xpPmiCBnFiTWBxPQ
23 | pe7dgQvG9nE8QwvLtJ3Yr767YWSvPh9SmNSBEeQGibs9JHmCp9niayve673/H8Y2
24 | XkCBZ/iDPwpCyaZglAbqLRViSltbYtOPtaZbNAxxhQKBgH8Gj0yg3A5DwWEyeUKR
25 | KJBS2rPiKgrhJs8Kd8GZZUb3H5WGqzfgrRK1p9j5Ug8UXSWbB9e5jg4ymVtcAAhc
26 | tRz4rCvYNY8fHlA2TfzrOeEuuxtoMFyxd26eFjZvS/w2VdFrIQZbSvDkV4hMqhpS
27 | CCslSFIIZOMCzqgQtM+NTQJ6
28 | -----END PRIVATE KEY-----
29 |
--------------------------------------------------------------------------------
/tests/fixtures/queries.py:
--------------------------------------------------------------------------------
1 | DATA_GRIP_PARAMETERS = """
2 | SELECT
3 | SPECIFIC_SCHEMA AS PROCEDURE_CAT,
4 | NULL AS `PROCEDURE_SCHEM`,
5 | SPECIFIC_NAME AS `PROCEDURE_NAME`,
6 | IFNULL(PARAMETER_NAME, '') AS `COLUMN_NAME`,
7 | CASE
8 | WHEN PARAMETER_MODE = 'IN'
9 | THEN 1
10 | WHEN PARAMETER_MODE = 'OUT'
11 | THEN 4
12 | WHEN PARAMETER_MODE = 'INOUT'
13 | THEN 2
14 | WHEN ORDINAL_POSITION = 0
15 | THEN 5
16 | ELSE 0
17 | END AS `COLUMN_TYPE`,
18 | CASE
19 | WHEN UPPER(DATA_TYPE) = 'DECIMAL'
20 | THEN 3
21 | WHEN UPPER(DATA_TYPE) = 'DECIMAL UNSIGNED'
22 | THEN 3
23 | WHEN UPPER(DATA_TYPE) = 'TINYINT'
24 | THEN -6
25 | WHEN UPPER(DATA_TYPE) = 'TINYINT UNSIGNED'
26 | THEN -6
27 | WHEN UPPER(DATA_TYPE) = 'BOOLEAN'
28 | THEN 16
29 | WHEN UPPER(DATA_TYPE) = 'SMALLINT'
30 | THEN 5
31 | WHEN UPPER(DATA_TYPE) = 'SMALLINT UNSIGNED'
32 | THEN 5
33 | WHEN UPPER(DATA_TYPE) = 'INT'
34 | THEN 4
35 | WHEN UPPER(DATA_TYPE) = 'INT UNSIGNED'
36 | THEN 4
37 | WHEN UPPER(DATA_TYPE) = 'FLOAT'
38 | THEN 7
39 | WHEN UPPER(DATA_TYPE) = 'FLOAT UNSIGNED'
40 | THEN 7
41 | WHEN UPPER(DATA_TYPE) = 'DOUBLE'
42 | THEN 8
43 | WHEN UPPER(DATA_TYPE) = 'DOUBLE UNSIGNED'
44 | THEN 8
45 | WHEN UPPER(DATA_TYPE) = 'NULL'
46 | THEN 0
47 | WHEN UPPER(DATA_TYPE) = 'TIMESTAMP'
48 | THEN 93
49 | WHEN UPPER(DATA_TYPE) = 'BIGINT'
50 | THEN -5
51 | WHEN UPPER(DATA_TYPE) = 'BIGINT UNSIGNED'
52 | THEN -5
53 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT'
54 | THEN 4
55 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT UNSIGNED'
56 | THEN 4
57 | WHEN UPPER(DATA_TYPE) = 'DATE'
58 | THEN 91
59 | WHEN UPPER(DATA_TYPE) = 'TIME'
60 | THEN 92
61 | WHEN UPPER(DATA_TYPE) = 'DATETIME'
62 | THEN 93
63 | WHEN UPPER(DATA_TYPE) = 'YEAR'
64 | THEN 91
65 | WHEN UPPER(DATA_TYPE) = 'VARCHAR'
66 | THEN 12
67 | WHEN UPPER(DATA_TYPE) = 'VARBINARY'
68 | THEN -3
69 | WHEN UPPER(DATA_TYPE) = 'BIT'
70 | THEN -7
71 | WHEN UPPER(DATA_TYPE) = 'JSON'
72 | THEN -1
73 | WHEN UPPER(DATA_TYPE) = 'ENUM'
74 | THEN 1
75 | WHEN UPPER(DATA_TYPE) = 'SET'
76 | THEN 1
77 | WHEN UPPER(DATA_TYPE) = 'TINYBLOB'
78 | THEN -3
79 | WHEN UPPER(DATA_TYPE) = 'TINYTEXT'
80 | THEN 12
81 | WHEN UPPER(DATA_TYPE) = 'MEDIUMBLOB'
82 | THEN -4
83 | WHEN UPPER(DATA_TYPE) = 'MEDIUMTEXT'
84 | THEN -1
85 | WHEN UPPER(DATA_TYPE) = 'LONGBLOB'
86 | THEN -4
87 | WHEN UPPER(DATA_TYPE) = 'LONGTEXT'
88 | THEN -1
89 | WHEN UPPER(DATA_TYPE) = 'BLOB'
90 | THEN -4
91 | WHEN UPPER(DATA_TYPE) = 'TEXT'
92 | THEN -1
93 | WHEN UPPER(DATA_TYPE) = 'CHAR'
94 | THEN 1
95 | WHEN UPPER(DATA_TYPE) = 'BINARY'
96 | THEN -2
97 | WHEN UPPER(DATA_TYPE) = 'GEOMETRY'
98 | THEN -2
99 | WHEN UPPER(DATA_TYPE) = 'UNKNOWN'
100 | THEN 1111
101 | WHEN UPPER(DATA_TYPE) = 'POINT'
102 | THEN -2
103 | WHEN UPPER(DATA_TYPE) = 'LINESTRING'
104 | THEN -2
105 | WHEN UPPER(DATA_TYPE) = 'POLYGON'
106 | THEN -2
107 | WHEN UPPER(DATA_TYPE) = 'MULTIPOINT'
108 | THEN -2
109 | WHEN UPPER(DATA_TYPE) = 'MULTILINESTRING'
110 | THEN -2
111 | WHEN UPPER(DATA_TYPE) = 'MULTIPOLYGON'
112 | THEN -2
113 | WHEN UPPER(DATA_TYPE) = 'GEOMETRYCOLLECTION'
114 | THEN -2
115 | WHEN UPPER(DATA_TYPE) = 'GEOMCOLLECTION'
116 | THEN -2
117 | ELSE 1111
118 | END AS `DATA_TYPE`,
119 | UPPER(
120 | CASE
121 | WHEN LOCATE('UNSIGNED', UPPER(DTD_IDENTIFIER)) <> 0
122 | AND LOCATE('UNSIGNED', UPPER(DATA_TYPE)) = 0
123 | AND LOCATE('SET', UPPER(DATA_TYPE)) <> 1
124 | AND LOCATE('ENUM', UPPER(DATA_TYPE)) <> 1
125 | THEN CONCAT(DATA_TYPE, ' UNSIGNED')
126 | WHEN UPPER(DATA_TYPE) = 'POINT'
127 | THEN 'GEOMETRY'
128 | WHEN UPPER(DATA_TYPE) = 'LINESTRING'
129 | THEN 'GEOMETRY'
130 | WHEN UPPER(DATA_TYPE) = 'POLYGON'
131 | THEN 'GEOMETRY'
132 | WHEN UPPER(DATA_TYPE) = 'MULTIPOINT'
133 | THEN 'GEOMETRY'
134 | WHEN UPPER(DATA_TYPE) = 'MULTILINESTRING'
135 | THEN 'GEOMETRY'
136 | WHEN UPPER(DATA_TYPE) = 'MULTIPOLYGON'
137 | THEN 'GEOMETRY'
138 | WHEN UPPER(DATA_TYPE) = 'GEOMETRYCOLLECTION'
139 | THEN 'GEOMETRY'
140 | WHEN UPPER(DATA_TYPE) = 'GEOMCOLLECTION'
141 | THEN 'GEOMETRY'
142 | ELSE UPPER(DATA_TYPE)
143 | END
144 | ) AS TYPE_NAME,
145 | CASE
146 | WHEN LCASE(DATA_TYPE) = 'date'
147 | THEN 0
148 | WHEN LCASE(DATA_TYPE) = 'time'
149 | OR LCASE(DATA_TYPE) = 'datetime'
150 | OR LCASE(DATA_TYPE) = 'timestamp'
151 | THEN DATETIME_PRECISION
152 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT' AND LOCATE('UNSIGNED', UPPER(DTD_IDENTIFIER)) <> 0
153 | THEN 8
154 | WHEN UPPER(DATA_TYPE) = 'JSON'
155 | THEN 1073741824
156 | ELSE NUMERIC_PRECISION
157 | END AS `PRECISION`,
158 | CASE
159 | WHEN LCASE(DATA_TYPE) = 'date'
160 | THEN 10
161 | WHEN LCASE(DATA_TYPE) = 'time'
162 | THEN 8 + (
163 | CASE
164 | WHEN DATETIME_PRECISION > 0
165 | THEN DATETIME_PRECISION + 1
166 | ELSE DATETIME_PRECISION
167 | END
168 | )
169 | WHEN LCASE(DATA_TYPE) = 'datetime' OR LCASE(DATA_TYPE) = 'timestamp'
170 | THEN 19 + (
171 | CASE
172 | WHEN DATETIME_PRECISION > 0
173 | THEN DATETIME_PRECISION + 1
174 | ELSE DATETIME_PRECISION
175 | END
176 | )
177 | WHEN UPPER(DATA_TYPE) = 'MEDIUMINT' AND LOCATE('UNSIGNED', UPPER(DTD_IDENTIFIER)) <> 0
178 | THEN 8
179 | WHEN UPPER(DATA_TYPE) = 'JSON'
180 | THEN 1073741824
181 | WHEN CHARACTER_MAXIMUM_LENGTH IS NULL
182 | THEN NUMERIC_PRECISION
183 | WHEN CHARACTER_MAXIMUM_LENGTH > 2147483647
184 | THEN 2147483647
185 | ELSE CHARACTER_MAXIMUM_LENGTH
186 | END AS LENGTH,
187 | NUMERIC_SCALE AS `SCALE`,
188 | 10 AS RADIX,
189 | 1 AS `NULLABLE`,
190 | NULL AS `REMARKS`,
191 | NULL AS `COLUMN_DEF`,
192 | NULL AS `SQL_DATA_TYPE`,
193 | NULL AS `SQL_DATETIME_SUB`,
194 | CASE
195 | WHEN CHARACTER_OCTET_LENGTH > 2147483647
196 | THEN 2147483647
197 | ELSE CHARACTER_OCTET_LENGTH
198 | END AS `CHAR_OCTET_LENGTH`,
199 | ORDINAL_POSITION,
200 | 'YES' AS `IS_NULLABLE`,
201 | SPECIFIC_NAME
202 | FROM INFORMATION_SCHEMA.PARAMETERS
203 | WHERE
204 | SPECIFIC_SCHEMA = 'magic' AND SPECIFIC_NAME LIKE '%'
205 | ORDER BY
206 | SPECIFIC_SCHEMA,
207 | SPECIFIC_NAME,
208 | ROUTINE_TYPE,
209 | ORDINAL_POSITION
210 | """
211 |
212 | DATA_GRIP_VARIABLES = """
213 | /* mysql-connector-java-8.0.25 (Revision: 08be9e9b4cba6aa115f9b27b215887af40b159e0) */SELECT
214 | @@session.auto_increment_increment AS auto_increment_increment,
215 | @@character_set_client AS character_set_client,
216 | @@character_set_connection AS character_set_connection,
217 | @@character_set_results AS character_set_results,
218 | @@character_set_server AS character_set_server,
219 | @@collation_server AS collation_server,
220 | @@collation_connection AS collation_connection,
221 | @@init_connect AS init_connect,
222 | @@interactive_timeout AS interactive_timeout,
223 | @@license AS license,
224 | @@lower_case_table_names AS lower_case_table_names,
225 | @@max_allowed_packet AS max_allowed_packet,
226 | @@net_write_timeout AS net_write_timeout,
227 | @@performance_schema AS performance_schema,
228 | @@sql_mode AS sql_mode,
229 | @@system_time_zone AS system_time_zone,
230 | @@time_zone AS time_zone,
231 | @@transaction_isolation AS transaction_isolation,
232 | @@wait_timeout AS wait_timeout
233 | """
234 |
235 | DATA_GRIP_TABLES = """
236 | /* ApplicationName=DataGrip 2022.2.5 */
237 | SELECT
238 | TABLE_SCHEMA AS TABLE_CAT,
239 | NULL AS TABLE_SCHEM,
240 | TABLE_NAME,
241 | CASE
242 | WHEN TABLE_TYPE = 'BASE TABLE'
243 | THEN CASE
244 | WHEN TABLE_SCHEMA = 'mysql' OR TABLE_SCHEMA = 'performance_schema'
245 | THEN 'SYSTEM TABLE'
246 | ELSE 'TABLE'
247 | END
248 | WHEN TABLE_TYPE = 'TEMPORARY'
249 | THEN 'LOCAL_TEMPORARY'
250 | ELSE TABLE_TYPE
251 | END AS TABLE_TYPE,
252 | TABLE_COMMENT AS REMARKS,
253 | NULL AS TYPE_CAT,
254 | NULL AS TYPE_SCHEM,
255 | NULL AS TYPE_NAME,
256 | NULL AS SELF_REFERENCING_COL_NAME,
257 | NULL AS REF_GENERATION
258 | FROM INFORMATION_SCHEMA.TABLES
259 | WHERE
260 | TABLE_SCHEMA = 'db' AND TABLE_NAME LIKE '%'
261 | HAVING
262 | TABLE_TYPE IN ('LOCAL TEMPORARY', 'SYSTEM TABLE', 'SYSTEM VIEW', 'TABLE', 'VIEW')
263 | ORDER BY
264 | TABLE_TYPE,
265 | TABLE_SCHEMA,
266 | TABLE_NAME
267 | """
268 |
269 | TABLEAU_INDEXES_1 = """
270 | SELECT
271 | A.REFERENCED_TABLE_SCHEMA AS PKTABLE_CAT,
272 | NULL AS PKTABLE_SCHEM,
273 | A.REFERENCED_TABLE_NAME AS PKTABLE_NAME,
274 | A.REFERENCED_COLUMN_NAME AS PKCOLUMN_NAME,
275 | A.TABLE_SCHEMA AS FKTABLE_CAT,
276 | NULL AS FKTABLE_SCHEM,
277 | A.TABLE_NAME AS FKTABLE_NAME,
278 | A.COLUMN_NAME AS FKCOLUMN_NAME,
279 | A.ORDINAL_POSITION AS KEY_SEQ,
280 | CASE
281 | WHEN R.UPDATE_RULE = 'CASCADE' THEN 0
282 | WHEN R.UPDATE_RULE = 'SET NULL' THEN 2
283 | WHEN R.UPDATE_RULE = 'SET DEFAULT' THEN 4
284 | WHEN R.UPDATE_RULE = 'SET RESTRICT' THEN 1
285 | WHEN R.UPDATE_RULE = 'SET NO ACTION' THEN 3
286 | ELSE 3
287 | END AS UPDATE_RULE,
288 | CASE
289 | WHEN R.DELETE_RULE = 'CASCADE' THEN 0
290 | WHEN R.DELETE_RULE = 'SET NULL' THEN 2
291 | WHEN R.DELETE_RULE = 'SET DEFAULT' THEN 4
292 | WHEN R.DELETE_RULE = 'SET RESTRICT' THEN 1
293 | WHEN R.DELETE_RULE = 'SET NO ACTION' THEN 3
294 | ELSE 3 END AS DELETE_RULE,
295 | A.CONSTRAINT_NAME AS FK_NAME,
296 | 'PRIMARY' AS PK_NAME,
297 | 7 AS DEFERRABILITY
298 | FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE A
299 | JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE D
300 | ON (
301 | D.TABLE_SCHEMA=A.REFERENCED_TABLE_SCHEMA
302 | AND D.TABLE_NAME=A.REFERENCED_TABLE_NAME
303 | AND D.COLUMN_NAME=A.REFERENCED_COLUMN_NAME)
304 | JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS R
305 | ON (
306 | R.CONSTRAINT_NAME = A.CONSTRAINT_NAME
307 | AND R.TABLE_NAME = A.TABLE_NAME
308 | AND R.CONSTRAINT_SCHEMA = A.TABLE_SCHEMA)
309 | WHERE
310 | D.CONSTRAINT_NAME='PRIMARY'
311 | AND A.TABLE_SCHEMA = 'magic'
312 | ORDER BY FKTABLE_CAT, FKTABLE_NAME, KEY_SEQ, PKTABLE_NAME
313 | """
314 |
315 | TABLEAU_INDEXES_2 = """
316 | SELECT COLUMN_NAME,
317 | REFERENCED_TABLE_NAME,
318 | REFERENCED_COLUMN_NAME,
319 | ORDINAL_POSITION,
320 | CONSTRAINT_NAME
321 | FROM information_schema.KEY_COLUMN_USAGE
322 | WHERE REFERENCED_TABLE_NAME IS NOT NULL
323 | AND TABLE_NAME = 'x'
324 | AND TABLE_SCHEMA = 'db'
325 | """
326 |
327 | TABLEAU_COLUMNS = """
328 | SELECT `table_name`, `column_name`
329 | FROM `information_schema`.`columns`
330 | WHERE `data_type`='enum' AND `table_schema`='db'
331 | """
332 |
333 | TABLEAU_TABLES = """
334 | SELECT TABLE_NAME,TABLE_COMMENT,IF(TABLE_TYPE='BASE TABLE', 'TABLE', TABLE_TYPE),TABLE_SCHEMA
335 | FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'db' AND ( TABLE_TYPE='BASE TABLE' OR TABLE_TYPE='VIEW' )
336 | ORDER BY TABLE_SCHEMA, TABLE_NAME
337 | """
338 |
339 | TABLEAU_SCHEMATA = """
340 | SELECT NULL, NULL, NULL, SCHEMA_NAME
341 | FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME LIKE '%'
342 | ORDER BY SCHEMA_NAME
343 | """
344 |
--------------------------------------------------------------------------------
/tests/test_auth.py:
--------------------------------------------------------------------------------
1 | from contextlib import closing
2 | from typing import Optional, Tuple, Dict
3 |
4 | import pytest
5 | from mysql.connector import DatabaseError
6 | from mysql.connector.abstracts import MySQLConnectionAbstract
7 | from mysql.connector.plugins.mysql_clear_password import MySQLClearPasswordAuthPlugin
8 |
9 | from mysql_mimic import User
10 | from mysql_mimic.auth import (
11 | NativePasswordAuthPlugin,
12 | AbstractClearPasswordAuthPlugin,
13 | NoLoginAuthPlugin,
14 | )
15 | from tests.conftest import query, to_thread, ConnectFixture
16 |
17 | # mysql.connector throws an error if you try to use mysql_clear_password without SSL.
18 | # That's silly, since SSL termination doesn't have to be handled by MySQL.
19 | # But it's extra silly in tests.
20 | MySQLClearPasswordAuthPlugin.requires_ssl = False # type: ignore
21 | MySQLConnectionAbstract.is_secure = True # type: ignore
22 |
23 | SIMPLE_AUTH_USER = "levon_helm"
24 |
25 | PASSWORD_AUTH_USER = "rick_danko"
26 | PASSWORD_AUTH_PASSWORD = "nazareth"
27 | PASSWORD_AUTH_OLD_PASSWORD = "cannonball"
28 | PASSWORD_AUTH_PLUGIN = NativePasswordAuthPlugin.client_plugin_name
29 |
30 |
31 | class TestPlugin(AbstractClearPasswordAuthPlugin):
32 | name = "mysql_clear_password"
33 |
34 | async def check(self, username: str, password: str) -> Optional[str]:
35 | return username if username == password else None
36 |
37 |
38 | TEST_PLUGIN_AUTH_USER = "garth_hudson"
39 | TEST_PLUGIN_AUTH_PASSWORD = TEST_PLUGIN_AUTH_USER
40 | TEST_PLUGIN_AUTH_PLUGIN = TestPlugin.client_plugin_name
41 |
42 | NO_LOGIN_USER = "carmen_and_the_devil"
43 | NO_LOGIN_PLUGIN = NoLoginAuthPlugin.name
44 |
45 | UNKNOWN_PLUGIN_USER = "richard_manuel"
46 | NO_PLUGIN_USER = "miss_moses"
47 |
48 |
49 | @pytest.fixture(autouse=True)
50 | def users() -> Dict[str, User]:
51 | return {
52 | SIMPLE_AUTH_USER: User(
53 | name=SIMPLE_AUTH_USER,
54 | auth_string=None,
55 | auth_plugin=NativePasswordAuthPlugin.name,
56 | ),
57 | PASSWORD_AUTH_USER: User(
58 | name=PASSWORD_AUTH_USER,
59 | auth_string=NativePasswordAuthPlugin.create_auth_string(
60 | PASSWORD_AUTH_PASSWORD
61 | ),
62 | old_auth_string=NativePasswordAuthPlugin.create_auth_string(
63 | PASSWORD_AUTH_OLD_PASSWORD
64 | ),
65 | auth_plugin=NativePasswordAuthPlugin.name,
66 | ),
67 | TEST_PLUGIN_AUTH_USER: User(
68 | name=TEST_PLUGIN_AUTH_USER,
69 | auth_string=TEST_PLUGIN_AUTH_PASSWORD,
70 | auth_plugin=TestPlugin.name,
71 | ),
72 | UNKNOWN_PLUGIN_USER: User(name=UNKNOWN_PLUGIN_USER, auth_plugin="unknown"),
73 | NO_PLUGIN_USER: User(name=NO_PLUGIN_USER, auth_plugin=NO_LOGIN_PLUGIN),
74 | }
75 |
76 |
77 | @pytest.mark.asyncio
78 | @pytest.mark.parametrize(
79 | "auth_plugins,username,password,auth_plugin",
80 | [
81 | (
82 | [NativePasswordAuthPlugin()],
83 | PASSWORD_AUTH_USER,
84 | PASSWORD_AUTH_PASSWORD,
85 | PASSWORD_AUTH_PLUGIN,
86 | ),
87 | (
88 | [TestPlugin()],
89 | TEST_PLUGIN_AUTH_USER,
90 | TEST_PLUGIN_AUTH_PASSWORD,
91 | TEST_PLUGIN_AUTH_PLUGIN,
92 | ),
93 | ([NativePasswordAuthPlugin()], SIMPLE_AUTH_USER, None, None),
94 | ([TestPlugin(), NativePasswordAuthPlugin()], SIMPLE_AUTH_USER, None, None),
95 | (None, SIMPLE_AUTH_USER, None, None),
96 | ],
97 | )
98 | async def test_auth(
99 | connect: ConnectFixture,
100 | username: str,
101 | password: Optional[str],
102 | auth_plugin: Optional[str],
103 | ) -> None:
104 | kwargs = {"user": username, "password": password, "auth_plugin": auth_plugin}
105 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
106 | with closing(await connect(**kwargs)) as conn:
107 | assert await query(conn=conn, sql="SELECT USER() AS a") == [{"a": username}]
108 |
109 |
110 | @pytest.mark.asyncio
111 | @pytest.mark.parametrize(
112 | "auth_plugins",
113 | [
114 | [NativePasswordAuthPlugin()],
115 | ],
116 | )
117 | async def test_auth_secondary_password(
118 | connect: ConnectFixture,
119 | ) -> None:
120 | with closing(
121 | await connect(
122 | user=PASSWORD_AUTH_USER,
123 | password=PASSWORD_AUTH_OLD_PASSWORD,
124 | auth_plugin=PASSWORD_AUTH_PLUGIN,
125 | )
126 | ) as conn:
127 | assert await query(conn=conn, sql="SELECT USER() AS a") == [
128 | {"a": PASSWORD_AUTH_USER}
129 | ]
130 |
131 |
132 | @pytest.mark.asyncio
133 | @pytest.mark.parametrize(
134 | "auth_plugins,user1,user2",
135 | [
136 | (
137 | [NativePasswordAuthPlugin(), TestPlugin()],
138 | (PASSWORD_AUTH_USER, PASSWORD_AUTH_PASSWORD, PASSWORD_AUTH_PLUGIN),
139 | (TEST_PLUGIN_AUTH_USER, TEST_PLUGIN_AUTH_PASSWORD, TEST_PLUGIN_AUTH_PLUGIN),
140 | ),
141 | (
142 | [NativePasswordAuthPlugin()],
143 | (PASSWORD_AUTH_USER, PASSWORD_AUTH_PASSWORD, PASSWORD_AUTH_PLUGIN),
144 | (PASSWORD_AUTH_USER, PASSWORD_AUTH_PASSWORD, PASSWORD_AUTH_PLUGIN),
145 | ),
146 | ],
147 | )
148 | async def test_change_user(
149 | connect: ConnectFixture,
150 | user1: Tuple[str, str, str],
151 | user2: Tuple[str, str, str],
152 | ) -> None:
153 | kwargs1 = {"user": user1[0], "password": user1[1], "auth_plugin": user1[2]}
154 | kwargs1 = {k: v for k, v in kwargs1.items() if v is not None}
155 |
156 | with closing(await connect(**kwargs1)) as conn:
157 | # mysql.connector doesn't have great support for COM_CHANGE_USER
158 | # Here, we have to manually override the auth plugin to use
159 | conn._auth_plugin = user2[2] # pylint: disable=protected-access
160 |
161 | await to_thread(conn.cmd_change_user, username=user2[0], password=user2[1])
162 | assert await query(conn=conn, sql="SELECT USER() AS a") == [{"a": user2[0]}]
163 |
164 |
165 | @pytest.mark.asyncio
166 | @pytest.mark.parametrize(
167 | "auth_plugins,username,password,auth_plugin,msg",
168 | [
169 | ([NativePasswordAuthPlugin()], None, None, None, "User does not exist"),
170 | (
171 | [TestPlugin()],
172 | PASSWORD_AUTH_USER,
173 | PASSWORD_AUTH_PASSWORD,
174 | PASSWORD_AUTH_PLUGIN,
175 | "Access denied",
176 | ),
177 | # This test doesn't work with newer versions of mysql-connector-python.
178 | # ([NoLoginAuthPlugin()], NO_PLUGIN_USER, None, None, "Access denied"),
179 | (
180 | [NativePasswordAuthPlugin(), NoLoginAuthPlugin()],
181 | NO_PLUGIN_USER,
182 | None,
183 | None,
184 | "Access denied",
185 | ),
186 | ],
187 | )
188 | async def test_access_denied(
189 | connect: ConnectFixture,
190 | username: Optional[str],
191 | password: Optional[str],
192 | auth_plugin: Optional[str],
193 | msg: str,
194 | ) -> None:
195 | kwargs = {"user": username, "password": password, "auth_plugin": auth_plugin}
196 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
197 |
198 | with pytest.raises(DatabaseError) as ctx:
199 | await connect(**kwargs)
200 |
201 | assert msg in str(ctx.value)
202 |
--------------------------------------------------------------------------------
/tests/test_control.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from contextlib import closing
3 | from typing import Callable
4 |
5 | import pytest
6 | from mysql.connector import DatabaseError
7 |
8 | from tests.conftest import MockSession, query, ConnectFixture
9 |
10 |
11 | @pytest.fixture
12 | def session_factory(session: MockSession) -> Callable[[], MockSession]:
13 | called_once = False
14 |
15 | def factory() -> MockSession:
16 | nonlocal called_once
17 | if not called_once:
18 | called_once = True
19 | return session
20 | return MockSession()
21 |
22 | return factory
23 |
24 |
25 | @pytest.mark.asyncio
26 | async def test_kill_connection(
27 | session: MockSession,
28 | connect: ConnectFixture,
29 | ) -> None:
30 | session.pause = asyncio.Event()
31 |
32 | with closing(await connect()) as conn1:
33 | with closing(await connect()) as conn2:
34 | q1 = asyncio.create_task(query(conn1, "SELECT a FROM x"))
35 |
36 | # Wait until we know the session is paused
37 | await session.waiting.wait()
38 |
39 | await query(conn2, f"KILL {session.connection.connection_id}")
40 |
41 | with pytest.raises(DatabaseError) as ctx:
42 | await q1
43 |
44 | assert "Session was killed" in str(ctx.value.msg)
45 |
46 | session.pause.clear()
47 |
48 | with pytest.raises(DatabaseError) as ctx:
49 | await query(conn1, "SELECT 1")
50 |
51 | assert "Connection not available" in str(ctx.value.msg)
52 |
53 |
54 | @pytest.mark.asyncio
55 | async def test_kill_query(
56 | session: MockSession,
57 | connect: ConnectFixture,
58 | ) -> None:
59 | session.pause = asyncio.Event()
60 |
61 | with closing(await connect()) as conn1:
62 | with closing(await connect()) as conn2:
63 | q1 = asyncio.create_task(query(conn1, "SELECT a FROM x"))
64 |
65 | # Wait until we know the session is paused
66 | await session.waiting.wait()
67 |
68 | await query(conn2, f"KILL QUERY {session.connection.connection_id}")
69 |
70 | with pytest.raises(DatabaseError) as ctx:
71 | await q1
72 |
73 | assert "Query was killed" in str(ctx.value.msg)
74 |
75 | session.pause.clear()
76 |
77 | assert (await query(conn1, "SELECT 1")) == [{"1": 1}]
78 |
--------------------------------------------------------------------------------
/tests/test_result.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 |
5 | from mysql_mimic import ColumnType
6 | from mysql_mimic.errors import MysqlError
7 | from mysql_mimic.results import ensure_result_set
8 |
9 |
10 | async def gen_rows() -> Any:
11 | yield 1, None, None
12 | yield None, "2", None
13 | yield None, None, None
14 |
15 |
16 | @pytest.mark.asyncio
17 | @pytest.mark.parametrize(
18 | "result, column_types",
19 | [
20 | (
21 | (gen_rows(), ["a", "b", "c"]),
22 | [ColumnType.LONGLONG, ColumnType.STRING, ColumnType.NULL],
23 | ),
24 | ],
25 | )
26 | async def test_ensure_result_set_columns(result: Any, column_types: Any) -> None:
27 | result_set = await ensure_result_set(result)
28 | assert [c.type for c in result_set.columns] == column_types
29 |
30 |
31 | @pytest.mark.asyncio
32 | @pytest.mark.parametrize(
33 | "result",
34 | [
35 | [1, 2],
36 | ([[1, 2]], ["a", "b"], ["a", "b"]),
37 | ],
38 | )
39 | async def test_ensure_result_set__invalid(result: Any) -> None:
40 | with pytest.raises(MysqlError):
41 | await ensure_result_set(result)
42 |
--------------------------------------------------------------------------------
/tests/test_schema.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from sqlglot import expressions as exp
3 |
4 | from mysql_mimic.results import ensure_result_set
5 | from mysql_mimic.schema import (
6 | Column,
7 | InfoSchema,
8 | mapping_to_columns,
9 | show_statement_to_info_schema_query,
10 | )
11 | from mysql_mimic.utils import aiterate
12 |
13 |
14 | @pytest.mark.asyncio
15 | def test_mapping_to_columns() -> None:
16 | schema = {
17 | "table_1": {
18 | "col_1": "TEXT",
19 | "col_2": "INT",
20 | }
21 | }
22 |
23 | columns = mapping_to_columns(schema=schema)
24 |
25 | assert columns[0] == Column(
26 | name="col_1", type="TEXT", table="table_1", schema="", catalog="def"
27 | )
28 | assert columns[1] == Column(
29 | name="col_2", type="INT", table="table_1", schema="", catalog="def"
30 | )
31 |
32 |
33 | @pytest.mark.asyncio
34 | async def test_info_schema_from_columns() -> None:
35 | input_columns = [
36 | Column(
37 | name="col_1",
38 | type="TEXT",
39 | table="table_1",
40 | schema="my_db",
41 | catalog="def",
42 | comment="This is a comment",
43 | ),
44 | Column(
45 | name="col_1", type="TEXT", table="table_2", schema="my_db", catalog="def"
46 | ),
47 | ]
48 | schema = InfoSchema.from_columns(columns=input_columns)
49 | table_query = show_statement_to_info_schema_query(exp.Show(this="TABLES"), "my_db")
50 | result_set = await ensure_result_set(await schema.query(table_query))
51 | table_names = [row[0] async for row in aiterate(result_set.rows)]
52 | assert table_names == ["table_1", "table_2"]
53 |
54 | column_query = show_statement_to_info_schema_query(
55 | exp.Show(this="COLUMNS", full=True, target="table_1"), "my_db"
56 | )
57 | column_results = await ensure_result_set(await schema.query(column_query))
58 | columns = [row async for row in aiterate(column_results.rows)]
59 | assert columns == [
60 | (
61 | "col_1",
62 | "TEXT",
63 | "YES",
64 | None,
65 | None,
66 | None,
67 | "NULL",
68 | None,
69 | "This is a comment",
70 | )
71 | ]
72 |
--------------------------------------------------------------------------------
/tests/test_ssl.py:
--------------------------------------------------------------------------------
1 | import os
2 | from contextlib import closing
3 | from ssl import SSLContext, PROTOCOL_TLS_SERVER
4 |
5 | import pytest
6 |
7 | from tests.conftest import MockSession, ConnectFixture, query
8 |
9 | CURRENT_DIR = os.path.dirname(__file__)
10 | SSL_CERT = os.path.join(CURRENT_DIR, "fixtures/certificate.pem")
11 | SSL_KEY = os.path.join(CURRENT_DIR, "fixtures/key.pem")
12 |
13 |
14 | @pytest.fixture(autouse=True)
15 | def ssl() -> SSLContext:
16 | sslcontext = SSLContext(PROTOCOL_TLS_SERVER)
17 | sslcontext.load_cert_chain(certfile=SSL_CERT, keyfile=SSL_KEY)
18 | return sslcontext
19 |
20 |
21 | @pytest.mark.asyncio
22 | @pytest.mark.skipif(
23 | os.getenv("CI") == "true",
24 | reason="""
25 | There seems to be a sequencing issue with the asyncio server starting tls.
26 | https://stackoverflow.com/questions/74187508/sequencing-issue-when-upgrading-python-asyncio-connection-to-tls
27 | Didn't have time to debug this, so skipping in CI for now.
28 | """,
29 | )
30 | async def test_ssl(
31 | session: MockSession,
32 | connect: ConnectFixture,
33 | port: int,
34 | ) -> None:
35 | with closing(await connect()) as conn:
36 | results = await query(conn, "SELECT 1 AS a")
37 | assert results == [{"a": 1}]
38 |
--------------------------------------------------------------------------------
/tests/test_stream.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from mysql_mimic.errors import MysqlError
4 | from mysql_mimic.stream import MysqlStream
5 |
6 |
7 | class MockReader:
8 | def __init__(self, data: bytes):
9 | self.data = data
10 | self.pos = 0
11 |
12 | async def read(self, n: int) -> bytes:
13 | start = self.pos
14 | end = self.pos + n
15 | self.pos = end
16 |
17 | return self.data[start:end]
18 |
19 | readexactly = read
20 |
21 |
22 | class MockWriter:
23 | def __init__(self) -> None:
24 | self.data = b""
25 |
26 | def write(self, data: bytes) -> None:
27 | self.data += data
28 |
29 | async def drain(self) -> None:
30 | return
31 |
32 |
33 | def test_seq() -> None:
34 | s1 = MysqlStream(reader=None, writer=None) # type: ignore
35 | assert next(s1.seq) == 0
36 | assert next(s1.seq) == 1
37 | assert next(s1.seq) == 2
38 | s2 = MysqlStream(reader=None, writer=None) # type: ignore
39 | assert next(s2.seq) == 0
40 | assert next(s2.seq) == 1
41 | assert next(s2.seq) == 2
42 |
43 |
44 | def test_reset_seq() -> None:
45 | s = MysqlStream(reader=None, writer=None) # type: ignore
46 | assert next(s.seq) == 0
47 | assert next(s.seq) == 1
48 | s.reset_seq()
49 | assert next(s.seq) == 0
50 | assert next(s.seq) == 1
51 |
52 |
53 | @pytest.mark.asyncio
54 | async def test_bad_seq_read() -> None:
55 | reader = MockReader(b"\x00\x00\x00\x01")
56 | s = MysqlStream(reader=reader, writer=None) # type: ignore
57 | with pytest.raises(MysqlError):
58 | await s.read()
59 |
60 |
61 | @pytest.mark.asyncio
62 | async def test_empty_read() -> None:
63 | reader = MockReader(b"\x00\x00\x00\x00")
64 | s = MysqlStream(reader=reader, writer=None) # type: ignore
65 | assert await s.read() == b""
66 | assert next(s.seq) == 1
67 |
68 |
69 | @pytest.mark.asyncio
70 | async def test_small_read() -> None:
71 | reader = MockReader(b"\x01\x00\x00\x00k")
72 | s = MysqlStream(reader=reader, writer=None) # type: ignore
73 | assert await s.read() == b"k"
74 | assert next(s.seq) == 1
75 |
76 |
77 | @pytest.mark.asyncio
78 | async def test_medium_read() -> None:
79 | reader = MockReader(b"\xff\xff\x00\x00" + bytes(0xFFFF))
80 | s = MysqlStream(reader=reader, writer=None) # type: ignore
81 | assert await s.read() == bytes(0xFFFF)
82 | assert next(s.seq) == 1
83 |
84 |
85 | @pytest.mark.asyncio
86 | async def test_edge_read() -> None:
87 | reader = MockReader(b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x00\x00\x00\x01")
88 | s = MysqlStream(reader=reader, writer=None) # type: ignore
89 | assert await s.read() == bytes(0xFFFFFF)
90 | assert next(s.seq) == 2
91 |
92 |
93 | @pytest.mark.asyncio
94 | async def test_large_read() -> None:
95 | reader = MockReader(
96 | b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x06\x00\x00\x01" + b"kelsin"
97 | )
98 | s = MysqlStream(reader=reader, writer=None) # type: ignore
99 | assert await s.read() == bytes(0xFFFFFF) + b"kelsin"
100 | assert next(s.seq) == 2
101 |
102 |
103 | @pytest.mark.asyncio
104 | async def test_empty_write() -> None:
105 | writer = MockWriter()
106 | s = MysqlStream(reader=None, writer=writer) # type: ignore
107 | await s.write(b"")
108 | assert writer.data == b"\x00\x00\x00\x00"
109 | assert next(s.seq) == 1
110 |
111 |
112 | @pytest.mark.asyncio
113 | async def test_small_write() -> None:
114 | writer = MockWriter()
115 | s = MysqlStream(reader=None, writer=writer) # type: ignore
116 | await s.write(b"kelsin")
117 | assert writer.data == b"\x06\x00\x00\x00kelsin"
118 | assert next(s.seq) == 1
119 |
120 |
121 | @pytest.mark.asyncio
122 | async def test_medium_write() -> None:
123 | writer = MockWriter()
124 | s = MysqlStream(reader=None, writer=writer) # type: ignore
125 | await s.write(bytes(0xFFFF))
126 | assert writer.data == b"\xff\xff\x00\x00" + bytes(0xFFFF)
127 | assert next(s.seq) == 1
128 |
129 |
130 | @pytest.mark.asyncio
131 | async def test_edge_write() -> None:
132 | writer = MockWriter()
133 | s = MysqlStream(reader=None, writer=writer) # type: ignore
134 | await s.write(bytes(0xFFFFFF))
135 | assert writer.data == b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x00\x00\x00\x01"
136 | assert next(s.seq) == 2
137 |
138 |
139 | @pytest.mark.asyncio
140 | async def test_large_write() -> None:
141 | writer = MockWriter()
142 | s = MysqlStream(reader=None, writer=writer) # type: ignore
143 | await s.write(bytes(0xFFFFFF) + b"kelsin")
144 | assert (
145 | writer.data == b"\xff\xff\xff\x00" + bytes(0xFFFFFF) + b"\x06\x00\x00\x01kelsin"
146 | )
147 | assert next(s.seq) == 2
148 |
--------------------------------------------------------------------------------
/tests/test_types.py:
--------------------------------------------------------------------------------
1 | import io
2 | import struct
3 |
4 | import pytest
5 |
6 | from mysql_mimic import types
7 |
8 |
9 | def test_column_definition() -> None:
10 | assert types.ColumnDefinition.NOT_NULL_FLAG == 1
11 | assert types.ColumnDefinition.PRI_KEY_FLAG == 2
12 | assert types.ColumnDefinition.FIELD_IS_INVISIBLE == 1 << 30
13 | assert (
14 | types.ColumnDefinition.NOT_NULL_FLAG | types.ColumnDefinition.PRI_KEY_FLAG == 3
15 | )
16 |
17 |
18 | def test_capabilities() -> None:
19 | assert types.Capabilities.CLIENT_LONG_PASSWORD == 1
20 | assert types.Capabilities.CLIENT_FOUND_ROWS == 2
21 | assert types.Capabilities.CLIENT_PROTOCOL_41 == 512
22 | assert types.Capabilities.CLIENT_DEPRECATE_EOF == 1 << 24
23 | assert (
24 | types.Capabilities.CLIENT_LONG_PASSWORD | types.Capabilities.CLIENT_FOUND_ROWS
25 | == 3
26 | )
27 |
28 |
29 | def test_server_status() -> None:
30 | assert types.ServerStatus.SERVER_STATUS_IN_TRANS == 1
31 | assert types.ServerStatus.SERVER_STATUS_AUTOCOMMIT == 2
32 | assert types.ServerStatus.SERVER_STATUS_LAST_ROW_SENT == 128
33 | assert (
34 | types.ServerStatus.SERVER_STATUS_IN_TRANS
35 | | types.ServerStatus.SERVER_STATUS_AUTOCOMMIT
36 | == 3
37 | )
38 |
39 |
40 | def test_int_1() -> None:
41 | assert types.uint_1(0) == b"\x00"
42 | assert types.uint_1(1) == b"\x01"
43 | assert types.uint_1(255) == b"\xff"
44 | pytest.raises(struct.error, types.uint_1, 256)
45 |
46 |
47 | def test_int_2() -> None:
48 | assert types.uint_2(0) == b"\x00\x00"
49 | assert types.uint_2(1) == b"\x01\x00"
50 | assert types.uint_2(255) == b"\xff\x00"
51 | assert types.uint_2(2**8) == b"\x00\x01"
52 | assert types.uint_2(2**16 - 1) == b"\xff\xff"
53 | pytest.raises(struct.error, types.uint_2, 2**16)
54 |
55 |
56 | def test_int_3() -> None:
57 | assert types.uint_3(0) == b"\x00\x00\x00"
58 | assert types.uint_3(1) == b"\x01\x00\x00"
59 | assert types.uint_3(255) == b"\xff\x00\x00"
60 | assert types.uint_3(2**8) == b"\x00\x01\x00"
61 | assert types.uint_3(2**16) == b"\x00\x00\x01"
62 | pytest.raises(struct.error, types.uint_3, 2**24)
63 |
64 |
65 | def test_int_4() -> None:
66 | assert types.uint_4(0) == b"\x00\x00\x00\x00"
67 | assert types.uint_4(1) == b"\x01\x00\x00\x00"
68 | assert types.uint_4(255) == b"\xff\x00\x00\x00"
69 | assert types.uint_4(2**8) == b"\x00\x01\x00\x00"
70 | assert types.uint_4(2**16) == b"\x00\x00\x01\x00"
71 | assert types.uint_4(2**24) == b"\x00\x00\x00\x01"
72 | pytest.raises(struct.error, types.uint_4, 2**32)
73 |
74 |
75 | def test_int_6() -> None:
76 | assert types.uint_6(0) == b"\x00\x00\x00\x00\x00\x00"
77 | assert types.uint_6(1) == b"\x01\x00\x00\x00\x00\x00"
78 | assert types.uint_6(255) == b"\xff\x00\x00\x00\x00\x00"
79 | assert types.uint_6(2**8) == b"\x00\x01\x00\x00\x00\x00"
80 | assert types.uint_6(2**16) == b"\x00\x00\x01\x00\x00\x00"
81 | assert types.uint_6(2**24) == b"\x00\x00\x00\x01\x00\x00"
82 | assert types.uint_6(2**32) == b"\x00\x00\x00\x00\x01\x00"
83 | assert types.uint_6(2**40) == b"\x00\x00\x00\x00\x00\x01"
84 | pytest.raises(struct.error, types.uint_6, 2**48)
85 |
86 |
87 | def test_int_8() -> None:
88 | assert types.uint_8(0) == b"\x00\x00\x00\x00\x00\x00\x00\x00"
89 | assert types.uint_8(1) == b"\x01\x00\x00\x00\x00\x00\x00\x00"
90 | assert types.uint_8(255) == b"\xff\x00\x00\x00\x00\x00\x00\x00"
91 | assert types.uint_8(2**8) == b"\x00\x01\x00\x00\x00\x00\x00\x00"
92 | assert types.uint_8(2**16) == b"\x00\x00\x01\x00\x00\x00\x00\x00"
93 | assert types.uint_8(2**24) == b"\x00\x00\x00\x01\x00\x00\x00\x00"
94 | assert types.uint_8(2**32) == b"\x00\x00\x00\x00\x01\x00\x00\x00"
95 | assert types.uint_8(2**56) == b"\x00\x00\x00\x00\x00\x00\x00\x01"
96 | pytest.raises(struct.error, types.uint_8, 2**64)
97 |
98 |
99 | def test_int_len() -> None:
100 | assert types.uint_len(0) == b"\x00"
101 | assert types.uint_len(1) == b"\x01"
102 | assert types.uint_len(250) == b"\xfa"
103 | assert types.uint_len(251) == b"\xfc\xfb\x00"
104 | assert types.uint_len(2**8) == b"\xfc\x00\x01"
105 | assert types.uint_len(2**16) == b"\xfd\x00\x00\x01"
106 | assert types.uint_len(2**24) == b"\xfe\x00\x00\x00\x01\x00\x00\x00\x00"
107 | assert types.uint_len(2**32) == b"\xfe\x00\x00\x00\x00\x01\x00\x00\x00"
108 | assert types.uint_len(2**56) == b"\xfe\x00\x00\x00\x00\x00\x00\x00\x01"
109 | pytest.raises(struct.error, types.uint_len, 2**64)
110 |
111 |
112 | def test_str_fixed() -> None:
113 | assert types.str_fixed(0, b"kelsin") == b""
114 | assert types.str_fixed(1, b"kelsin") == b"k"
115 | assert types.str_fixed(6, b"kelsin") == b"kelsin"
116 | assert types.str_fixed(8, b"kelsin") == b"kelsin\x00\x00"
117 |
118 |
119 | def test_str_null() -> None:
120 | assert types.str_null(b"") == b"\x00"
121 | assert types.str_null(b"kelsin") == b"kelsin\x00"
122 |
123 |
124 | def test_str_len() -> None:
125 | assert types.str_len(b"") == b"\x00"
126 | assert types.str_len(b"kelsin") == b"\x06kelsin"
127 |
128 | big_str = bytes(256)
129 | assert types.str_len(big_str) == b"\xfc\x00\x01" + big_str
130 |
131 |
132 | def test_str_rest() -> None:
133 | assert types.str_rest(b"") == b""
134 | assert types.str_rest(b"kelsin") == b"kelsin"
135 |
136 | big_str = bytes(256)
137 | assert types.str_rest(big_str) == big_str
138 |
139 |
140 | def test_read_int_1() -> None:
141 | reader = io.BytesIO(b"\x01")
142 | assert types.read_uint_1(reader) == 1
143 | reader = io.BytesIO(b"\x00\x00")
144 | assert types.read_uint_1(reader) == 0
145 | assert types.read_uint_1(reader) == 0
146 |
147 |
148 | def test_read_int_2() -> None:
149 | reader = io.BytesIO(b"\x00\x01")
150 | assert types.read_uint_2(reader) == 256
151 |
152 |
153 | def test_read_int_3() -> None:
154 | reader = io.BytesIO(b"\x00\x00\x01")
155 | assert types.read_uint_3(reader) == 2**16
156 |
157 |
158 | def test_read_int_4() -> None:
159 | reader = io.BytesIO(b"\x01\x00\x00\x00")
160 | assert types.read_uint_4(reader) == 1
161 | reader = io.BytesIO(b"\x00\x00\x01\x00")
162 | assert types.read_uint_4(reader) == 2**16
163 | reader = io.BytesIO(b"\x00\x00\x00\x01")
164 | assert types.read_uint_4(reader) == 2**24
165 |
166 |
167 | def test_read_int_6() -> None:
168 | reader = io.BytesIO(b"\x01\x00\x00\x00\x00\x00")
169 | assert types.read_uint_6(reader) == 1
170 | reader = io.BytesIO(b"\x00\x00\x00\x00\x01\x00")
171 | assert types.read_uint_6(reader) == 2**32
172 | reader = io.BytesIO(b"\x00\x00\x00\x00\x00\x01")
173 | assert types.read_uint_6(reader) == 2**40
174 |
175 |
176 | def test_read_int_8() -> None:
177 | reader = io.BytesIO(b"\x01\x00\x00\x00\x00\x00\x00\x00")
178 | assert types.read_uint_8(reader) == 1
179 | reader = io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x01\x00")
180 | assert types.read_uint_8(reader) == 2**48
181 | reader = io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x01")
182 | assert types.read_uint_8(reader) == 2**56
183 |
184 |
185 | def test_read_int_len() -> None:
186 | reader = io.BytesIO(b"\x01")
187 | assert types.read_uint_len(reader) == 1
188 | reader = io.BytesIO(b"\xfa")
189 | assert types.read_uint_len(reader) == 250
190 | reader = io.BytesIO(b"\xfc\x00\x01")
191 | assert types.read_uint_len(reader) == 256
192 | reader = io.BytesIO(b"\xfd\x00\x00\x01")
193 | assert types.read_uint_len(reader) == 2**16
194 | reader = io.BytesIO(b"\xfe\x00\x00\x00\x00\x00\x00\x00\x01")
195 | assert types.read_uint_len(reader) == 2**56
196 |
197 |
198 | def test_read_str_fixed() -> None:
199 | reader = io.BytesIO(b"kelsin")
200 | assert types.read_str_fixed(reader, 1) == b"k"
201 | assert types.read_str_fixed(reader, 2) == b"el"
202 | assert types.read_str_fixed(reader, 3) == b"sin"
203 |
204 |
205 | def test_read_str_null() -> None:
206 | reader = io.BytesIO(b"\x00")
207 | assert types.read_str_null(reader) == b""
208 | reader = io.BytesIO(b"kelsin\x00")
209 | assert types.read_str_null(reader) == b"kelsin"
210 | reader = io.BytesIO(b"kelsin\x00foo")
211 | assert types.read_str_null(reader) == b"kelsin"
212 |
213 |
214 | def test_read_str_len() -> None:
215 | reader = io.BytesIO(b"\x01kelsin")
216 | assert types.read_str_len(reader) == b"k"
217 |
218 | big_str = bytes(256)
219 | reader = io.BytesIO(b"\xfc\x00\x01" + big_str)
220 | assert types.read_str_len(reader) == big_str
221 |
222 |
223 | def test_read_str_rest() -> None:
224 | reader = io.BytesIO(b"kelsin")
225 | assert types.read_str_rest(reader) == b"kelsin"
226 | reader = io.BytesIO(b"kelsin\x00foo")
227 | assert types.read_str_rest(reader) == b"kelsin\x00foo"
228 |
--------------------------------------------------------------------------------
/tests/test_variables.py:
--------------------------------------------------------------------------------
1 | from datetime import timezone, timedelta
2 | from typing import Dict
3 |
4 | import pytest
5 |
6 | from mysql_mimic.errors import MysqlError
7 | from mysql_mimic.variables import (
8 | parse_timezone,
9 | Variables,
10 | VariableSchema,
11 | )
12 |
13 |
14 | class TestVars(Variables):
15 |
16 | @property
17 | def schema(self) -> Dict[str, VariableSchema]:
18 | return {"foo": (str, "bar", True)}
19 |
20 |
21 | def test_parse_timezone() -> None:
22 | assert timezone(timedelta()) == parse_timezone("UTC")
23 | assert timezone(timedelta()) == parse_timezone("+00:00")
24 | assert timezone(timedelta(hours=1)) == parse_timezone("+01:00")
25 | assert timezone(timedelta(hours=-1)) == parse_timezone("-01:00")
26 |
27 | # Implicitly test cache
28 | assert timezone(timedelta(hours=-1)) == parse_timezone("-01:00")
29 |
30 | with pytest.raises(MysqlError):
31 | parse_timezone("whoops")
32 |
33 |
34 | def test_variable_mapping() -> None:
35 | test_vars = TestVars()
36 | assert test_vars
37 | assert len(test_vars) == 1
38 |
39 | assert test_vars.get_variable("foo") == "bar"
40 | assert test_vars["foo"] == "bar"
41 |
42 | test_vars["foo"] = "hello"
43 | assert test_vars.get_variable("foo") == "hello"
44 | assert test_vars["foo"] == "hello"
45 |
46 | with pytest.raises(KeyError):
47 | assert test_vars["world"]
48 |
49 | with pytest.raises(MysqlError):
50 | test_vars["world"] = "hello"
51 |
--------------------------------------------------------------------------------
/tests/test_version.py:
--------------------------------------------------------------------------------
1 | import io
2 | import contextlib
3 |
4 | from mysql_mimic import version
5 |
6 |
7 | def test_version() -> None:
8 | assert isinstance(version.__version__, str)
9 |
10 |
11 | def test_main() -> None:
12 | out = io.StringIO()
13 | with contextlib.redirect_stdout(out):
14 | version.main("__main__")
15 |
16 | assert f"{version.__version__}\n" == out.getvalue()
17 |
--------------------------------------------------------------------------------