├── tests ├── __init__.py ├── waiting_for_stack_up.sh ├── test_mixins.py ├── test_utils.py ├── test_sql.py ├── test_database.py ├── common.py └── test_query.py ├── dev ├── trino-conf │ └── etc │ │ ├── catalog │ │ ├── jms.properties │ │ ├── tpcds.properties │ │ ├── tpch.properties │ │ ├── memory.properties │ │ └── postgresql.properties │ │ ├── node.properties │ │ ├── config.properties │ │ └── jvm.config ├── presto-conf │ └── standalone │ │ ├── catalog │ │ ├── jmx.properties │ │ ├── memory.properties │ │ ├── tpcds.properties │ │ ├── tpch.properties │ │ └── postgresql.properties │ │ ├── log.properties │ │ ├── password.db │ │ ├── node.properties │ │ ├── password-authenticator.properties │ │ ├── jvm.config │ │ ├── config.properties │ │ └── combined.pem ├── dev.env └── Dockerfile.prestosql.340 ├── sqeleton ├── __init__.py ├── abcs │ ├── __init__.py │ ├── compiler.py │ └── mixins.py ├── queries │ ├── base.py │ ├── __init__.py │ ├── extras.py │ ├── compiler.py │ └── api.py ├── databases │ ├── __init__.py │ ├── mssql.py │ ├── trino.py │ ├── mysql.py │ ├── vertica.py │ ├── postgresql.py │ ├── duckdb.py │ ├── presto.py │ ├── redshift.py │ ├── oracle.py │ ├── clickhouse.py │ ├── databricks.py │ ├── bigquery.py │ ├── snowflake.py │ └── _connect.py ├── __main__.py ├── schema.py ├── query_utils.py ├── repl.py ├── bound_exprs.py ├── conn_editor.py └── utils.py ├── .github ├── ISSUE_TEMPLATE │ ├── request-support-for-a-database.md │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── ci_full.yml │ └── ci.yml ├── docs ├── requirements.txt ├── Makefile ├── make.bat ├── index.rst ├── conn_editor.md ├── python-api.rst ├── install.md ├── supported-databases.md └── conf.py ├── LICENSE ├── pyproject.toml ├── .gitignore ├── docker-compose.yml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/jms.properties: -------------------------------------------------------------------------------- 1 | connector.name=jmx 2 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/tpcds.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpcds 2 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/tpch.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpch 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/jmx.properties: -------------------------------------------------------------------------------- 1 | connector.name=jmx 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/log.properties: -------------------------------------------------------------------------------- 1 | com.facebook.presto=INFO 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/memory.properties: -------------------------------------------------------------------------------- 1 | connector.name=memory 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/tpcds.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpcds 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/tpch.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpch 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/password.db: -------------------------------------------------------------------------------- 1 | test:$2y$10$877iU3J5a26SPDjFSrrz2eFAq2DwMDsBAus92Dj0z5A5qNMNlnpHa 2 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/memory.properties: -------------------------------------------------------------------------------- 1 | connector.name=memory 2 | memory.max-data-per-node=128MB 3 | -------------------------------------------------------------------------------- /sqeleton/__init__.py: -------------------------------------------------------------------------------- 1 | from .databases import connect 2 | from .queries import table, this, SKIP, code 3 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/node.properties: -------------------------------------------------------------------------------- 1 | node.environment=production 2 | node.data-dir=/data 3 | node.id=standalone 4 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/node.properties: -------------------------------------------------------------------------------- 1 | node.environment=docker 2 | node.data-dir=/data/trino 3 | plugin.dir=/usr/lib/trino/plugin 4 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/password-authenticator.properties: -------------------------------------------------------------------------------- 1 | password-authenticator.name=file 2 | file.password-file=/opt/presto/etc/password.db 3 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/postgresql.properties: -------------------------------------------------------------------------------- 1 | connector.name=postgresql 2 | connection-url=jdbc:postgresql://postgres:5432/postgres 3 | connection-user=postgres 4 | connection-password=Password1 5 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/config.properties: -------------------------------------------------------------------------------- 1 | coordinator=true 2 | node-scheduler.include-coordinator=true 3 | http-server.http.port=8080 4 | discovery.uri=http://localhost:8080 5 | discovery-server.enabled=true 6 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/postgresql.properties: -------------------------------------------------------------------------------- 1 | connector.name=postgresql 2 | connection-url=jdbc:postgresql://postgres:5432/postgres 3 | connection-user=postgres 4 | connection-password=Password1 5 | allow-drop-table=true 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/request-support-for-a-database.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Request support for a database 3 | about: 'Request a driver to support a new database ' 4 | title: 'Add support for ' 5 | labels: new-db-driver 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/jvm.config: -------------------------------------------------------------------------------- 1 | -server 2 | -Xmx16G 3 | -XX:+UseG1GC 4 | -XX:G1HeapRegionSize=32M 5 | -XX:+UseGCOverheadLimit 6 | -XX:+ExplicitGCInvokesConcurrent 7 | -XX:+HeapDumpOnOutOfMemoryError 8 | -XX:+ExitOnOutOfMemoryError 9 | -XX:OnOutOfMemoryError=kill -9 %p 10 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/guides/specifying-dependencies.html#specifying-a-requirements-file 2 | sphinx-gallery 3 | sphinx_markdown_tables 4 | sphinx-copybutton 5 | sphinx-rtd-theme 6 | recommonmark 7 | enum-tools[sphinx] 8 | 9 | sqeleton 10 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/config.properties: -------------------------------------------------------------------------------- 1 | coordinator=true 2 | node-scheduler.include-coordinator=true 3 | http-server.http.port=8080 4 | query.max-memory=5GB 5 | query.max-memory-per-node=1GB 6 | query.max-total-memory-per-node=2GB 7 | discovery-server.enabled=true 8 | discovery.uri=http://127.0.0.1:8080 9 | -------------------------------------------------------------------------------- /sqeleton/abcs/__init__.py: -------------------------------------------------------------------------------- 1 | from .database_types import ( 2 | AbstractDatabase, 3 | AbstractDialect, 4 | DbKey, 5 | DbPath, 6 | DbTime, 7 | IKey, 8 | ColType_UUID, 9 | NumericType, 10 | PrecisionType, 11 | StringType, 12 | Boolean, 13 | JSONType, 14 | ) 15 | from .compiler import AbstractCompiler, Compilable 16 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/jvm.config: -------------------------------------------------------------------------------- 1 | -server 2 | -Xmx1G 3 | -XX:-UseBiasedLocking 4 | -XX:+UseG1GC 5 | -XX:G1HeapRegionSize=32M 6 | -XX:+ExplicitGCInvokesConcurrent 7 | -XX:+HeapDumpOnOutOfMemoryError 8 | -XX:+UseGCOverheadLimit 9 | -XX:+ExitOnOutOfMemoryError 10 | -XX:ReservedCodeCacheSize=256M 11 | -Djdk.attach.allowAttachSelf=true 12 | -Djdk.nio.maxCachedBufferSize=2000000 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Describe the environment** 14 | 15 | Describe which OS you're using, which sqeleton version, and any other information that might be relevant to this bug. 16 | -------------------------------------------------------------------------------- /sqeleton/abcs/compiler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class AbstractCompiler(ABC): 6 | @abstractmethod 7 | def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: 8 | ... 9 | 10 | 11 | class Compilable(ABC): 12 | # TODO generic syntax, so we can write Compilable[T] for expressions returning a value of type T 13 | @abstractmethod 14 | def compile(self, c: AbstractCompiler) -> str: 15 | ... 16 | -------------------------------------------------------------------------------- /sqeleton/queries/base.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | from ..abcs import DbPath, DbKey 4 | from ..schema import Schema 5 | 6 | 7 | class _SKIP: 8 | def __repr__(self): 9 | return "SKIP" 10 | 11 | 12 | SKIP = _SKIP() 13 | 14 | 15 | class SqeletonError(Exception): 16 | pass 17 | 18 | 19 | def args_as_tuple(exprs): 20 | if len(exprs) == 1: 21 | (e,) = exprs 22 | if isinstance(e, Generator): 23 | return tuple(e) 24 | return exprs 25 | -------------------------------------------------------------------------------- /tests/waiting_for_stack_up.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -n "$VERTICA_URI" ] 4 | then 5 | echo "Check Vertica DB running..." 6 | while true 7 | do 8 | if docker logs dd-vertica | tail -n 100 | grep -q -i "vertica is now running" 9 | then 10 | echo "Vertica DB is ready"; 11 | break; 12 | else 13 | echo "Waiting for Vertica DB starting..."; 14 | sleep 10; 15 | fi 16 | done 17 | fi 18 | -------------------------------------------------------------------------------- /sqeleton/queries/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiler import Compiler, CompileError 2 | from .api import ( 3 | this, 4 | join, 5 | outerjoin, 6 | table, 7 | SKIP, 8 | sum_, 9 | avg, 10 | min_, 11 | max_, 12 | cte, 13 | commit, 14 | when, 15 | coalesce, 16 | and_, 17 | if_, 18 | or_, 19 | leftjoin, 20 | rightjoin, 21 | current_timestamp, 22 | code, 23 | ) 24 | from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column 25 | from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString 26 | -------------------------------------------------------------------------------- /sqeleton/databases/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database 2 | from ..abcs import DbPath, DbKey, DbTime 3 | from ._connect import Connect 4 | 5 | from .postgresql import PostgreSQL 6 | from .mysql import MySQL 7 | from .oracle import Oracle 8 | from .snowflake import Snowflake 9 | from .bigquery import BigQuery 10 | from .redshift import Redshift 11 | from .presto import Presto 12 | from .databricks import Databricks 13 | from .trino import Trino 14 | from .clickhouse import Clickhouse 15 | from .vertica import Vertica 16 | from .duckdb import DuckDB 17 | 18 | connect = Connect() 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = sqeleton 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /sqeleton/__main__.py: -------------------------------------------------------------------------------- 1 | import click 2 | from .repl import repl as repl_main 3 | 4 | 5 | @click.group(no_args_is_help=True) 6 | def main(): 7 | pass 8 | 9 | 10 | @main.command(no_args_is_help=True) 11 | @click.argument("database", required=True) 12 | def repl(database): 13 | return repl_main(database) 14 | 15 | 16 | CONN_EDITOR_HELP = """CONFIG_PATH - Path to a TOML config file of db connections, new or existing.""" 17 | 18 | 19 | @main.command(no_args_is_help=True, help=CONN_EDITOR_HELP) 20 | @click.argument("config_path", required=True) 21 | def conn_editor(config_path): 22 | from .conn_editor import main 23 | 24 | return main(config_path) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /dev/dev.env: -------------------------------------------------------------------------------- 1 | POSTGRES_USER=postgres 2 | POSTGRES_PASSWORD=Password1 3 | POSTGRES_DB=postgres 4 | 5 | MYSQL_DATABASE=mysql 6 | MYSQL_USER=mysql 7 | MYSQL_PASSWORD=Password1 8 | MYSQL_ROOT_PASSWORD=RootPassword1 9 | 10 | CLICKHOUSE_USER=clickhouse 11 | CLICKHOUSE_PASSWORD=Password1 12 | CLICKHOUSE_DB=clickhouse 13 | CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT=1 14 | 15 | # Vertica credentials 16 | APP_DB_USER=vertica 17 | APP_DB_PASSWORD=Password1 18 | VERTICA_DB_NAME=vertica 19 | 20 | # To prevent generating sample demo VMart data (more about it here https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/GettingStartedGuide/IntroducingVMart/IntroducingVMart.htm), 21 | # leave VMART_DIR and VMART_ETL_SCRIPT empty. 22 | VMART_DIR= 23 | VMART_ETL_SCRIPT= 24 | -------------------------------------------------------------------------------- /sqeleton/schema.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict 4 | from .abcs import AbstractDatabase, DbPath 5 | 6 | logger = logging.getLogger("schema") 7 | 8 | Schema = CaseAwareMapping 9 | 10 | 11 | def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: 12 | logger.debug(f"[{db.name}] Schema = {schema}") 13 | 14 | if case_sensitive: 15 | return CaseSensitiveDict(schema) 16 | 17 | if len({k.lower() for k in schema}) < len(schema): 18 | logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') 19 | logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).") 20 | return CaseInsensitiveDict(schema) 21 | -------------------------------------------------------------------------------- /dev/Dockerfile.prestosql.340: -------------------------------------------------------------------------------- 1 | FROM openjdk:11-jdk-slim-buster 2 | 3 | ENV PRESTO_VERSION=340 4 | ENV PRESTO_SERVER_URL=https://repo1.maven.org/maven2/io/prestosql/presto-server/${PRESTO_VERSION}/presto-server-${PRESTO_VERSION}.tar.gz 5 | ENV PRESTO_CLI_URL=https://repo1.maven.org/maven2/io/prestosql/presto-cli/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar 6 | ENV PRESTO_HOME=/opt/presto 7 | ENV PATH=${PRESTO_HOME}/bin:${PATH} 8 | 9 | WORKDIR $PRESTO_HOME 10 | 11 | RUN set -xe \ 12 | && apt-get update \ 13 | && apt-get install -y curl less python \ 14 | && curl -sSL $PRESTO_SERVER_URL | tar xz --strip 1 \ 15 | && curl -sSL $PRESTO_CLI_URL > ./bin/presto \ 16 | && chmod +x ./bin/presto \ 17 | && apt-get remove -y curl \ 18 | && rm -rf /var/lib/apt/lists/* 19 | 20 | VOLUME /data 21 | 22 | EXPOSE 8080 23 | 24 | ENTRYPOINT ["launcher"] 25 | CMD ["run"] 26 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=sqeleton 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | :caption: API Reference 4 | :hidden: 5 | 6 | install 7 | intro 8 | supported-databases 9 | conn_editor 10 | python-api 11 | 12 | Sqeleton 13 | --------- 14 | 15 | **Sqeleton** is a Python library for querying SQL databases. 16 | 17 | It consists of - 18 | 19 | - A fast and concise query builder, inspired by PyPika and SQLAlchemy 20 | 21 | - A modular database interface, with drivers for a long list of SQL databases. 22 | 23 | It is comparable to other libraries such as SQLAlchemy or PyPika, in terms of API and intended audience. However there are several notable ways in which it is different. 24 | 25 | For more information, `See our README `_ 26 | 27 | 28 | Resources 29 | --------- 30 | 31 | - :doc:`install` 32 | - :doc:`intro` 33 | - :doc:`supported-databases` 34 | - :doc:`conn_editor` 35 | - :doc:`python-api` 36 | - Source code (git): ``_ 37 | 38 | -------------------------------------------------------------------------------- /sqeleton/databases/mssql.py: -------------------------------------------------------------------------------- 1 | # class MsSQL(ThreadedDatabase): 2 | # "AKA sql-server" 3 | 4 | # def __init__(self, host, port, user, password, *, database, thread_count, **kw): 5 | # args = dict(server=host, port=port, database=database, user=user, password=password, **kw) 6 | # self._args = {k: v for k, v in args.items() if v is not None} 7 | 8 | # super().__init__(thread_count=thread_count) 9 | 10 | # def create_connection(self): 11 | # mssql = import_mssql() 12 | # try: 13 | # return mssql.connect(**self._args) 14 | # except mssql.Error as e: 15 | # raise ConnectError(*e.args) from e 16 | 17 | # def quote(self, s: str): 18 | # return f"[{s}]" 19 | 20 | # def md5_as_int(self, s: str) -> str: 21 | # return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" 22 | # # return f"CONVERT(bigint, (CHECKSUM({s})))" 23 | 24 | # def to_string(self, s: str): 25 | # return f"CONVERT(varchar, {s})" 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Datafold 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/conn_editor.md: -------------------------------------------------------------------------------- 1 | # Connection Editor 2 | 3 | A common complaint among new users was the difficulty in setting up the connections. 4 | 5 | Connection URLs are admittedly confusing, and editing `.toml` files isn't always straight-forward either. 6 | 7 | To ease this initial difficulty, we added a `textual`-based TUI tool to sqeleton, that allows users to edit configuration files and test the connections while editing them. 8 | 9 | ## Install 10 | 11 | This tool needs `textual` to run. You can install it using: 12 | 13 | ```bash 14 | pip install 'sqeleton[tui]' 15 | ``` 16 | 17 | Make sure you also have drivers for whatever database connection you're going to edit! 18 | 19 | ## Run 20 | 21 | Once everything is installed, you can run the editor with the following command: 22 | 23 | ```bash 24 | sqeleton conn-editor 25 | ``` 26 | 27 | Example: 28 | 29 | ```bash 30 | sqeleton conn-editor ~/dbs.toml 31 | ``` 32 | 33 | The available actions and hotkeys will be listed in the status bar. 34 | 35 | Note: using the connection editor will delete comments and reformat the file! 36 | 37 | We recommend backing up the configuration file before editing it. -------------------------------------------------------------------------------- /docs/python-api.rst: -------------------------------------------------------------------------------- 1 | ******************** 2 | Python API Reference 3 | ******************** 4 | 5 | .. py:module:: sqeleton 6 | 7 | 8 | User API 9 | ======== 10 | 11 | .. autofunction:: table 12 | 13 | .. autofunction:: code 14 | 15 | .. autoclass:: connect 16 | :members: 17 | 18 | .. autodata:: SKIP 19 | .. autodata:: this 20 | 21 | Database 22 | -------- 23 | 24 | .. automodule:: sqeleton.databases.base 25 | :members: 26 | 27 | 28 | Queries 29 | -------- 30 | 31 | .. automodule:: sqeleton.queries.api 32 | :members: 33 | 34 | Internals 35 | ========= 36 | 37 | This section is for developers who wish to improve sqeleton, or to extend it within their own project. 38 | 39 | Regular users might also find it useful for debugging and understanding, especially at this early stage of the project. 40 | 41 | 42 | Query ASTs 43 | ----------- 44 | 45 | .. automodule:: sqeleton.queries.ast_classes 46 | :members: 47 | 48 | Query Compiler 49 | -------------- 50 | 51 | .. automodule:: sqeleton.queries.compiler 52 | :members: 53 | 54 | ABCS 55 | ----- 56 | 57 | .. automodule:: sqeleton.abcs.database_types 58 | :members: 59 | 60 | .. automodule:: sqeleton.abcs.mixins 61 | :members: 62 | 63 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Install / Get started 2 | 3 | Sqeleton can be installed using pip: 4 | 5 | ``` 6 | pip install sqeleton 7 | ``` 8 | 9 | ## Database drivers 10 | 11 | To ensure that the database drivers are compatible with sqeleton, we recommend installing them along with sqeleton, using pip's `[]` syntax: 12 | 13 | - `pip install 'sqeleton[mysql]'` 14 | 15 | - `pip install 'sqeleton[postgresql]'` 16 | 17 | - `pip install 'sqeleton[snowflake]'` 18 | 19 | - `pip install 'sqeleton[presto]'` 20 | 21 | - `pip install 'sqeleton[oracle]'` 22 | 23 | - `pip install 'sqeleton[trino]'` 24 | 25 | - `pip install 'sqeleton[clickhouse]'` 26 | 27 | - `pip install 'sqeleton[vertica]'` 28 | 29 | - For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ 30 | 31 | _Some drivers have dependencies that cannot be installed using `pip` and still need to be installed manually._ 32 | 33 | 34 | It is also possible to install several databases at once. For example: 35 | 36 | ```bash 37 | pip install 'sqeleton[mysql, postgresql]' 38 | ``` 39 | 40 | Note: Some shells use `"` for escaping instead, like: 41 | 42 | ```bash 43 | pip install "sqeleton[mysql, postgresql]" 44 | ``` 45 | 46 | ## Connection editor 47 | 48 | Sqeleton provides a TUI connection editor, that can be installed using: 49 | 50 | ```bash 51 | pip install 'sqeleton[tui]' 52 | ``` 53 | 54 | Read more [here](conn_editor.md). 55 | 56 | ## What's next? 57 | 58 | Read the [introduction](intro.md) and start coding! -------------------------------------------------------------------------------- /sqeleton/databases/trino.py: -------------------------------------------------------------------------------- 1 | from ..abcs.database_types import TemporalType, ColType_UUID 2 | from . import presto 3 | from .base import import_helper 4 | from .base import TIMESTAMP_PRECISION_POS 5 | 6 | 7 | @import_helper("trino") 8 | def import_trino(): 9 | import trino 10 | 11 | return trino 12 | 13 | 14 | Mixin_MD5 = presto.Mixin_MD5 15 | 16 | 17 | class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): 18 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 19 | if coltype.rounds: 20 | s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" 21 | else: 22 | s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" 23 | 24 | return ( 25 | f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" 26 | ) 27 | 28 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 29 | return f"TRIM({value})" 30 | 31 | 32 | class Dialect(presto.Dialect): 33 | name = "Trino" 34 | 35 | 36 | class Trino(presto.Presto): 37 | dialect = Dialect() 38 | CONNECT_URI_HELP = "trino://@//" 39 | CONNECT_URI_PARAMS = ["catalog", "schema"] 40 | 41 | def __init__(self, **kw): 42 | trino = import_trino() 43 | 44 | if kw.get("schema"): 45 | self.default_schema = kw.get("schema") 46 | 47 | self._conn = trino.dbapi.connect(**kw) 48 | -------------------------------------------------------------------------------- /tests/test_mixins.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sqeleton import connect 4 | 5 | from sqeleton.abcs import AbstractDialect, AbstractDatabase 6 | from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue, AbstractMixin_RandomSample, AbstractMixin_TimeTravel 7 | 8 | 9 | class TestMixins(unittest.TestCase): 10 | def test_normalize(self): 11 | # - Test sanity 12 | ddb1 = connect("duckdb://:memory:") 13 | assert not hasattr(ddb1.dialect, "normalize_boolean") 14 | 15 | # - Test abstract mixins 16 | class NewAbstractDialect(AbstractDialect, AbstractMixin_NormalizeValue, AbstractMixin_RandomSample): 17 | pass 18 | 19 | new_connect = connect.load_mixins(AbstractMixin_NormalizeValue, AbstractMixin_RandomSample) 20 | ddb2: AbstractDatabase[NewAbstractDialect] = new_connect("duckdb://:memory:") 21 | # Implementation may change; Just update the test 22 | assert ddb2.dialect.normalize_boolean("bool", None) == "bool::INTEGER::VARCHAR" 23 | assert ddb2.dialect.random_sample_n("x", 10) 24 | 25 | # - Test immutability 26 | ddb3 = connect("duckdb://:memory:") 27 | assert not hasattr(ddb3.dialect, "normalize_boolean") 28 | 29 | self.assertRaises(TypeError, connect.load_mixins, AbstractMixin_TimeTravel) 30 | 31 | new_connect = connect.for_databases("bigquery", "snowflake").load_mixins(AbstractMixin_TimeTravel) 32 | self.assertRaises(NotImplementedError, new_connect, "duckdb://:memory:") 33 | -------------------------------------------------------------------------------- /sqeleton/query_utils.py: -------------------------------------------------------------------------------- 1 | "Module for query utilities that didn't make it into the query-builder (yet)" 2 | 3 | from contextlib import suppress 4 | 5 | from sqeleton.databases import DbPath, QueryError, Oracle 6 | from sqeleton.queries import table, commit, Expr 7 | 8 | 9 | def _drop_table_oracle(name: DbPath): 10 | t = table(name) 11 | # Experience shows double drop is necessary 12 | with suppress(QueryError): 13 | yield t.drop() 14 | yield t.drop() 15 | yield commit 16 | 17 | 18 | def _drop_table(name: DbPath): 19 | t = table(name) 20 | yield t.drop(if_exists=True) 21 | yield commit 22 | 23 | 24 | def drop_table(db, tbl): 25 | if isinstance(db, Oracle): 26 | db.query(_drop_table_oracle(tbl)) 27 | else: 28 | db.query(_drop_table(tbl)) 29 | 30 | 31 | def _append_to_table_oracle(path: DbPath, expr: Expr): 32 | """See append_to_table""" 33 | assert expr.schema, expr 34 | t = table(path, schema=expr.schema) 35 | with suppress(QueryError): 36 | yield t.create() # uses expr.schema 37 | yield commit 38 | yield t.insert_expr(expr) 39 | yield commit 40 | 41 | 42 | def _append_to_table(path: DbPath, expr: Expr): 43 | """Append to table""" 44 | assert expr.schema, expr 45 | t = table(path, schema=expr.schema) 46 | yield t.create(if_not_exists=True) # uses expr.schema 47 | yield commit 48 | yield t.insert_expr(expr) 49 | yield commit 50 | 51 | 52 | def append_to_table(db, path, expr): 53 | f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table 54 | db.query(f(path, expr)) 55 | -------------------------------------------------------------------------------- /.github/workflows/ci_full.yml: -------------------------------------------------------------------------------- 1 | name: CI-COVER-DATABASES 2 | 3 | on: 4 | # push: 5 | # paths: 6 | # - '**.py' 7 | # - '.github/workflows/**' 8 | # - '!dev/**' 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | 13 | jobs: 14 | unit_tests: 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest] 19 | python-version: 20 | - "3.10" 21 | 22 | name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} 23 | runs-on: ${{ matrix.os }} 24 | steps: 25 | - uses: actions/checkout@v3 26 | 27 | - name: Setup Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v3 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Build the stack 33 | run: docker-compose up -d mysql postgres presto trino clickhouse vertica 34 | 35 | - name: Install Poetry 36 | run: pip install poetry 37 | 38 | - name: Install package 39 | run: "poetry install" 40 | 41 | - name: Run unit tests 42 | env: 43 | PRESTO_URI: 'presto://presto@127.0.0.1/postgresql/public' 44 | TRINO_URI: 'trino://postgres@127.0.0.1:8081/postgresql/public' 45 | CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' 46 | VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' 47 | SNOWFLAKE_URI: '${{ secrets.SNOWFLAKE_URI }}' 48 | REDSHIFT_URI: '${{ secrets.REDSHIFT_URI }}' 49 | run: | 50 | chmod +x tests/waiting_for_stack_up.sh 51 | ./tests/waiting_for_stack_up.sh && poetry run unittest-parallel -j 16 52 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI-COVER-VERSIONS 2 | 3 | on: 4 | # push: 5 | # paths: 6 | # - '**.py' 7 | # - '.github/workflows/**' 8 | # - '!dev/**' 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | 13 | jobs: 14 | unit_tests: 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest] 19 | python-version: 20 | - "3.7" 21 | - "3.8" 22 | - "3.9" 23 | - "3.10" 24 | 25 | name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} 26 | runs-on: ${{ matrix.os }} 27 | steps: 28 | - uses: actions/checkout@v3 29 | 30 | - name: Setup Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v3 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | 35 | - name: Build the stack 36 | run: docker-compose up -d mysql postgres presto trino clickhouse vertica 37 | 38 | - name: Install Poetry 39 | run: pip install poetry 40 | 41 | - name: Install package 42 | run: "poetry install" 43 | 44 | - name: Run unit tests 45 | env: 46 | PRESTO_URI: 'presto://presto@127.0.0.1/postgresql/public' 47 | TRINO_URI: 'trino://postgres@127.0.0.1:8081/postgresql/public' 48 | CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' 49 | VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' 50 | SNOWFLAKE_URI: '${{ secrets.SNOWFLAKE_URI }}' 51 | REDSHIFT_URI: '${{ secrets.REDSHIFT_URI }}' 52 | run: | 53 | chmod +x tests/waiting_for_stack_up.sh 54 | ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=0 poetry run unittest-parallel -j 16 55 | -------------------------------------------------------------------------------- /sqeleton/queries/extras.py: -------------------------------------------------------------------------------- 1 | "Useful AST classes that don't quite fall within the scope of regular SQL" 2 | 3 | from typing import Callable, Sequence 4 | from runtype import dataclass 5 | 6 | from ..abcs.database_types import ColType, Native_UUID 7 | 8 | from .compiler import Compiler 9 | from .ast_classes import Expr, ExprNode, Concat, Code 10 | 11 | 12 | @dataclass 13 | class NormalizeAsString(ExprNode): 14 | expr: ExprNode 15 | expr_type: ColType = None 16 | type = str 17 | 18 | def compile(self, c: Compiler) -> str: 19 | expr = c.compile(self.expr) 20 | return c.dialect.normalize_value_by_type(expr, self.expr_type or self.expr.type) 21 | 22 | 23 | @dataclass 24 | class ApplyFuncAndNormalizeAsString(ExprNode): 25 | expr: ExprNode 26 | apply_func: Callable = None 27 | 28 | def compile(self, c: Compiler) -> str: 29 | expr = self.expr 30 | expr_type = expr.type 31 | 32 | if isinstance(expr_type, Native_UUID): 33 | # Normalize first, apply template after (for uuids) 34 | # Needed because min/max(uuid) fails in postgresql 35 | expr = NormalizeAsString(expr, expr_type) 36 | if self.apply_func is not None: 37 | expr = self.apply_func(expr) # Apply template using Python's string formatting 38 | 39 | else: 40 | # Apply template before normalizing (for ints) 41 | if self.apply_func is not None: 42 | expr = self.apply_func(expr) # Apply template using Python's string formatting 43 | expr = NormalizeAsString(expr, expr_type) 44 | 45 | return c.compile(expr) 46 | 47 | 48 | @dataclass 49 | class Checksum(ExprNode): 50 | exprs: Sequence[Expr] 51 | 52 | def compile(self, c: Compiler): 53 | if len(self.exprs) > 1: 54 | exprs = [Code(f"coalesce({c.compile(expr)}, '')") for expr in self.exprs] 55 | # exprs = [c.compile(e) for e in exprs] 56 | expr = Concat(exprs, "|") 57 | else: 58 | # No need to coalesce - safe to assume that key cannot be null 59 | (expr,) = self.exprs 60 | expr = c.compile(expr) 61 | md5 = c.dialect.md5_as_int(expr) 62 | return f"sum({md5})" 63 | -------------------------------------------------------------------------------- /sqeleton/repl.py: -------------------------------------------------------------------------------- 1 | import rich.table 2 | import logging 3 | 4 | # logging.basicConfig(level=logging.DEBUG) 5 | 6 | from . import connect 7 | 8 | import sys 9 | 10 | 11 | def print_table(rows, schema, table_name=""): 12 | # Print rows in a rich table 13 | t = rich.table.Table(title=table_name, caption=f"{len(rows)} rows") 14 | for col in schema: 15 | t.add_column(col) 16 | for r in rows: 17 | t.add_row(*map(str, r)) 18 | rich.print(t) 19 | 20 | 21 | def help(): 22 | rich.print("Commands:") 23 | rich.print(" ?mytable - shows schema of table 'mytable'") 24 | rich.print(" * - shows list of all tables") 25 | rich.print(" *pattern - shows list of all tables with name like pattern") 26 | rich.print("Otherwise, runs regular SQL query") 27 | 28 | 29 | def repl(uri): 30 | db = connect(uri) 31 | db_name = db.name 32 | 33 | while True: 34 | try: 35 | q = input(f"{db_name}> ").strip() 36 | except EOFError: 37 | return 38 | if not q: 39 | continue 40 | if q.startswith("*"): 41 | pattern = q[1:] 42 | names = db.query(db.dialect.list_tables(db.default_schema, like=f"%{pattern}%" if pattern else None)) 43 | print_table(names, ["name"], "List of tables") 44 | elif q.startswith("?"): 45 | table_name = q[1:] 46 | if not table_name: 47 | help() 48 | continue 49 | try: 50 | path = db.parse_table_name(table_name) 51 | print("->", path) 52 | schema = db.query_table_schema(path) 53 | except Exception as e: 54 | logging.error(e) 55 | else: 56 | print_table([(k, v[1]) for k, v in schema.items()], ["name", "type"], f"Table '{table_name}'") 57 | else: 58 | # Normal SQL query 59 | try: 60 | res = db.query(q) 61 | except Exception as e: 62 | logging.error(e) 63 | else: 64 | if res: 65 | print_table(res.rows, res.columns, None) 66 | 67 | 68 | def main(): 69 | uri = sys.argv[1] 70 | return repl(uri) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sqeleton" 3 | version = "0.0.8" 4 | description = "Python library for querying SQL databases" 5 | authors = ["Erez Shinan "] 6 | license = "MIT" 7 | readme = "README.md" 8 | repository = "https://github.com/datafold/sqeleton" 9 | documentation = "" 10 | classifiers = [ 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Information Technology", 13 | "Intended Audience :: System Administrators", 14 | "Programming Language :: Python :: 3.7", 15 | "Programming Language :: Python :: 3.8", 16 | "Programming Language :: Python :: 3.9", 17 | "Programming Language :: Python :: 3.10", 18 | "Development Status :: 2 - Pre-Alpha", 19 | "Environment :: Console", 20 | "Topic :: Database :: Database Engines/Servers", 21 | "Typing :: Typed" 22 | ] 23 | packages = [{ include = "sqeleton" }] 24 | 25 | [tool.poetry.dependencies] 26 | python = "^3.7" 27 | runtype = "^0.2.6" 28 | dsnparse = "*" 29 | click = "^8.1" 30 | rich = "*" 31 | toml = "^0.10.2" 32 | mysql-connector-python = {version="8.0.29", optional=true} 33 | psycopg2 = {version="*", optional=true} 34 | snowflake-connector-python = {version="^2.7.2", optional=true} 35 | cryptography = {version="*", optional=true} 36 | trino = {version="^0.314.0", optional=true} 37 | presto-python-client = {version="*", optional=true} 38 | clickhouse-driver = {version="*", optional=true} 39 | duckdb = {version="^0.7.0", optional=true} 40 | textual = {version="^0.9.1", optional=true} 41 | textual-select = {version="*", optional=true} 42 | 43 | [tool.poetry.dev-dependencies] 44 | parameterized = "*" 45 | unittest-parallel = "*" 46 | 47 | duckdb = "^0.7.0" 48 | mysql-connector-python = "*" 49 | psycopg2 = "*" 50 | snowflake-connector-python = "^2.7.2" 51 | cryptography = "*" 52 | trino = "^0.314.0" 53 | presto-python-client = "*" 54 | clickhouse-driver = "*" 55 | vertica-python = "*" 56 | 57 | [tool.poetry.extras] 58 | mysql = ["mysql-connector-python"] 59 | postgresql = ["psycopg2"] 60 | snowflake = ["snowflake-connector-python", "cryptography"] 61 | presto = ["presto-python-client"] 62 | oracle = ["cx_Oracle"] 63 | # databricks = ["databricks-sql-connector"] 64 | trino = ["trino"] 65 | clickhouse = ["clickhouse-driver"] 66 | vertica = ["vertica-python"] 67 | duckdb = ["duckdb"] 68 | tui = ["textual", "textual-select"] 69 | 70 | [build-system] 71 | requires = ["poetry-core>=1.0.0"] 72 | build-backend = "poetry.core.masonry.api" 73 | 74 | [tool.poetry.scripts] 75 | sqeleton = 'sqeleton.__main__:main' 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mac 132 | .DS_Store 133 | 134 | # IntelliJ 135 | .idea 136 | 137 | # VSCode 138 | .vscode 139 | -------------------------------------------------------------------------------- /sqeleton/bound_exprs.py: -------------------------------------------------------------------------------- 1 | """Expressions bound to a specific database""" 2 | 3 | import inspect 4 | from functools import wraps 5 | from typing import Union, TYPE_CHECKING 6 | 7 | from runtype import dataclass 8 | 9 | from .abcs import AbstractDatabase, AbstractCompiler 10 | from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable 11 | from .queries.api import table 12 | from .schema import create_schema 13 | 14 | 15 | @dataclass 16 | class BoundNode(ExprNode): 17 | database: AbstractDatabase 18 | node: Compilable 19 | 20 | def __getattr__(self, attr): 21 | value = getattr(self.node, attr) 22 | if inspect.ismethod(value): 23 | 24 | @wraps(value) 25 | def bound_method(*args, **kw): 26 | return BoundNode(self.database, value(*args, **kw)) 27 | 28 | return bound_method 29 | return value 30 | 31 | def query(self, res_type=list): 32 | return self.database.query(self.node, res_type=res_type) 33 | 34 | @property 35 | def type(self): 36 | return self.node.type 37 | 38 | def compile(self, c: AbstractCompiler) -> str: 39 | assert c.database is self.database 40 | return self.node.compile(c) 41 | 42 | 43 | def bind_node(node, database): 44 | return BoundNode(database, node) 45 | 46 | 47 | ExprNode.bind = bind_node 48 | 49 | 50 | @dataclass 51 | class BoundTable(BoundNode): # ITable 52 | database: AbstractDatabase 53 | node: TablePath 54 | 55 | def with_schema(self, schema): 56 | table_path = self.node.replace(schema=schema) 57 | return self.replace(node=table_path) 58 | 59 | def query_schema(self, *, columns=None, where=None, case_sensitive=True): 60 | table_path = self.node 61 | 62 | if table_path.schema: 63 | return self 64 | 65 | raw_schema = self.database.query_table_schema(table_path.path) 66 | schema = self.database._process_table_schema(table_path.path, raw_schema, columns, where) 67 | schema = create_schema(self.database, table_path, schema, case_sensitive) 68 | return self.with_schema(schema) 69 | 70 | @property 71 | def schema(self): 72 | return self.node.schema 73 | 74 | 75 | def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tuple], **kw): 76 | return BoundTable(database, table(table_path, **kw)) 77 | 78 | 79 | # Database.table = bound_table 80 | 81 | # def test(): 82 | # from . import connect 83 | # from .queries.api import table 84 | # d = connect("mysql://erez:qweqwe123@localhost/erez") 85 | # t = table(('Rating',)) 86 | 87 | # b = BoundTable(d, t) 88 | # b2 = b.with_schema() 89 | 90 | # breakpoint() 91 | 92 | # test() 93 | 94 | if TYPE_CHECKING: 95 | 96 | class BoundTable(BoundTable, TablePath): 97 | pass 98 | -------------------------------------------------------------------------------- /sqeleton/queries/compiler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datetime import datetime 3 | from typing import Any, Dict, Sequence, List 4 | 5 | from runtype import dataclass 6 | 7 | from ..utils import ArithString 8 | from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable 9 | 10 | import contextvars 11 | 12 | cv_params = contextvars.ContextVar("params") 13 | 14 | 15 | class CompileError(Exception): 16 | pass 17 | 18 | 19 | class Root: 20 | "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" 21 | 22 | 23 | @dataclass 24 | class Compiler(AbstractCompiler): 25 | database: AbstractDatabase 26 | params: dict = {} 27 | in_select: bool = False # Compilation runtime flag 28 | in_join: bool = False # Compilation runtime flag 29 | 30 | _table_context: List = [] # List[ITable] 31 | _subqueries: Dict[str, Any] = {} # XXX not thread-safe 32 | root: bool = True 33 | 34 | _counter: List = [0] 35 | 36 | @property 37 | def dialect(self) -> AbstractDialect: 38 | return self.database.dialect 39 | 40 | def compile(self, elem, params=None) -> str: 41 | if params: 42 | cv_params.set(params) 43 | 44 | if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root): 45 | from .ast_classes import Select 46 | 47 | elem = Select(columns=[elem]) 48 | 49 | res = self._compile(elem) 50 | if self.root and self._subqueries: 51 | subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) 52 | self._subqueries.clear() 53 | return f"WITH {subq}\n{res}" 54 | return res 55 | 56 | def _compile(self, elem) -> str: 57 | if elem is None: 58 | return "NULL" 59 | elif isinstance(elem, Compilable): 60 | return elem.compile(self.replace(root=False)) 61 | elif isinstance(elem, str): 62 | return f"'{elem}'" 63 | elif isinstance(elem, (int, float)): 64 | return str(elem) 65 | elif isinstance(elem, datetime): 66 | return self.dialect.timestamp_value(elem) 67 | elif isinstance(elem, bytes): 68 | return f"b'{elem.decode()}'" 69 | elif isinstance(elem, ArithString): 70 | return f"'{elem}'" 71 | assert False, elem 72 | 73 | def new_unique_name(self, prefix="tmp"): 74 | self._counter[0] += 1 75 | return f"{prefix}{self._counter[0]}" 76 | 77 | def new_unique_table_name(self, prefix="tmp") -> DbPath: 78 | self._counter[0] += 1 79 | return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") 80 | 81 | def add_table_context(self, *tables: Sequence, **kw): 82 | return self.replace(_table_context=self._table_context + list(tables), **kw) 83 | 84 | def quote(self, s: str): 85 | return self.dialect.quote(s) 86 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/combined.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIERDCCAyygAwIBAgIUBxO/CflDP+0yZAXt5FKm/WQ4538wDQYJKoZIhvcNAQEL 3 | BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM 4 | GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X 5 | DTIyMDgyNDA4NTI1N1oXDTMyMDgyMTA4NTI1N1owWTELMAkGA1UEBhMCQVUxEzAR 6 | BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 7 | IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 8 | MIIBCgKCAQEA3lNywkj/eGPGoFA3Lcx++98l17CRy+uzZMtJsr6lYAg1p/n1vPw0 9 | BQXI5TSBJ6vM/axtwgwrfXQsjQ/GYJKQkb6eEBCc3xb+Rk5HNBiBBZsIjYm0U1zz 10 | 7dKnNwAznjx3j72s2ZQiqkoxcu7Bctw28ynbg0rjNkuUk3QESKuOgaTltpWKZiiu 11 | XwWasREeH6MH7ROy8db6cz+MwGaig0mUvGPmD97bPRD/X683RyOiXzEaogl/rpGK 12 | qZ3jRsmS8ZwawzKxx16kqPsX8/01EruGIoubMttr3YoZG044zq7nQqdAAz6wXx6V 13 | mgzToCHI+/g+8JS/bgqJTyb2Y6aGXExiuQIDAQABo4IBAjCB/zAdBgNVHQ4EFgQU 14 | 5i1F8pTnwjFxw6W/0RjwpaJaK9MwgZYGA1UdIwSBjjCBi4AU5i1F8pTnwjFxw6W/ 15 | 0RjwpaJaK9OhXaRbMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRl 16 | MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxv 17 | Y2FsaG9zdIIUBxO/CflDP+0yZAXt5FKm/WQ4538wDAYDVR0TBAUwAwEB/zALBgNV 18 | HQ8EBAMCAvwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MBQGA1UdEgQNMAuCCWxvY2Fs 19 | aG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAjBQLl/UFSd9TH2VLM1GH8bixtEJ9+rm7 20 | x+Jw665+XLjW107dJ33qxy9zjd3cZ2fynKg2Tb7+9QAvSlqpt2YMGP9jr4W2w16u 21 | ngbNB+kfoOotcUChk90aHmdHLZgOOve/ArFIvbr8douLOn0NAJBrj+iX4zC1pgEC 22 | 9hsMUekkAPIcCGc0rEkEc8r8uiUBWNAdEWpBt0X2fE1ownLuB/E/+3HutLTw8Lv0 23 | b+jNt/vogVixcw/FF4atoO+F7S5FYzAb0U7YXaNISfVPVBsA89oPy7PlxULHDUIF 24 | Iq+vVqKdj1EXR+Iec0TMiMsa3MnIGkpL7ZuUXaG+xGBaVhGrUp67lQ== 25 | -----END CERTIFICATE----- 26 | -----BEGIN RSA PRIVATE KEY----- 27 | MIIEogIBAAKCAQEA3lNywkj/eGPGoFA3Lcx++98l17CRy+uzZMtJsr6lYAg1p/n1 28 | vPw0BQXI5TSBJ6vM/axtwgwrfXQsjQ/GYJKQkb6eEBCc3xb+Rk5HNBiBBZsIjYm0 29 | U1zz7dKnNwAznjx3j72s2ZQiqkoxcu7Bctw28ynbg0rjNkuUk3QESKuOgaTltpWK 30 | ZiiuXwWasREeH6MH7ROy8db6cz+MwGaig0mUvGPmD97bPRD/X683RyOiXzEaogl/ 31 | rpGKqZ3jRsmS8ZwawzKxx16kqPsX8/01EruGIoubMttr3YoZG044zq7nQqdAAz6w 32 | Xx6VmgzToCHI+/g+8JS/bgqJTyb2Y6aGXExiuQIDAQABAoIBAD3pKwnjXhDeaA94 33 | hwUf7zSgfV9E8jTBHCGzYoB+CntljduLBd1stee4Jqt9JYIwm1MA00e4L9wtn8Jg 34 | ZDO8XLnZRRbgKW8ObhyR684cDMHM3GLdt/OG7P6LLLlqOvWTjQ/gF+Q3FjgplP+W 35 | cRRVMpAgVdqH3iHehi9RnWfHLlX3WkBC97SumWFzWqBnqUQAC8AvFHiUCqA9qIeA 36 | 8ieOEoE17yv2nkmu+A5OZoCXtVfc2lQ90Fj9QZiv4rIVXBtTRRURJuvi2iX3nOPl 37 | MsjAUIBK1ndzpJ7wuLICSR1U3/npPC6Va06lTm0H/Q6DEqZjEHbx9TGY3pTgVXuA 38 | +G0C5GkCgYEA/spvoDUMZrH2JE43TT/qMEHPtHW4qT6fTmzu08Gx++8nFmddNgSD 39 | zrdjxqPMUGV7Q5smwoHaQyqFxHMM2jh8icnV6VoBDrDdZM0eGFAs6HUjKmyaAdQO 40 | dC4kPiy3LX5pJUnQnmwq1fVsgXWGQF/LhD0L/y6xOiqdhZp/8nv6SFMCgYEA32GR 41 | gWJQdgWXTXxSEDn0twKPevgAT0s778/7h5osCLG82Q7ab2+Fc1qTleiwiQ2SAuOl 42 | mWvtz0Eg4dYe/q6jugqkEgAYZvdHGL7CSmC28O98fTLapgKQC5GUUan90sCbRec4 43 | kjbyx5scICNBYJVchdFg6UUSNz5czORUVgQEF0MCgYB1toUX2Spfj7yOTWyTTgIe 44 | RWl2kCS+XGYxT3aPcp+OK5E9cofH2xIiQOvh6+8K/beTJm0j0+ZIva6LcjPv5cTz 45 | y8H+S0zNwrymQ3Wx+eilhOi4QvBsA9KhrmekKfh/FjXxukadyo+HxhlZPjjGKPvX 46 | nnSacrICk4mvHhAasViSbQKBgD7mZGiAXJO/I0moVhtHlobp66j+qGerkacHc5ZN 47 | bVTNZ5XfPtbeGj/PI3u01/Dfp1u06m53G7GebznoZzXjyyqZ0HVZHYXw304yeNck 48 | wJ67cNx4M2VHl3QKfC86pMRxg8d9Qkq5ukdGf/b0tnYR2Mm9mYJV9rkjkFIJgU3v 49 | N4+tAoGABOlVGuRx2cSQ9QeC0AcqKlxXygdrzyadA7i0KNBZGGyrMSpJDrl2rrRn 50 | ylzAgGjvfilwQzZuqTm6Vo2yvaX+TTGS44B+DnxCZvuviftea++sNMjuEkBLTCpF 51 | xk2yOzsOnx652kWO4L+dVrDAxl65f3v0YaKWZI504LFYl18uS/E= 52 | -----END RSA PRIVATE KEY----- 53 | -------------------------------------------------------------------------------- /docs/supported-databases.md: -------------------------------------------------------------------------------- 1 | # List of supported databases 2 | 3 | | Database | Status | Connection string | 4 | |---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| 5 | | PostgreSQL >=10 | 💚 | `postgresql://:@:5432/` | 6 | | MySQL | 💚 | `mysql://:@:5432/` | 7 | | Snowflake | 💚 | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | 8 | | BigQuery | 💚 | `bigquery:///` | 9 | | Redshift | 💚 | `redshift://:@:5439/` | 10 | | Oracle | 💛 | `oracle://:@/database` | 11 | | Presto | 💛 | `presto://:@:8080/` | 12 | | Databricks | 💛 | `databricks://:@//` | 13 | | Trino | 💛 | `trino://:@:8080/` | 14 | | Clickhouse | 💛 | `clickhouse://:@:9000/` | 15 | | Vertica | 💛 | `vertica://:@:5433/` | 16 | | DuckDB | 💛 | | 17 | | ElasticSearch | 📝 | | 18 | | Planetscale | 📝 | | 19 | | Pinot | 📝 | | 20 | | Druid | 📝 | | 21 | | Kafka | 📝 | | 22 | | SQLite | 📝 | | 23 | 24 | * 💚: Implemented and thoroughly tested. 25 | * 💛: Implemented, but not thoroughly tested yet. 26 | * ⏳: Implementation in progress. 27 | * 📝: Implementation planned. Contributions welcome. 28 | 29 | Is your database not listed here? We accept pull-requests! 30 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | postgres: 5 | container_name: dd-postgresql 6 | image: postgres:14.1-alpine 7 | # work_mem: less tmp files 8 | # maintenance_work_mem: improve table-level op perf 9 | # max_wal_size: allow more time before merging to heap 10 | command: > 11 | -c work_mem=1GB 12 | -c maintenance_work_mem=1GB 13 | -c max_wal_size=8GB 14 | restart: always 15 | volumes: 16 | - postgresql-data:/var/lib/postgresql/data:delegated 17 | ports: 18 | - '5432:5432' 19 | expose: 20 | - '5432' 21 | env_file: 22 | - dev/dev.env 23 | tty: true 24 | networks: 25 | - local 26 | 27 | mysql: 28 | container_name: dd-mysql 29 | image: mysql:oracle 30 | # fsync less aggressively for insertion perf for test setup 31 | command: > 32 | --default-authentication-plugin=mysql_native_password 33 | --binlog-cache-size=16M 34 | --key_buffer_size=0 35 | --max_connections=1000 36 | --innodb_flush_log_at_trx_commit=2 37 | --innodb_flush_log_at_timeout=10 38 | --innodb_log_compressed_pages=OFF 39 | --sync_binlog=0 40 | restart: always 41 | volumes: 42 | - mysql-data:/var/lib/mysql:delegated 43 | user: mysql 44 | ports: 45 | - '3306:3306' 46 | expose: 47 | - '3306' 48 | env_file: 49 | - dev/dev.env 50 | tty: true 51 | networks: 52 | - local 53 | 54 | clickhouse: 55 | container_name: dd-clickhouse 56 | image: clickhouse/clickhouse-server:21.12.3.32 57 | restart: always 58 | volumes: 59 | - clickhouse-data:/var/lib/clickhouse:delegated 60 | ulimits: 61 | nproc: 65535 62 | nofile: 63 | soft: 262144 64 | hard: 262144 65 | ports: 66 | - '8123:8123' 67 | - '9000:9000' 68 | expose: 69 | - '8123' 70 | - '9000' 71 | env_file: 72 | - dev/dev.env 73 | tty: true 74 | networks: 75 | - local 76 | 77 | # prestodb.dbapi.connect(host="127.0.0.1", user="presto").cursor().execute('SELECT * FROM system.runtime.nodes') 78 | presto: 79 | container_name: dd-presto 80 | build: 81 | context: ./dev 82 | dockerfile: ./Dockerfile.prestosql.340 83 | volumes: 84 | - ./dev/presto-conf/standalone:/opt/presto/etc:ro 85 | ports: 86 | - '8080:8080' 87 | tty: true 88 | networks: 89 | - local 90 | 91 | trino: 92 | container_name: dd-trino 93 | image: 'trinodb/trino:389' 94 | hostname: trino 95 | ports: 96 | - '8081:8080' 97 | volumes: 98 | - ./dev/trino-conf/etc:/etc/trino:ro 99 | networks: 100 | - local 101 | 102 | vertica: 103 | container_name: dd-vertica 104 | image: vertica/vertica-ce:12.0.0-0 105 | restart: always 106 | volumes: 107 | - vertica-data:/data:delegated 108 | ports: 109 | - '5433:5433' 110 | - '5444:5444' 111 | expose: 112 | - '5433' 113 | - '5444' 114 | env_file: 115 | - dev/dev.env 116 | tty: true 117 | networks: 118 | - local 119 | 120 | 121 | 122 | volumes: 123 | postgresql-data: 124 | mysql-data: 125 | clickhouse-data: 126 | vertica-data: 127 | 128 | networks: 129 | local: 130 | driver: bridge 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **THE PROJECT IS MERGED INTO [data-diff](https://github.com/datafold/data-diff) ON 2023-04-20 AND IS MAINTAINED THERE.** 2 | 3 | **THIS REPO IS FOR HISTORIC REFERENCES ONLY.** 4 | 5 | # Sqeleton 6 | 7 | **Under construction!** 8 | 9 | Sqeleton is a Python library for querying SQL databases. 10 | 11 | It consists of - 12 | 13 | - A fast and concise query builder, inspired by PyPika and SQLAlchemy 14 | 15 | - A modular database interface, with drivers for a long list of SQL databases. 16 | 17 | It is comparable to other libraries such as SQLAlchemy or PyPika, in terms of API and intended audience. However there are several notable ways in which it is different. 18 | 19 | ## Overview 20 | 21 | ### Built for performance 22 | 23 | - Multi-threaded by default - 24 | The same connection object can be used from multiple threads without any additional setup. 25 | 26 | - No ORM 27 | ORMs are easy and familiar, but they encourage bad and slow code. Sqeleton is designed to push the compute to SQL. 28 | 29 | - Fast query-builder 30 | Sqeleton's query-builder runs about 4 times faster than SQLAlchemy's. 31 | 32 | ### Type-aware 33 | 34 | Sqeleton has a built-in feature to query the schemas of the databases it supports. 35 | 36 | This feature can be also used to inform the query-builder, either as an alternative to defining the tables yourself, or to validate that your definitions match the actual schema. 37 | 38 | The schema is used for validation when building expressions, making sure the names are correct, and that the data-types align. 39 | 40 | (Still WIP) 41 | 42 | ### Multi-database access 43 | 44 | Sqeleton is designed to work with several databases at the same time. Its API abstracts away as many implementation details as possible. 45 | 46 | Databases we fully support: 47 | 48 | - PostgreSQL >=10 49 | - MySQL 50 | - Snowflake 51 | - BigQuery 52 | - Redshift 53 | - Oracle 54 | - Presto 55 | - Databricks 56 | - Trino 57 | - Clickhouse 58 | - Vertica 59 | - DuckDB >=0.6 60 | - SQLite (coming soon) 61 | 62 | ## Documentation 63 | 64 | [Read the docs!](https://sqeleton.readthedocs.io) 65 | 66 | Or jump straight to the [introduction](https://sqeleton.readthedocs.io/en/latest/intro.html). 67 | 68 | ### Install 69 | 70 | Install using pip: 71 | 72 | ```bash 73 | pip install sqeleton 74 | ``` 75 | 76 | It is recommended to install the driver dependencies using pip's `[]` syntax: 77 | 78 | ```bash 79 | pip install 'sqeleton[mysql, postgresql]' 80 | ``` 81 | 82 | Read more in [install / getting started.](https://sqeleton.readthedocs.io/en/latest/install.html) 83 | 84 | ### Basic usage 85 | 86 | ```python 87 | from sqeleton import connect, table, this 88 | 89 | # Create a new database connection 90 | ddb = connect("duckdb://:memory:") 91 | 92 | # Define a table with one int column 93 | tbl = table('my_list', schema={'item': int}) 94 | 95 | # Make a bunch of queries 96 | queries = [ 97 | # Create table 'my_list' 98 | tbl.create(), 99 | 100 | # Insert 100 numbers 101 | tbl.insert_rows([x] for x in range(100)), 102 | 103 | # Get the sum of the numbers 104 | tbl.select(this.item.sum()) 105 | ] 106 | # Query in order, and return the last result as an int 107 | result = ddb.query(queries, int) 108 | 109 | # Prints: Total sum of 0..100 = 4950 110 | print(f"Total sum of 0..100 = {result}") 111 | ``` 112 | 113 | 114 | # TODO 115 | 116 | - Transactions 117 | 118 | - Indexes 119 | 120 | - Date/time expressions 121 | 122 | - Window functions 123 | 124 | ## Possible plans for the future (not determined yet) 125 | 126 | - Cache the compilation of repetitive queries for even faster query-building 127 | 128 | - Compile control flow, functions 129 | 130 | - Define tables using type-annotated classes (SQLModel style) 131 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human, WeakCache 4 | 5 | 6 | class TestUtils(unittest.TestCase): 7 | def test_remove_passwords_in_dict(self): 8 | # Test replacing password value 9 | d = {"password": "mypassword"} 10 | remove_passwords_in_dict(d) 11 | assert d["password"] == "***" 12 | 13 | # Test replacing password in database URL 14 | d = {"database_url": "mysql://user:mypassword@localhost/db"} 15 | remove_passwords_in_dict(d, "$$$$") 16 | assert d["database_url"] == "mysql://user:$$$$@localhost/db" 17 | 18 | # Test replacing password in nested dictionary 19 | d = {"info": {"password": "mypassword"}} 20 | remove_passwords_in_dict(d, "%%") 21 | assert d["info"]["password"] == "%%" 22 | 23 | def test_match_regexps(self): 24 | def only_results(x): 25 | return [v for k, v in x] 26 | 27 | # Test with no matches 28 | regexps = {"a*": 1, "b*": 2} 29 | s = "c" 30 | assert only_results(match_regexps(regexps, s)) == [] 31 | 32 | # Test with one match 33 | regexps = {"a*": 1, "b*": 2} 34 | s = "b" 35 | assert only_results(match_regexps(regexps, s)) == [2] 36 | 37 | # Test with multiple matches 38 | regexps = {"abc": 1, "ab*c": 2, "c*": 3} 39 | s = "abc" 40 | assert only_results(match_regexps(regexps, s)) == [1, 2] 41 | 42 | # Test with regexp that doesn't match the end of the string 43 | regexps = {"a*b": 1} 44 | s = "acb" 45 | assert only_results(match_regexps(regexps, s)) == [] 46 | 47 | def test_match_like(self): 48 | strs = ["abc", "abcd", "ab", "bcd", "def"] 49 | 50 | # Test exact match 51 | pattern = "abc" 52 | result = list(match_like(pattern, strs)) 53 | assert result == ["abc"] 54 | 55 | # Test % match 56 | pattern = "a%" 57 | result = list(match_like(pattern, strs)) 58 | self.assertEqual(result, ["abc", "abcd", "ab"]) 59 | 60 | # Test ? match 61 | pattern = "a?c" 62 | result = list(match_like(pattern, strs)) 63 | self.assertEqual(result, ["abc"]) 64 | 65 | def test_number_to_human(self): 66 | # Test basic conversion 67 | assert number_to_human(1000) == "1k" 68 | assert number_to_human(1000000) == "1m" 69 | assert number_to_human(1000000000) == "1b" 70 | 71 | # Test decimal values 72 | assert number_to_human(1234) == "1k" 73 | assert number_to_human(12345) == "12k" 74 | assert number_to_human(123456) == "123k" 75 | assert number_to_human(1234567) == "1m" 76 | assert number_to_human(12345678) == "12m" 77 | assert number_to_human(123456789) == "123m" 78 | assert number_to_human(1234567890) == "1b" 79 | 80 | # Test negative values 81 | assert number_to_human(-1000) == "-1k" 82 | assert number_to_human(-1000000) == "-1m" 83 | assert number_to_human(-1000000000) == "-1b" 84 | 85 | def test_weak_cache(self): 86 | # Create cache 87 | cache = WeakCache() 88 | 89 | # Test adding and retrieving basic value 90 | o = {1, 2} 91 | cache.add("key", o) 92 | assert cache.get("key") is o 93 | 94 | # Test adding and retrieving dict value 95 | cache.add({"key": "value"}, o) 96 | assert cache.get({"key": "value"}) is o 97 | 98 | # Test deleting value when reference is lost 99 | del o 100 | try: 101 | cache.get({"key": "value"}) 102 | assert False, "KeyError should have been raised" 103 | except KeyError: 104 | pass 105 | -------------------------------------------------------------------------------- /tests/test_sql.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .common import TEST_MYSQL_CONN_STRING 4 | 5 | from sqeleton import connect 6 | from sqeleton.queries import Compiler, Count, Explain, Select, table, In, BinOp, Code 7 | 8 | 9 | class TestSQL(unittest.TestCase): 10 | def setUp(self): 11 | self.mysql = connect(TEST_MYSQL_CONN_STRING) 12 | self.compiler = Compiler(self.mysql) 13 | 14 | def test_compile_string(self): 15 | self.assertEqual("SELECT 1", self.compiler.compile(Code("SELECT 1"))) 16 | 17 | def test_compile_int(self): 18 | self.assertEqual("1", self.compiler.compile(1)) 19 | 20 | def test_compile_table_name(self): 21 | self.assertEqual( 22 | "`marine_mammals`.`walrus`", self.compiler.replace(root=False).compile(table("marine_mammals", "walrus")) 23 | ) 24 | 25 | def test_compile_select(self): 26 | expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" 27 | self.assertEqual( 28 | expected_sql, 29 | self.compiler.compile( 30 | Select( 31 | table("marine_mammals", "walrus"), 32 | [Code("name")], 33 | ) 34 | ), 35 | ) 36 | 37 | # def test_enum(self): 38 | # expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" 39 | # self.assertEqual( 40 | # expected_sql, 41 | # self.compiler.compile( 42 | # Enum( 43 | # ("walrus",), 44 | # "id", 45 | # ) 46 | # ), 47 | # ) 48 | 49 | # def test_checksum(self): 50 | # expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`" 51 | # self.assertEqual( 52 | # expected_sql, 53 | # self.compiler.compile( 54 | # Select( 55 | # ["name", Checksum(["id", "timestamp"])], 56 | # TableName(("marine_mammals", "walrus")), 57 | # ) 58 | # ), 59 | # ) 60 | 61 | def test_compare(self): 62 | expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id <= 1000) AND (id > 1)" 63 | self.assertEqual( 64 | expected_sql, 65 | self.compiler.compile( 66 | Select( 67 | table("marine_mammals", "walrus"), 68 | [Code("name")], 69 | [BinOp("<=", [Code("id"), Code("1000")]), BinOp(">", [Code("id"), Code("1")])], 70 | ) 71 | ), 72 | ) 73 | 74 | def test_in(self): 75 | expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 76 | self.assertEqual( 77 | expected_sql, 78 | self.compiler.compile( 79 | Select(table("marine_mammals", "walrus"), [Code("name")], [In(Code("id"), [1, 2, 3])]) 80 | ), 81 | ) 82 | 83 | def test_count(self): 84 | expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 85 | self.assertEqual( 86 | expected_sql, 87 | self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In(Code("id"), [1, 2, 3])])), 88 | ) 89 | 90 | def test_count_with_column(self): 91 | expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 92 | self.assertEqual( 93 | expected_sql, 94 | self.compiler.compile( 95 | Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])]) 96 | ), 97 | ) 98 | 99 | def test_explain(self): 100 | expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 101 | self.assertEqual( 102 | expected_sql, 103 | self.compiler.compile( 104 | Explain(Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])])) 105 | ), 106 | ) 107 | -------------------------------------------------------------------------------- /sqeleton/databases/mysql.py: -------------------------------------------------------------------------------- 1 | from ..abcs.database_types import ( 2 | Datetime, 3 | Timestamp, 4 | Float, 5 | Decimal, 6 | Integer, 7 | Text, 8 | TemporalType, 9 | FractionalType, 10 | ColType_UUID, 11 | Boolean, 12 | Date, 13 | ) 14 | from ..abcs.mixins import ( 15 | AbstractMixin_MD5, 16 | AbstractMixin_NormalizeValue, 17 | AbstractMixin_Regex, 18 | AbstractMixin_RandomSample, 19 | ) 20 | from .base import Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable 21 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema, Mixin_RandomSample 22 | from ..queries.ast_classes import BinBoolOp 23 | 24 | 25 | @import_helper("mysql") 26 | def import_mysql(): 27 | import mysql.connector 28 | 29 | return mysql.connector 30 | 31 | 32 | class Mixin_MD5(AbstractMixin_MD5): 33 | def md5_as_int(self, s: str) -> str: 34 | return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" 35 | 36 | 37 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 38 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 39 | if coltype.rounds: 40 | return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") 41 | 42 | s = self.to_string(f"cast({value} as datetime(6))") 43 | return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" 44 | 45 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 46 | return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") 47 | 48 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 49 | return f"TRIM(CAST({value} AS char))" 50 | 51 | 52 | class Mixin_Regex(AbstractMixin_Regex): 53 | def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: 54 | return BinBoolOp("REGEXP", [string, pattern]) 55 | 56 | 57 | class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): 58 | name = "MySQL" 59 | ROUNDS_ON_PREC_LOSS = True 60 | SUPPORTS_PRIMARY_KEY = True 61 | SUPPORTS_INDEXES = True 62 | TYPE_CLASSES = { 63 | # Dates 64 | "datetime": Datetime, 65 | "timestamp": Timestamp, 66 | "date": Date, 67 | # Numbers 68 | "double": Float, 69 | "float": Float, 70 | "decimal": Decimal, 71 | "int": Integer, 72 | "bigint": Integer, 73 | "smallint": Integer, 74 | "tinyint": Integer, 75 | # Text 76 | "varchar": Text, 77 | "char": Text, 78 | "varbinary": Text, 79 | "binary": Text, 80 | "text": Text, 81 | "mediumtext": Text, 82 | "longtext": Text, 83 | "tinytext": Text, 84 | # Boolean 85 | "boolean": Boolean, 86 | } 87 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 88 | 89 | def quote(self, s: str): 90 | return f"`{s}`" 91 | 92 | def to_string(self, s: str): 93 | return f"cast({s} as char)" 94 | 95 | def is_distinct_from(self, a: str, b: str) -> str: 96 | return f"not ({a} <=> {b})" 97 | 98 | def random(self) -> str: 99 | return "RAND()" 100 | 101 | def type_repr(self, t) -> str: 102 | try: 103 | return { 104 | str: "VARCHAR(1024)", 105 | }[t] 106 | except KeyError: 107 | return super().type_repr(t) 108 | 109 | def explain_as_text(self, query: str) -> str: 110 | return f"EXPLAIN FORMAT=TREE {query}" 111 | 112 | def optimizer_hints(self, s: str): 113 | return f"/*+ {s} */ " 114 | 115 | def set_timezone_to_utc(self) -> str: 116 | return "SET @@session.time_zone='+00:00'" 117 | 118 | 119 | class MySQL(ThreadedDatabase): 120 | dialect = Dialect() 121 | SUPPORTS_ALPHANUMS = False 122 | SUPPORTS_UNIQUE_CONSTAINT = True 123 | CONNECT_URI_HELP = "mysql://:@/" 124 | CONNECT_URI_PARAMS = ["database?"] 125 | 126 | def __init__(self, *, thread_count, **kw): 127 | self._args = kw 128 | 129 | super().__init__(thread_count=thread_count) 130 | 131 | # In MySQL schema and database are synonymous 132 | try: 133 | self.default_schema = kw["database"] 134 | except KeyError: 135 | raise ValueError("MySQL URL must specify a database") 136 | 137 | def create_connection(self): 138 | mysql = import_mysql() 139 | try: 140 | return mysql.connect(charset="utf8", use_unicode=True, **self._args) 141 | except mysql.Error as e: 142 | if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: 143 | raise ConnectError("Bad user name or password") from e 144 | elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: 145 | raise ConnectError("Database does not exist") from e 146 | raise ConnectError(*e.args) from e 147 | -------------------------------------------------------------------------------- /tests/test_database.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | from typing import Callable, List, Tuple 4 | 5 | import pytz 6 | 7 | from sqeleton import connect 8 | from sqeleton import databases as dbs 9 | from sqeleton.queries import table, current_timestamp, NormalizeAsString 10 | from .common import TEST_MYSQL_CONN_STRING 11 | from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix 12 | from sqeleton.abcs.database_types import TimestampTZ 13 | 14 | TEST_DATABASES = { 15 | dbs.MySQL, 16 | dbs.PostgreSQL, 17 | dbs.Oracle, 18 | dbs.Redshift, 19 | dbs.Snowflake, 20 | dbs.DuckDB, 21 | dbs.BigQuery, 22 | dbs.Presto, 23 | dbs.Trino, 24 | dbs.Vertica, 25 | } 26 | 27 | test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) 28 | 29 | 30 | class TestDatabase(unittest.TestCase): 31 | def setUp(self): 32 | self.mysql = connect(TEST_MYSQL_CONN_STRING) 33 | 34 | def test_connect_to_db(self): 35 | self.assertEqual(1, self.mysql.query("SELECT 1", int)) 36 | 37 | 38 | class TestMD5(unittest.TestCase): 39 | def test_md5_as_int(self): 40 | class MD5Dialect(dbs.mysql.Dialect, dbs.mysql.Mixin_MD5): 41 | pass 42 | 43 | self.mysql = connect(TEST_MYSQL_CONN_STRING) 44 | self.mysql.dialect = MD5Dialect() 45 | 46 | str = "hello world" 47 | query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str)) 48 | query = f"SELECT {query_fragment}" 49 | 50 | self.assertEqual(str_to_checksum(str), self.mysql.query(query, int)) 51 | 52 | 53 | class TestConnect(unittest.TestCase): 54 | def test_bad_uris(self): 55 | self.assertRaises(ValueError, connect, "p") 56 | self.assertRaises(ValueError, connect, "postgresql:///bla/foo") 57 | self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1") 58 | self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1?warehouse=ha&schema=dup") 59 | 60 | 61 | @test_each_database 62 | class TestSchema(unittest.TestCase): 63 | def test_table_list(self): 64 | name = "tbl_" + random_table_suffix() 65 | db = get_conn(self.db_cls) 66 | tbl = table(db.parse_table_name(name), schema={"id": int}) 67 | q = db.dialect.list_tables(db.default_schema, name) 68 | assert not db.query(q) 69 | 70 | db.query(tbl.create()) 71 | self.assertEqual(db.query(q, List[str]), [name]) 72 | 73 | db.query(tbl.drop()) 74 | assert not db.query(q) 75 | 76 | def test_type_mapping(self): 77 | name = "tbl_" + random_table_suffix() 78 | db = get_conn(self.db_cls) 79 | tbl = table(db.parse_table_name(name), schema={ 80 | "int": int, 81 | "float": float, 82 | "datetime": datetime, 83 | "str": str, 84 | "bool": bool, 85 | }) 86 | q = db.dialect.list_tables(db.default_schema, name) 87 | assert not db.query(q) 88 | 89 | db.query(tbl.create()) 90 | self.assertEqual(db.query(q, List[str]), [name]) 91 | 92 | db.query(tbl.drop()) 93 | assert not db.query(q) 94 | 95 | @test_each_database 96 | class TestQueries(unittest.TestCase): 97 | def test_current_timestamp(self): 98 | db = get_conn(self.db_cls) 99 | res = db.query(current_timestamp(), datetime) 100 | assert isinstance(res, datetime), (res, type(res)) 101 | 102 | def test_correct_timezone(self): 103 | name = "tbl_" + random_table_suffix() 104 | db = get_conn(self.db_cls) 105 | tbl = table(name, schema={ 106 | "id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9) 107 | }) 108 | 109 | db.query(tbl.create()) 110 | 111 | tz = pytz.timezone('Europe/Berlin') 112 | 113 | now = datetime.now(tz) 114 | if isinstance(db, dbs.Presto): 115 | ms = now.microsecond // 1000 * 1000 # Presto max precision is 3 116 | now = now.replace(microsecond = ms) 117 | 118 | db.query(table(name).insert_row(1, now, now)) 119 | db.query(db.dialect.set_timezone_to_utc()) 120 | 121 | t = db.table(name).query_schema() 122 | t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) 123 | 124 | tbl = table(name, schema=t.schema) 125 | 126 | results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple]) 127 | 128 | created_at = results[0][1] 129 | updated_at = results[0][1] 130 | 131 | utc = now.astimezone(pytz.UTC) 132 | expected = utc.__format__("%Y-%m-%d %H:%M:%S.%f") 133 | 134 | 135 | self.assertEqual(created_at, expected) 136 | self.assertEqual(updated_at, expected) 137 | 138 | db.query(tbl.drop()) 139 | 140 | @test_each_database 141 | class TestThreePartIds(unittest.TestCase): 142 | def test_three_part_support(self): 143 | if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB]: 144 | self.skipTest("Limited support for 3 part ids") 145 | 146 | table_name = "tbl_" + random_table_suffix() 147 | db = get_conn(self.db_cls) 148 | db_res = db.query("SELECT CURRENT_DATABASE()") 149 | schema_res = db.query("SELECT CURRENT_SCHEMA()") 150 | db_name = db_res.rows[0][0] 151 | schema_name = schema_res.rows[0][0] 152 | 153 | table_one_part = table((table_name,), schema={"id": int}) 154 | table_two_part = table((schema_name, table_name), schema={"id": int}) 155 | table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) 156 | 157 | for part in (table_one_part, table_two_part, table_three_part): 158 | db.query(part.create()) 159 | d = db.query_table_schema(part.path) 160 | assert len(d) == 1 161 | db.query(part.drop()) 162 | 163 | -------------------------------------------------------------------------------- /sqeleton/queries/api.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from ..utils import CaseAwareMapping, CaseSensitiveDict 4 | from .ast_classes import * 5 | from .base import args_as_tuple 6 | 7 | 8 | this = This() 9 | 10 | 11 | def join(*tables: ITable) -> Join: 12 | """Inner-join a sequence of table expressions" 13 | 14 | When joining, it's recommended to use explicit tables names, instead of `this`, in order to avoid potential name collisions. 15 | 16 | Example: 17 | :: 18 | 19 | person = table('person') 20 | city = table('city') 21 | 22 | name_and_city = ( 23 | join(person, city) 24 | .on(person['city_id'] == city['id']) 25 | .select(person['id'], city['name']) 26 | ) 27 | """ 28 | return Join(tables) 29 | 30 | 31 | def leftjoin(*tables: ITable): 32 | """Left-joins a sequence of table expressions. 33 | 34 | See Also: ``join()`` 35 | """ 36 | return Join(tables, "LEFT") 37 | 38 | 39 | def rightjoin(*tables: ITable): 40 | """Right-joins a sequence of table expressions. 41 | 42 | See Also: ``join()`` 43 | """ 44 | return Join(tables, "RIGHT") 45 | 46 | 47 | def outerjoin(*tables: ITable): 48 | """Outer-joins a sequence of table expressions. 49 | 50 | See Also: ``join()`` 51 | """ 52 | return Join(tables, "FULL OUTER") 53 | 54 | 55 | def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None): 56 | """Define a CTE""" 57 | return Cte(expr, name, params) 58 | 59 | 60 | def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath: 61 | """Defines a table with a path (dotted name), and optionally a schema. 62 | 63 | Parameters: 64 | path: A list of names that make up the path to the table. 65 | schema: a dictionary of {name: type} 66 | """ 67 | if len(path) == 1 and isinstance(path[0], tuple): 68 | (path,) = path 69 | if not all(isinstance(i, str) for i in path): 70 | raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") 71 | if schema and not isinstance(schema, CaseAwareMapping): 72 | assert isinstance(schema, dict) 73 | schema = CaseSensitiveDict(schema) 74 | return TablePath(path, schema) 75 | 76 | 77 | def or_(*exprs: Expr): 78 | """Apply OR between a sequence of boolean expressions""" 79 | exprs = args_as_tuple(exprs) 80 | if len(exprs) == 1: 81 | return exprs[0] 82 | return BinBoolOp("OR", exprs) 83 | 84 | 85 | def and_(*exprs: Expr): 86 | """Apply AND between a sequence of boolean expressions""" 87 | exprs = args_as_tuple(exprs) 88 | if len(exprs) == 1: 89 | return exprs[0] 90 | return BinBoolOp("AND", exprs) 91 | 92 | 93 | def sum_(expr: Expr): 94 | """Call SUM(expr)""" 95 | return Func("sum", [expr]) 96 | 97 | 98 | def avg(expr: Expr): 99 | """Call AVG(expr)""" 100 | return Func("avg", [expr]) 101 | 102 | 103 | def min_(expr: Expr): 104 | """Call MIN(expr)""" 105 | return Func("min", [expr]) 106 | 107 | 108 | def max_(expr: Expr): 109 | """Call MAX(expr)""" 110 | return Func("max", [expr]) 111 | 112 | 113 | def exists(expr: Expr): 114 | """Call EXISTS(expr)""" 115 | return Func("exists", [expr]) 116 | 117 | 118 | def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): 119 | """Conditional expression, shortcut to when-then-else. 120 | 121 | Example: 122 | :: 123 | 124 | # SELECT CASE WHEN b THEN c ELSE d END FROM foo 125 | table('foo').select(if_(b, c, d)) 126 | """ 127 | return when(cond).then(then).else_(else_) 128 | 129 | 130 | def when(*when_exprs: Expr): 131 | """Start a when-then expression 132 | 133 | Example: 134 | :: 135 | 136 | # SELECT CASE 137 | # WHEN (type = 'text') THEN text 138 | # WHEN (type = 'number') THEN number 139 | # ELSE 'unknown type' END 140 | # FROM foo 141 | rows = table('foo').select( 142 | when(this.type == 'text').then(this.text) 143 | .when(this.type == 'number').then(this.number) 144 | .else_('unknown type') 145 | ) 146 | """ 147 | return CaseWhen([]).when(*when_exprs) 148 | 149 | 150 | def coalesce(*exprs): 151 | "Returns a call to COALESCE" 152 | exprs = args_as_tuple(exprs) 153 | return Func("COALESCE", exprs) 154 | 155 | 156 | def insert_rows_in_batches(db, tbl: TablePath, rows, *, columns=None, batch_size=1024 * 8): 157 | assert batch_size > 0 158 | rows = list(rows) 159 | 160 | while rows: 161 | batch, rows = rows[:batch_size], rows[batch_size:] 162 | db.query(tbl.insert_rows(batch, columns=columns)) 163 | 164 | 165 | def current_timestamp(): 166 | """Returns CURRENT_TIMESTAMP() or NOW()""" 167 | return CurrentTimestamp() 168 | 169 | 170 | def code(code: str, **kw: Dict[str, Expr]) -> Code: 171 | """Inline raw SQL code. 172 | 173 | It allows users to use features and syntax that Sqeleton doesn't yet support. 174 | 175 | It's the user's responsibility to make sure the contents of the string given to `code()` are correct and safe for execution. 176 | 177 | Strings given to `code()` are actually templates, and can embed query expressions given as arguments: 178 | 179 | Parameters: 180 | code: template string of SQL code. Templated variables are signified with '{var}'. 181 | kw: optional parameters for SQL template. 182 | 183 | Examples: 184 | :: 185 | 186 | # SELECT b, FROM tmp WHERE 187 | table('tmp').select(this.b, code("")).where(code("")) 188 | 189 | :: 190 | 191 | def tablesample(tbl, size): 192 | return code("SELECT * FROM {tbl} TABLESAMPLE BERNOULLI ({size})", tbl=tbl, size=size) 193 | 194 | nonzero = table('points').where(this.x > 0, this.y > 0) 195 | 196 | # SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10) 197 | sample_expr = tablesample(nonzero) 198 | """ 199 | return Code(code, kw) 200 | 201 | 202 | commit = Commit() 203 | -------------------------------------------------------------------------------- /sqeleton/databases/vertica.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils import match_regexps 4 | from .base import ( 5 | CHECKSUM_HEXDIGITS, 6 | MD5_HEXDIGITS, 7 | TIMESTAMP_PRECISION_POS, 8 | BaseDialect, 9 | ConnectError, 10 | DbPath, 11 | ColType, 12 | ThreadedDatabase, 13 | import_helper, 14 | Mixin_RandomSample, 15 | ) 16 | from ..abcs.database_types import ( 17 | Decimal, 18 | Float, 19 | FractionalType, 20 | Integer, 21 | TemporalType, 22 | Text, 23 | Timestamp, 24 | TimestampTZ, 25 | Boolean, 26 | ColType_UUID, 27 | ) 28 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema 29 | from ..abcs import Compilable 30 | from ..queries import table, this, SKIP 31 | 32 | 33 | @import_helper("vertica") 34 | def import_vertica(): 35 | import vertica_python 36 | 37 | return vertica_python 38 | 39 | 40 | class Mixin_MD5(AbstractMixin_MD5): 41 | def md5_as_int(self, s: str) -> str: 42 | return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" 43 | 44 | 45 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 46 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 47 | if coltype.rounds: 48 | return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" 49 | 50 | timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" 51 | return ( 52 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 53 | ) 54 | 55 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 56 | return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") 57 | 58 | def normalize_uuid(self, value: str, _coltype: ColType_UUID) -> str: 59 | # Trim doesn't work on CHAR type 60 | return f"TRIM(CAST({value} AS VARCHAR))" 61 | 62 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 63 | return self.to_string(f"cast ({value} as int)") 64 | 65 | 66 | class Mixin_Schema(AbstractMixin_Schema): 67 | def table_information(self) -> Compilable: 68 | return table("v_catalog", "tables") 69 | 70 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 71 | return ( 72 | self.table_information() 73 | .where( 74 | this.table_schema == table_schema, 75 | this.table_name.like(like) if like is not None else SKIP, 76 | ) 77 | .select(this.table_name) 78 | ) 79 | 80 | 81 | class Dialect(BaseDialect, Mixin_Schema): 82 | name = "Vertica" 83 | ROUNDS_ON_PREC_LOSS = True 84 | 85 | TYPE_CLASSES = { 86 | # Timestamps 87 | "timestamp": Timestamp, 88 | "timestamptz": TimestampTZ, 89 | # Numbers 90 | "numeric": Decimal, 91 | "int": Integer, 92 | "float": Float, 93 | # Text 94 | "char": Text, 95 | "varchar": Text, 96 | # Boolean 97 | "boolean": Boolean, 98 | } 99 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 100 | 101 | def quote(self, s: str): 102 | return f'"{s}"' 103 | 104 | def concat(self, items: List[str]) -> str: 105 | return " || ".join(items) 106 | 107 | def to_string(self, s: str) -> str: 108 | return f"CAST({s} AS VARCHAR)" 109 | 110 | def is_distinct_from(self, a: str, b: str) -> str: 111 | return f"not ({a} <=> {b})" 112 | 113 | def parse_type( 114 | self, 115 | table_path: DbPath, 116 | col_name: str, 117 | type_repr: str, 118 | datetime_precision: int = None, 119 | numeric_precision: int = None, 120 | numeric_scale: int = None, 121 | ) -> ColType: 122 | timestamp_regexps = { 123 | r"timestamp\(?(\d?)\)?": Timestamp, 124 | r"timestamptz\(?(\d?)\)?": TimestampTZ, 125 | } 126 | for m, t_cls in match_regexps(timestamp_regexps, type_repr): 127 | precision = int(m.group(1)) if m.group(1) else 6 128 | return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) 129 | 130 | number_regexps = { 131 | r"numeric\((\d+),(\d+)\)": Decimal, 132 | } 133 | for m, n_cls in match_regexps(number_regexps, type_repr): 134 | _prec, scale = map(int, m.groups()) 135 | return n_cls(scale) 136 | 137 | string_regexps = { 138 | r"varchar\((\d+)\)": Text, 139 | r"char\((\d+)\)": Text, 140 | } 141 | for m, n_cls in match_regexps(string_regexps, type_repr): 142 | return n_cls() 143 | 144 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) 145 | 146 | def set_timezone_to_utc(self) -> str: 147 | return "SET TIME ZONE TO 'UTC'" 148 | 149 | def current_timestamp(self) -> str: 150 | return "current_timestamp(6)" 151 | 152 | 153 | class Vertica(ThreadedDatabase): 154 | dialect = Dialect() 155 | CONNECT_URI_HELP = "vertica://:@/" 156 | CONNECT_URI_PARAMS = ["database?"] 157 | 158 | default_schema = "public" 159 | 160 | def __init__(self, *, thread_count, **kw): 161 | self._args = kw 162 | self._args["AUTOCOMMIT"] = False 163 | 164 | super().__init__(thread_count=thread_count) 165 | 166 | def create_connection(self): 167 | vertica = import_vertica() 168 | try: 169 | c = vertica.connect(**self._args) 170 | return c 171 | except vertica.errors.ConnectionError as e: 172 | raise ConnectError(*e.args) from e 173 | 174 | def select_table_schema(self, path: DbPath) -> str: 175 | schema, name = self._normalize_table_path(path) 176 | 177 | return ( 178 | "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " 179 | "FROM V_CATALOG.COLUMNS " 180 | f"WHERE table_name = '{name}' AND table_schema = '{schema}'" 181 | ) 182 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import string 4 | import random 5 | from typing import Callable 6 | import unittest 7 | import logging 8 | import subprocess 9 | 10 | import sqeleton 11 | from parameterized import parameterized_class 12 | 13 | from sqeleton import databases as db 14 | from sqeleton import connect 15 | from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue 16 | from sqeleton.queries import table 17 | from sqeleton.databases import Database 18 | from sqeleton.query_utils import drop_table 19 | 20 | 21 | # We write 'or None' because Github sometimes creates empty env vars for secrets 22 | TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" 23 | TEST_POSTGRESQL_CONN_STRING: str = "postgresql://postgres:Password1@localhost/postgres" 24 | TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("SNOWFLAKE_URI") or None 25 | TEST_PRESTO_CONN_STRING: str = os.environ.get("PRESTO_URI") or None 26 | TEST_BIGQUERY_CONN_STRING: str = os.environ.get("BIGQUERY_URI") or None 27 | TEST_REDSHIFT_CONN_STRING: str = os.environ.get("REDSHIFT_URI") or None 28 | TEST_ORACLE_CONN_STRING: str = None 29 | TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATABRICKS_URI") 30 | TEST_TRINO_CONN_STRING: str = os.environ.get("TRINO_URI") or None 31 | # clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" 32 | TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("CLICKHOUSE_URI") 33 | # vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica" 34 | TEST_VERTICA_CONN_STRING: str = os.environ.get("VERTICA_URI") 35 | TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" 36 | 37 | 38 | DEFAULT_N_SAMPLES = 50 39 | N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) 40 | BENCHMARK = os.environ.get("BENCHMARK", False) 41 | N_THREADS = int(os.environ.get("N_THREADS", 1)) 42 | TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite? 43 | 44 | 45 | def get_git_revision_short_hash() -> str: 46 | return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() 47 | 48 | 49 | GIT_REVISION = get_git_revision_short_hash() 50 | 51 | level = logging.ERROR 52 | if os.environ.get("LOG_LEVEL", False): 53 | level = getattr(logging, os.environ["LOG_LEVEL"].upper()) 54 | 55 | logging.basicConfig(level=level) 56 | logging.getLogger("database").setLevel(level) 57 | 58 | try: 59 | from .local_settings import * 60 | except ImportError: 61 | pass # No local settings 62 | 63 | 64 | CONN_STRINGS = { 65 | db.BigQuery: TEST_BIGQUERY_CONN_STRING, 66 | db.MySQL: TEST_MYSQL_CONN_STRING, 67 | db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING, 68 | db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, 69 | db.Redshift: TEST_REDSHIFT_CONN_STRING, 70 | db.Oracle: TEST_ORACLE_CONN_STRING, 71 | db.Presto: TEST_PRESTO_CONN_STRING, 72 | db.Databricks: TEST_DATABRICKS_CONN_STRING, 73 | db.Trino: TEST_TRINO_CONN_STRING, 74 | db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, 75 | db.Vertica: TEST_VERTICA_CONN_STRING, 76 | db.DuckDB: TEST_DUCKDB_CONN_STRING, 77 | } 78 | 79 | _database_instances = {} 80 | 81 | 82 | def get_conn(cls: type, shared: bool = True) -> Database: 83 | if shared: 84 | if cls not in _database_instances: 85 | _database_instances[cls] = get_conn(cls, shared=False) 86 | return _database_instances[cls] 87 | 88 | con = sqeleton.connect.load_mixins(AbstractMixin_NormalizeValue) 89 | return con(CONN_STRINGS[cls], N_THREADS) 90 | 91 | 92 | def _print_used_dbs(): 93 | used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None} 94 | unused = {k.__name__ for k, v in CONN_STRINGS.items() if v is None} 95 | 96 | print(f"Testing databases: {', '.join(used)}") 97 | if unused: 98 | logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") 99 | if TEST_ACROSS_ALL_DBS: 100 | logging.info( 101 | f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" 102 | ) 103 | 104 | 105 | _print_used_dbs() 106 | CONN_STRINGS = {k: v for k, v in CONN_STRINGS.items() if v is not None} 107 | 108 | 109 | def random_table_suffix() -> str: 110 | char_set = string.ascii_lowercase + string.digits 111 | suffix = "_" 112 | suffix += "".join(random.choice(char_set) for _ in range(5)) 113 | return suffix 114 | 115 | 116 | def str_to_checksum(str: str): 117 | # hello world 118 | # => 5eb63bbbe01eeed093cb22bb8f5acdc3 119 | # => cb22bb8f5acdc3 120 | # => 273350391345368515 121 | m = hashlib.md5() 122 | m.update(str.encode("utf-8")) # encode to binary 123 | md5 = m.hexdigest() 124 | # 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs 125 | half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS 126 | return int(md5[half_pos:], 16) 127 | 128 | 129 | class DbTestCase(unittest.TestCase): 130 | "Sets up a table for testing" 131 | db_cls = None 132 | table1_schema = None 133 | shared_connection = True 134 | 135 | def setUp(self): 136 | assert self.db_cls, self.db_cls 137 | 138 | self.connection = get_conn(self.db_cls, self.shared_connection) 139 | 140 | table_suffix = random_table_suffix() 141 | self.table1_name = f"src{table_suffix}" 142 | 143 | self.table1_path = self.connection.parse_table_name(self.table1_name) 144 | 145 | drop_table(self.connection, self.table1_path) 146 | 147 | self.src_table = table(self.table1_path, schema=self.table1_schema) 148 | if self.table1_schema: 149 | self.connection.query(self.src_table.create()) 150 | 151 | return super().setUp() 152 | 153 | def tearDown(self): 154 | drop_table(self.connection, self.table1_path) 155 | 156 | 157 | def _parameterized_class_per_conn(test_databases): 158 | test_databases = set(test_databases) 159 | names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] 160 | return parameterized_class(("name", "db_cls"), names) 161 | 162 | 163 | def test_each_database_in_list(databases) -> Callable: 164 | def _test_per_database(cls): 165 | return _parameterized_class_per_conn(databases)(cls) 166 | 167 | return _test_per_database 168 | -------------------------------------------------------------------------------- /sqeleton/databases/postgresql.py: -------------------------------------------------------------------------------- 1 | from ..abcs.database_types import ( 2 | DbPath, 3 | Timestamp, 4 | TimestampTZ, 5 | Float, 6 | Decimal, 7 | Integer, 8 | TemporalType, 9 | Native_UUID, 10 | Text, 11 | FractionalType, 12 | Boolean, 13 | Date, 14 | PostgresqlJSON, 15 | PostgresqlJSONB 16 | ) 17 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 18 | from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema 19 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS, Mixin_RandomSample 20 | 21 | SESSION_TIME_ZONE = None # Changed by the tests 22 | 23 | 24 | @import_helper("postgresql") 25 | def import_postgresql(): 26 | import psycopg2 27 | import psycopg2.extras 28 | 29 | psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) 30 | return psycopg2 31 | 32 | 33 | class Mixin_MD5(AbstractMixin_MD5): 34 | def md5_as_int(self, s: str) -> str: 35 | return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" 36 | 37 | 38 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 39 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 40 | if coltype.rounds: 41 | return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" 42 | 43 | timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" 44 | return ( 45 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 46 | ) 47 | 48 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 49 | return self.to_string(f"{value}::decimal(38, {coltype.precision})") 50 | 51 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 52 | return self.to_string(f"{value}::int") 53 | 54 | def normalize_json(self, value: str, _coltype: PostgresqlJSON) -> str: 55 | return f"{value}::text" 56 | 57 | 58 | class PostgresqlDialect(BaseDialect, Mixin_Schema): 59 | name = "PostgreSQL" 60 | ROUNDS_ON_PREC_LOSS = True 61 | SUPPORTS_PRIMARY_KEY = True 62 | SUPPORTS_INDEXES = True 63 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 64 | 65 | TYPE_CLASSES = { 66 | # Timestamps 67 | "timestamp with time zone": TimestampTZ, 68 | "timestamp without time zone": Timestamp, 69 | "timestamp": Timestamp, 70 | "date": Date, 71 | # Numbers 72 | "double precision": Float, 73 | "real": Float, 74 | "decimal": Decimal, 75 | "smallint": Integer, 76 | "integer": Integer, 77 | "numeric": Decimal, 78 | "bigint": Integer, 79 | # Text 80 | "character": Text, 81 | "character varying": Text, 82 | "varchar": Text, 83 | "text": Text, 84 | # JSON 85 | "json": PostgresqlJSON, 86 | "jsonb": PostgresqlJSONB, 87 | # UUID 88 | "uuid": Native_UUID, 89 | # Boolean 90 | "boolean": Boolean, 91 | } 92 | 93 | def quote(self, s: str): 94 | return f'"{s}"' 95 | 96 | def to_string(self, s: str): 97 | return f"{s}::varchar" 98 | 99 | def _convert_db_precision_to_digits(self, p: int) -> int: 100 | # Subtracting 2 due to wierd precision issues in PostgreSQL 101 | return super()._convert_db_precision_to_digits(p) - 2 102 | 103 | def set_timezone_to_utc(self) -> str: 104 | return "SET TIME ZONE 'UTC'" 105 | 106 | def current_timestamp(self) -> str: 107 | return "current_timestamp" 108 | 109 | def type_repr(self, t) -> str: 110 | if isinstance(t, TimestampTZ): 111 | return f"timestamp ({t.precision}) with time zone" 112 | return super().type_repr(t) 113 | 114 | 115 | class PostgreSQL(ThreadedDatabase): 116 | dialect = PostgresqlDialect() 117 | SUPPORTS_UNIQUE_CONSTAINT = True 118 | CONNECT_URI_HELP = "postgresql://:@/" 119 | CONNECT_URI_PARAMS = ["database?"] 120 | 121 | default_schema = "public" 122 | 123 | def __init__(self, *, thread_count, **kw): 124 | self._args = kw 125 | 126 | super().__init__(thread_count=thread_count) 127 | 128 | def create_connection(self): 129 | if not self._args: 130 | self._args["host"] = None # psycopg2 requires 1+ arguments 131 | 132 | pg = import_postgresql() 133 | try: 134 | c = pg.connect(**self._args) 135 | if SESSION_TIME_ZONE: 136 | c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") 137 | return c 138 | except pg.OperationalError as e: 139 | raise ConnectError(*e.args) from e 140 | 141 | def select_table_schema(self, path: DbPath) -> str: 142 | database, schema, table = self._normalize_table_path(path) 143 | 144 | info_schema_path = ["information_schema", "columns"] 145 | if database: 146 | info_schema_path.insert(0, database) 147 | 148 | return ( 149 | f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " 150 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 151 | ) 152 | 153 | def select_table_unique_columns(self, path: DbPath) -> str: 154 | database, schema, table = self._normalize_table_path(path) 155 | 156 | info_schema_path = ["information_schema", "key_column_usage"] 157 | if database: 158 | info_schema_path.insert(0, database) 159 | 160 | return ( 161 | "SELECT column_name " 162 | f"FROM {'.'.join(info_schema_path)} " 163 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 164 | ) 165 | 166 | def _normalize_table_path(self, path: DbPath) -> DbPath: 167 | if len(path) == 1: 168 | return None, self.default_schema, path[0] 169 | elif len(path) == 2: 170 | return None, path[0], path[1] 171 | elif len(path) == 3: 172 | return path 173 | 174 | raise ValueError( 175 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 176 | ) 177 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Documentation build configuration file, created by 5 | # sphinx-quickstart on Sun Aug 16 13:09:41 2020. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath("..")) 24 | sys.path.append(os.path.abspath("./_ext")) 25 | autodoc_member_order = "bysource" 26 | 27 | import builtins 28 | builtins.__sphinx_build__ = True 29 | 30 | # -- General configuration ------------------------------------------------ 31 | 32 | # If your documentation needs a minimal Sphinx version, state it here. 33 | # 34 | # needs_sphinx = '1.0' 35 | 36 | # Add any Sphinx extension module names here, as strings. They can be 37 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 38 | # ones. 39 | extensions = [ 40 | "sphinx.ext.autodoc", 41 | "sphinx.ext.napoleon", 42 | "sphinx.ext.coverage", 43 | "recommonmark", 44 | "sphinx_markdown_tables", 45 | "sphinx_copybutton", 46 | "enum_tools.autoenum", 47 | # 'sphinx_gallery.gen_gallery' 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ["_templates"] 52 | 53 | # The suffix(es) of source filenames. 54 | # You can specify multiple suffix as a list of string: 55 | # 56 | # source_suffix = ['.rst', '.md'] 57 | source_suffix = {".rst": "restructuredtext", ".md": "markdown"} 58 | 59 | 60 | # The master toctree document. 61 | master_doc = "index" 62 | 63 | # General information about the project. 64 | project = "Sqeleton" 65 | copyright = "Datafold" 66 | author = "Erez Shinan" 67 | 68 | # The version info for the project you're documenting, acts as replacement for 69 | # |version| and |release|, also used in various other places throughout the 70 | # built documents. 71 | # 72 | # The short X.Y version. 73 | version = "" 74 | # The full version, including alpha/beta/rc tags. 75 | release = "" 76 | 77 | # The language for content autogenerated by Sphinx. Refer to documentation 78 | # for a list of supported languages. 79 | # 80 | # This is also used if you do content translation via gettext catalogs. 81 | # Usually you set "language" from the command line for these cases. 82 | language = None 83 | 84 | # List of patterns, relative to source directory, that match files and 85 | # directories to ignore when looking for source files. 86 | # This patterns also effect to html_static_path and html_extra_path 87 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 88 | 89 | # The name of the Pygments (syntax highlighting) style to use. 90 | pygments_style = "sphinx" 91 | 92 | # If true, `todo` and `todoList` produce output, else they produce nothing. 93 | todo_include_todos = False 94 | 95 | autodoc_default_options = { 96 | # 'special-members': '__init__', 97 | 'exclude-members': 'json,aslist,astuple,replace', 98 | } 99 | 100 | # -- Options for HTML output ---------------------------------------------- 101 | 102 | # The theme to use for HTML and HTML Help pages. See the documentation for 103 | # a list of builtin themes. 104 | # 105 | html_theme = "sphinx_rtd_theme" 106 | 107 | # Theme options are theme-specific and customize the look and feel of a theme 108 | # further. For a list of options available for each theme, see the 109 | # documentation. 110 | # 111 | # html_theme_options = {} 112 | 113 | # Add any paths that contain custom static files (such as style sheets) here, 114 | # relative to this directory. They are copied after the builtin static files, 115 | # so a file named "default.css" will overwrite the builtin "default.css". 116 | html_static_path = ["_static"] 117 | 118 | html_css_files = [ 119 | "custom.css", 120 | ] 121 | 122 | # Custom sidebar templates, must be a dictionary that maps document names 123 | # to template names. 124 | # 125 | # This is required for the alabaster theme 126 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 127 | html_sidebars = { 128 | "**": [ 129 | "relations.html", # needs 'show_related': True theme option to display 130 | "searchbox.html", 131 | ] 132 | } 133 | 134 | 135 | # -- Options for HTMLHelp output ------------------------------------------ 136 | 137 | # Output file base name for HTML help builder. 138 | htmlhelp_basename = "sqeletondoc" 139 | 140 | 141 | # -- Options for LaTeX output --------------------------------------------- 142 | 143 | latex_elements = { 144 | # The paper size ('letterpaper' or 'a4paper'). 145 | # 146 | # 'papersize': 'letterpaper', 147 | # The font size ('10pt', '11pt' or '12pt'). 148 | # 149 | # 'pointsize': '10pt', 150 | # Additional stuff for the LaTeX preamble. 151 | # 152 | # 'preamble': '', 153 | # Latex figure (float) alignment 154 | # 155 | # 'figure_align': 'htbp', 156 | } 157 | 158 | # Grouping the document tree into LaTeX files. List of tuples 159 | # (source start file, target name, title, 160 | # author, documentclass [howto, manual, or own class]). 161 | latex_documents = [ 162 | (master_doc, "Sqeleton.tex", "Sqeleton Documentation", "Erez Shinan", "manual"), 163 | ] 164 | 165 | 166 | # -- Options for manual page output --------------------------------------- 167 | 168 | # One entry per manual page. List of tuples 169 | # (source start file, name, description, authors, manual section). 170 | man_pages = [(master_doc, "Sqeleton", "Sqeleton Documentation", [author], 1)] 171 | 172 | 173 | # -- Options for Texinfo output ------------------------------------------- 174 | 175 | # Grouping the document tree into Texinfo files. List of tuples 176 | # (source start file, target name, title, author, 177 | # dir menu entry, description, category) 178 | texinfo_documents = [ 179 | ( 180 | master_doc, 181 | "Sqeleton", 182 | "Sqeleton Documentation", 183 | author, 184 | "Sqeleton", 185 | "One line description of project.", 186 | "Miscellaneous", 187 | ), 188 | ] 189 | 190 | # -- Sphinx gallery config ------------------------------------------- 191 | 192 | # sphinx_gallery_conf = { 193 | # 'examples_dirs': ['../examples'], 194 | # 'gallery_dirs': ['examples'], 195 | # } 196 | -------------------------------------------------------------------------------- /sqeleton/abcs/mixins.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSONType 3 | from .compiler import Compilable 4 | 5 | 6 | class AbstractMixin(ABC): 7 | "A mixin for a database dialect" 8 | 9 | 10 | class AbstractMixin_NormalizeValue(AbstractMixin): 11 | @abstractmethod 12 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 13 | """Creates an SQL expression, that converts 'value' to a normalized timestamp. 14 | 15 | The returned expression must accept any SQL datetime/timestamp, and return a string. 16 | 17 | Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` 18 | 19 | Precision of dates should be rounded up/down according to coltype.rounds 20 | """ 21 | 22 | @abstractmethod 23 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 24 | """Creates an SQL expression, that converts 'value' to a normalized number. 25 | 26 | The returned expression must accept any SQL int/numeric/float, and return a string. 27 | 28 | Floats/Decimals are expected in the format 29 | "I.P" 30 | 31 | Where I is the integer part of the number (as many digits as necessary), 32 | and must be at least one digit (0). 33 | P is the fractional digits, the amount of which is specified with 34 | coltype.precision. Trailing zeroes may be necessary. 35 | If P is 0, the dot is omitted. 36 | 37 | Note: We use 'precision' differently than most databases. For decimals, 38 | it's the same as ``numeric_scale``, and for floats, who use binary precision, 39 | it can be calculated as ``log10(2**numeric_precision)``. 40 | """ 41 | 42 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 43 | """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" 44 | return self.to_string(value) 45 | 46 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 47 | """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" 48 | if isinstance(coltype, String_UUID): 49 | return f"TRIM({value})" 50 | return self.to_string(value) 51 | 52 | def normalize_json(self, value: str, _coltype: JSONType) -> str: 53 | """Creates an SQL expression, that converts 'value' to its minified json string representation.""" 54 | raise NotImplementedError() 55 | 56 | def normalize_value_by_type(self, value: str, coltype: ColType) -> str: 57 | """Creates an SQL expression, that converts 'value' to a normalized representation. 58 | 59 | The returned expression must accept any SQL value, and return a string. 60 | 61 | The default implementation dispatches to a method according to `coltype`: 62 | 63 | :: 64 | 65 | TemporalType -> normalize_timestamp() 66 | FractionalType -> normalize_number() 67 | *else* -> to_string() 68 | 69 | (`Integer` falls in the *else* category) 70 | 71 | """ 72 | if isinstance(coltype, TemporalType): 73 | return self.normalize_timestamp(value, coltype) 74 | elif isinstance(coltype, FractionalType): 75 | return self.normalize_number(value, coltype) 76 | elif isinstance(coltype, ColType_UUID): 77 | return self.normalize_uuid(value, coltype) 78 | elif isinstance(coltype, Boolean): 79 | return self.normalize_boolean(value, coltype) 80 | elif isinstance(coltype, JSONType): 81 | return self.normalize_json(value, coltype) 82 | return self.to_string(value) 83 | 84 | 85 | class AbstractMixin_MD5(AbstractMixin): 86 | """Methods for calculating an MD6 hash as an integer.""" 87 | 88 | @abstractmethod 89 | def md5_as_int(self, s: str) -> str: 90 | "Provide SQL for computing md5 and returning an int" 91 | 92 | 93 | class AbstractMixin_Schema(AbstractMixin): 94 | """Methods for querying the database schema 95 | 96 | TODO: Move AbstractDatabase.query_table_schema() and friends over here 97 | """ 98 | 99 | def table_information(self) -> Compilable: 100 | "Query to return a table of schema information about existing tables" 101 | raise NotImplementedError() 102 | 103 | @abstractmethod 104 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 105 | """Query to select the list of tables in the schema. (query return type: table[str]) 106 | 107 | If 'like' is specified, the value is applied to the table name, using the 'like' operator. 108 | """ 109 | 110 | 111 | class AbstractMixin_Regex(AbstractMixin): 112 | @abstractmethod 113 | def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: 114 | """Tests whether the regex pattern matches the string. Returns a bool expression.""" 115 | 116 | 117 | class AbstractMixin_RandomSample(AbstractMixin): 118 | @abstractmethod 119 | def random_sample_n(self, tbl: str, size: int) -> str: 120 | """Take a random sample of the given size, i.e. return 'size' amount of rows""" 121 | 122 | @abstractmethod 123 | def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str: 124 | """Take a random sample of the approximate size determined by the ratio (0..1), where 0 means no rows, and 1 means all rows 125 | 126 | i.e. the actual mount of rows returned may vary by standard deviation. 127 | """ 128 | 129 | # def random_sample_ratio(self, table: AbstractTable, ratio: float): 130 | # """Take a random sample of the size determined by the ratio (0..1), where 0 means no rows, and 1 means all rows 131 | # """ 132 | 133 | 134 | class AbstractMixin_TimeTravel(AbstractMixin): 135 | @abstractmethod 136 | def time_travel( 137 | self, 138 | table: Compilable, 139 | before: bool = False, 140 | timestamp: Compilable = None, 141 | offset: Compilable = None, 142 | statement: Compilable = None, 143 | ) -> Compilable: 144 | """Selects historical data from a table 145 | 146 | Parameters: 147 | table - The name of the table whose history we're querying 148 | timestamp - A constant timestamp 149 | offset - the time 'offset' seconds before now 150 | statement - identifier for statement, e.g. query ID 151 | 152 | Must specify exactly one of `timestamp`, `offset` or `statement`. 153 | """ 154 | 155 | 156 | class AbstractMixin_OptimizerHints(AbstractMixin): 157 | @abstractmethod 158 | def optimizer_hints(self, optimizer_hints: str) -> str: 159 | """Creates a compatible optimizer_hints string 160 | 161 | Parameters: 162 | optimizer_hints - string of optimizer hints 163 | """ 164 | -------------------------------------------------------------------------------- /sqeleton/databases/duckdb.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from ..utils import match_regexps 4 | from ..abcs.database_types import ( 5 | Timestamp, 6 | TimestampTZ, 7 | DbPath, 8 | ColType, 9 | Float, 10 | Decimal, 11 | Integer, 12 | TemporalType, 13 | Native_UUID, 14 | Text, 15 | FractionalType, 16 | Boolean, 17 | AbstractTable, 18 | ) 19 | from ..abcs.mixins import ( 20 | AbstractMixin_MD5, 21 | AbstractMixin_NormalizeValue, 22 | AbstractMixin_RandomSample, 23 | AbstractMixin_Regex, 24 | ) 25 | from .base import ( 26 | Database, 27 | BaseDialect, 28 | import_helper, 29 | ConnectError, 30 | ThreadLocalInterpreter, 31 | TIMESTAMP_PRECISION_POS, 32 | ) 33 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema 34 | from ..queries.ast_classes import Func, Compilable 35 | from ..queries.api import code 36 | 37 | 38 | @import_helper("duckdb") 39 | def import_duckdb(): 40 | import duckdb 41 | 42 | return duckdb 43 | 44 | 45 | class Mixin_MD5(AbstractMixin_MD5): 46 | def md5_as_int(self, s: str) -> str: 47 | return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" 48 | 49 | 50 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 51 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 52 | # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. 53 | if coltype.rounds and coltype.precision > 0: 54 | return f"CONCAT(SUBSTRING(STRFTIME({value}::TIMESTAMP, '%Y-%m-%d %H:%M:%S.'),1,23), LPAD(((ROUND(strftime({value}::timestamp, '%f')::DECIMAL(15,7)/100000,{coltype.precision-1})*100000)::INT)::VARCHAR,6,'0'))" 55 | 56 | return f"rpad(substring(strftime({value}::timestamp, '%Y-%m-%d %H:%M:%S.%f'),1,{TIMESTAMP_PRECISION_POS+coltype.precision}),26,'0')" 57 | 58 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 59 | return self.to_string(f"{value}::DECIMAL(38, {coltype.precision})") 60 | 61 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 62 | return self.to_string(f"{value}::INTEGER") 63 | 64 | 65 | class Mixin_RandomSample(AbstractMixin_RandomSample): 66 | def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: 67 | return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) 68 | 69 | def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: 70 | return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) 71 | 72 | 73 | class Mixin_Regex(AbstractMixin_Regex): 74 | def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: 75 | return Func("regexp_matches", [string, pattern]) 76 | 77 | 78 | class Dialect(BaseDialect, Mixin_Schema): 79 | name = "DuckDB" 80 | ROUNDS_ON_PREC_LOSS = False 81 | SUPPORTS_PRIMARY_KEY = True 82 | SUPPORTS_INDEXES = True 83 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 84 | 85 | TYPE_CLASSES = { 86 | # Timestamps 87 | "TIMESTAMP WITH TIME ZONE": TimestampTZ, 88 | "TIMESTAMP": Timestamp, 89 | # Numbers 90 | "DOUBLE": Float, 91 | "FLOAT": Float, 92 | "DECIMAL": Decimal, 93 | "INTEGER": Integer, 94 | "BIGINT": Integer, 95 | # Text 96 | "VARCHAR": Text, 97 | "TEXT": Text, 98 | # UUID 99 | "UUID": Native_UUID, 100 | # Bool 101 | "BOOLEAN": Boolean, 102 | } 103 | 104 | def quote(self, s: str): 105 | return f'"{s}"' 106 | 107 | def to_string(self, s: str): 108 | return f"{s}::VARCHAR" 109 | 110 | def _convert_db_precision_to_digits(self, p: int) -> int: 111 | # Subtracting 2 due to wierd precision issues in PostgreSQL 112 | return super()._convert_db_precision_to_digits(p) - 2 113 | 114 | def parse_type( 115 | self, 116 | table_path: DbPath, 117 | col_name: str, 118 | type_repr: str, 119 | datetime_precision: int = None, 120 | numeric_precision: int = None, 121 | numeric_scale: int = None, 122 | ) -> ColType: 123 | regexps = { 124 | r"DECIMAL\((\d+),(\d+)\)": Decimal, 125 | } 126 | 127 | for m, t_cls in match_regexps(regexps, type_repr): 128 | precision = int(m.group(2)) 129 | return t_cls(precision=precision) 130 | 131 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) 132 | 133 | def set_timezone_to_utc(self) -> str: 134 | return "SET GLOBAL TimeZone='UTC'" 135 | 136 | def current_timestamp(self) -> str: 137 | return "current_timestamp" 138 | 139 | 140 | class DuckDB(Database): 141 | dialect = Dialect() 142 | SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it 143 | default_schema = "main" 144 | CONNECT_URI_HELP = "duckdb://@" 145 | CONNECT_URI_PARAMS = ["database", "dbpath"] 146 | 147 | def __init__(self, **kw): 148 | self._args = kw 149 | self._conn = self.create_connection() 150 | 151 | @property 152 | def is_autocommit(self) -> bool: 153 | return True 154 | 155 | def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): 156 | "Uses the standard SQL cursor interface" 157 | return self._query_conn(self._conn, sql_code) 158 | 159 | def close(self): 160 | super().close() 161 | self._conn.close() 162 | 163 | def create_connection(self): 164 | ddb = import_duckdb() 165 | try: 166 | return ddb.connect(self._args["filepath"]) 167 | except ddb.OperationalError as e: 168 | raise ConnectError(*e.args) from e 169 | 170 | def select_table_schema(self, path: DbPath) -> str: 171 | database, schema, table = self._normalize_table_path(path) 172 | 173 | info_schema_path = ["information_schema", "columns"] 174 | if database: 175 | info_schema_path.insert(0, database) 176 | 177 | return ( 178 | f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " 179 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 180 | ) 181 | 182 | def _normalize_table_path(self, path: DbPath) -> DbPath: 183 | if len(path) == 1: 184 | return None, self.default_schema, path[0] 185 | elif len(path) == 2: 186 | return None, path[0], path[1] 187 | elif len(path) == 3: 188 | return path 189 | 190 | raise ValueError( 191 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 192 | ) 193 | -------------------------------------------------------------------------------- /sqeleton/databases/presto.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import re 3 | 4 | from ..utils import match_regexps 5 | 6 | from ..abcs.database_types import ( 7 | Timestamp, 8 | TimestampTZ, 9 | Integer, 10 | Float, 11 | Text, 12 | FractionalType, 13 | DbPath, 14 | DbTime, 15 | Decimal, 16 | ColType, 17 | ColType_UUID, 18 | TemporalType, 19 | Boolean, 20 | ) 21 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 22 | from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter, Mixin_Schema, Mixin_RandomSample 23 | from .base import ( 24 | MD5_HEXDIGITS, 25 | CHECKSUM_HEXDIGITS, 26 | TIMESTAMP_PRECISION_POS, 27 | ) 28 | 29 | 30 | def query_cursor(c, sql_code): 31 | c.execute(sql_code) 32 | if sql_code.lower().startswith("select"): 33 | return c.fetchall() 34 | # Required for the query to actually run 🤯 35 | if re.match(r"(insert|create|truncate|drop|explain)", sql_code, re.IGNORECASE): 36 | return c.fetchone() 37 | 38 | 39 | @import_helper("presto") 40 | def import_presto(): 41 | import prestodb 42 | 43 | return prestodb 44 | 45 | 46 | class Mixin_MD5(AbstractMixin_MD5): 47 | def md5_as_int(self, s: str) -> str: 48 | return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" 49 | 50 | 51 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 52 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 53 | # Trim doesn't work on CHAR type 54 | return f"TRIM(CAST({value} AS VARCHAR))" 55 | 56 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 57 | # TODO rounds 58 | if coltype.rounds: 59 | s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" 60 | else: 61 | s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" 62 | 63 | return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" 64 | 65 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 66 | return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") 67 | 68 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 69 | return self.to_string(f"cast ({value} as int)") 70 | 71 | 72 | class Dialect(BaseDialect, Mixin_Schema): 73 | name = "Presto" 74 | ROUNDS_ON_PREC_LOSS = True 75 | TYPE_CLASSES = { 76 | # Timestamps 77 | "timestamp with time zone": TimestampTZ, 78 | "timestamp without time zone": Timestamp, 79 | "timestamp": Timestamp, 80 | # Numbers 81 | "integer": Integer, 82 | "bigint": Integer, 83 | "real": Float, 84 | "double": Float, 85 | # Text 86 | "varchar": Text, 87 | # Boolean 88 | "boolean": Boolean, 89 | } 90 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 91 | 92 | def explain_as_text(self, query: str) -> str: 93 | return f"EXPLAIN (FORMAT TEXT) {query}" 94 | 95 | def type_repr(self, t) -> str: 96 | if isinstance(t, TimestampTZ): 97 | return f"timestamp with time zone" 98 | 99 | try: 100 | return {float: "REAL"}[t] 101 | except KeyError: 102 | return super().type_repr(t) 103 | 104 | def timestamp_value(self, t: DbTime) -> str: 105 | return f"timestamp '{t.isoformat(' ')}'" 106 | 107 | def quote(self, s: str): 108 | return f'"{s}"' 109 | 110 | def to_string(self, s: str): 111 | return f"cast({s} as varchar)" 112 | 113 | def parse_type( 114 | self, 115 | table_path: DbPath, 116 | col_name: str, 117 | type_repr: str, 118 | datetime_precision: int = None, 119 | numeric_precision: int = None, 120 | _numeric_scale: int = None, 121 | ) -> ColType: 122 | timestamp_regexps = { 123 | r"timestamp\((\d)\)": Timestamp, 124 | r"timestamp\((\d)\) with time zone": TimestampTZ, 125 | } 126 | for m, t_cls in match_regexps(timestamp_regexps, type_repr): 127 | precision = int(m.group(1)) 128 | return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) 129 | 130 | number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} 131 | for m, n_cls in match_regexps(number_regexps, type_repr): 132 | _prec, scale = map(int, m.groups()) 133 | return n_cls(scale) 134 | 135 | string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} 136 | for m, n_cls in match_regexps(string_regexps, type_repr): 137 | return n_cls() 138 | 139 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) 140 | 141 | def set_timezone_to_utc(self) -> str: 142 | return "SET TIME ZONE '+00:00'" 143 | 144 | def current_timestamp(self) -> str: 145 | return "current_timestamp" 146 | 147 | 148 | class Presto(Database): 149 | dialect = Dialect() 150 | CONNECT_URI_HELP = "presto://@//" 151 | CONNECT_URI_PARAMS = ["catalog", "schema"] 152 | 153 | default_schema = "public" 154 | 155 | def __init__(self, **kw): 156 | prestodb = import_presto() 157 | 158 | if kw.get("schema"): 159 | self.default_schema = kw.get("schema") 160 | 161 | if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto 162 | kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) 163 | 164 | if "cert" in kw: # if a certificate was specified in URI, verify session with cert 165 | cert = kw.pop("cert") 166 | self._conn = prestodb.dbapi.connect(**kw) 167 | self._conn._http_session.verify = cert 168 | else: 169 | self._conn = prestodb.dbapi.connect(**kw) 170 | 171 | def _query(self, sql_code: str) -> list: 172 | "Uses the standard SQL cursor interface" 173 | c = self._conn.cursor() 174 | 175 | if isinstance(sql_code, ThreadLocalInterpreter): 176 | return sql_code.apply_queries(partial(query_cursor, c)) 177 | 178 | return query_cursor(c, sql_code) 179 | 180 | def close(self): 181 | super().close() 182 | self._conn.close() 183 | 184 | def select_table_schema(self, path: DbPath) -> str: 185 | schema, table = self._normalize_table_path(path) 186 | 187 | return ( 188 | "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " 189 | "FROM INFORMATION_SCHEMA.COLUMNS " 190 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 191 | ) 192 | 193 | @property 194 | def is_autocommit(self) -> bool: 195 | return False 196 | -------------------------------------------------------------------------------- /sqeleton/databases/redshift.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from ..abcs.database_types import ( 3 | Float, 4 | TemporalType, 5 | FractionalType, 6 | DbPath, 7 | TimestampTZ, 8 | RedShiftSuper 9 | ) 10 | from ..abcs.mixins import AbstractMixin_MD5 11 | from .postgresql import ( 12 | PostgreSQL, 13 | MD5_HEXDIGITS, 14 | CHECKSUM_HEXDIGITS, 15 | TIMESTAMP_PRECISION_POS, 16 | PostgresqlDialect, 17 | Mixin_NormalizeValue, 18 | ) 19 | 20 | 21 | class Mixin_MD5(AbstractMixin_MD5): 22 | def md5_as_int(self, s: str) -> str: 23 | return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" 24 | 25 | 26 | class Mixin_NormalizeValue(Mixin_NormalizeValue): 27 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 28 | if coltype.rounds: 29 | timestamp = f"{value}::timestamp(6)" 30 | # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. 31 | secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" 32 | # Get the milliseconds from timestamp. 33 | ms = f"extract(ms from {timestamp})" 34 | # Get the microseconds from timestamp, without the milliseconds! 35 | us = f"extract(us from {timestamp})" 36 | # epoch = Total time since epoch in microseconds. 37 | epoch = f"{secs}*1000000 + {ms}*1000 + {us}" 38 | timestamp6 = ( 39 | f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" 40 | ) 41 | else: 42 | timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" 43 | return ( 44 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 45 | ) 46 | 47 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 48 | return self.to_string(f"{value}::decimal(38,{coltype.precision})") 49 | 50 | def normalize_json(self, value: str, _coltype: RedShiftSuper) -> str: 51 | return f'nvl2({value}, json_serialize({value}), NULL)' 52 | 53 | 54 | class Dialect(PostgresqlDialect): 55 | name = "Redshift" 56 | TYPE_CLASSES = { 57 | **PostgresqlDialect.TYPE_CLASSES, 58 | "double": Float, 59 | "real": Float, 60 | # JSON 61 | "super": RedShiftSuper 62 | } 63 | SUPPORTS_INDEXES = False 64 | 65 | def concat(self, items: List[str]) -> str: 66 | joined_exprs = " || ".join(items) 67 | return f"({joined_exprs})" 68 | 69 | def is_distinct_from(self, a: str, b: str) -> str: 70 | return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})" 71 | 72 | def type_repr(self, t) -> str: 73 | if isinstance(t, TimestampTZ): 74 | return f"timestamptz" 75 | return super().type_repr(t) 76 | 77 | 78 | class Redshift(PostgreSQL): 79 | dialect = Dialect() 80 | CONNECT_URI_HELP = "redshift://:@/" 81 | CONNECT_URI_PARAMS = ["database?"] 82 | 83 | def select_table_schema(self, path: DbPath) -> str: 84 | database, schema, table = self._normalize_table_path(path) 85 | 86 | info_schema_path = ["information_schema", "columns"] 87 | if database: 88 | info_schema_path.insert(0, database) 89 | 90 | return ( 91 | f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " 92 | f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" 93 | ) 94 | 95 | def select_external_table_schema(self, path: DbPath) -> str: 96 | database, schema, table = self._normalize_table_path(path) 97 | 98 | db_clause = "" 99 | if database: 100 | db_clause = f" AND redshift_database_name = '{database.lower()}'" 101 | 102 | return ( 103 | f"""SELECT 104 | columnname AS column_name 105 | , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type 106 | , NULL AS datetime_precision 107 | , NULL AS numeric_precision 108 | , NULL AS numeric_scale 109 | FROM svv_external_columns 110 | WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}' 111 | """ 112 | + db_clause 113 | ) 114 | 115 | def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: 116 | rows = self.query(self.select_external_table_schema(path), list) 117 | if not rows: 118 | raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") 119 | 120 | d = {r[0]: r for r in rows} 121 | assert len(d) == len(rows) 122 | return d 123 | 124 | def select_view_columns(self, path: DbPath) -> str: 125 | _, schema, table = self._normalize_table_path(path) 126 | 127 | return ( 128 | """select * from pg_get_cols('{}.{}') 129 | cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) 130 | """.format(schema, table) 131 | ) 132 | 133 | def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: 134 | rows = self.query(self.select_view_columns(path), list) 135 | 136 | if not rows: 137 | raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") 138 | 139 | output = {} 140 | for r in rows: 141 | col_name = r[2] 142 | type_info = r[3].split('(') 143 | base_type = type_info[0] 144 | precision = None 145 | scale = None 146 | 147 | if len(type_info) > 1: 148 | if base_type == 'numeric': 149 | precision, scale = type_info[1][:-1].split(',') 150 | precision = int(precision) 151 | scale = int(scale) 152 | 153 | out = [col_name, base_type, None, precision, scale] 154 | output[col_name] = tuple(out) 155 | 156 | return output 157 | 158 | def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: 159 | try: 160 | return super().query_table_schema(path) 161 | except RuntimeError: 162 | try: 163 | return self.query_external_table_schema(path) 164 | except RuntimeError: 165 | return self.query_pg_get_cols() 166 | 167 | def _normalize_table_path(self, path: DbPath) -> DbPath: 168 | if len(path) == 1: 169 | return None, self.default_schema, path[0] 170 | elif len(path) == 2: 171 | return None, path[0], path[1] 172 | elif len(path) == 3: 173 | return path 174 | 175 | raise ValueError( 176 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 177 | ) 178 | -------------------------------------------------------------------------------- /sqeleton/databases/oracle.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | from ..utils import match_regexps 4 | from ..abcs.database_types import ( 5 | Decimal, 6 | Float, 7 | Text, 8 | DbPath, 9 | TemporalType, 10 | ColType, 11 | DbTime, 12 | ColType_UUID, 13 | Timestamp, 14 | TimestampTZ, 15 | FractionalType, 16 | ) 17 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema 18 | from ..abcs import Compilable 19 | from ..queries import this, table, SKIP 20 | from .base import ( 21 | BaseDialect, 22 | Mixin_OptimizerHints, 23 | ThreadedDatabase, 24 | import_helper, 25 | ConnectError, 26 | QueryError, 27 | Mixin_RandomSample, 28 | ) 29 | from .base import TIMESTAMP_PRECISION_POS 30 | 31 | SESSION_TIME_ZONE = None # Changed by the tests 32 | 33 | 34 | @import_helper("oracle") 35 | def import_oracle(): 36 | import cx_Oracle 37 | 38 | return cx_Oracle 39 | 40 | 41 | class Mixin_MD5(AbstractMixin_MD5): 42 | def md5_as_int(self, s: str) -> str: 43 | # standard_hash is faster than DBMS_CRYPTO.Hash 44 | # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? 45 | return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" 46 | 47 | 48 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 49 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 50 | # Cast is necessary for correct MD5 (trimming not enough) 51 | return f"CAST(TRIM({value}) AS VARCHAR(36))" 52 | 53 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 54 | if coltype.rounds: 55 | return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" 56 | 57 | if coltype.precision > 0: 58 | truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" 59 | else: 60 | truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" 61 | return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" 62 | 63 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 64 | # FM999.9990 65 | format_str = "FM" + "9" * (38 - coltype.precision) 66 | if coltype.precision: 67 | format_str += "0." + "9" * (coltype.precision - 1) + "0" 68 | return f"to_char({value}, '{format_str}')" 69 | 70 | 71 | class Mixin_Schema(AbstractMixin_Schema): 72 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 73 | return ( 74 | table("ALL_TABLES") 75 | .where( 76 | this.OWNER == table_schema, 77 | this.TABLE_NAME.like(like) if like is not None else SKIP, 78 | ) 79 | .select(table_name=this.TABLE_NAME) 80 | ) 81 | 82 | 83 | class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): 84 | name = "Oracle" 85 | SUPPORTS_PRIMARY_KEY = True 86 | SUPPORTS_INDEXES = True 87 | TYPE_CLASSES: Dict[str, type] = { 88 | "NUMBER": Decimal, 89 | "FLOAT": Float, 90 | # Text 91 | "CHAR": Text, 92 | "NCHAR": Text, 93 | "NVARCHAR2": Text, 94 | "VARCHAR2": Text, 95 | "DATE": Timestamp, 96 | } 97 | ROUNDS_ON_PREC_LOSS = True 98 | PLACEHOLDER_TABLE = "DUAL" 99 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 100 | 101 | def quote(self, s: str): 102 | return f'"{s}"' 103 | 104 | def to_string(self, s: str): 105 | return f"cast({s} as varchar(1024))" 106 | 107 | def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): 108 | if offset: 109 | raise NotImplementedError("No support for OFFSET in query") 110 | 111 | return f"FETCH NEXT {limit} ROWS ONLY" 112 | 113 | def concat(self, items: List[str]) -> str: 114 | joined_exprs = " || ".join(items) 115 | return f"({joined_exprs})" 116 | 117 | def timestamp_value(self, t: DbTime) -> str: 118 | return "timestamp '%s'" % t.isoformat(" ") 119 | 120 | def random(self) -> str: 121 | return "dbms_random.value" 122 | 123 | def is_distinct_from(self, a: str, b: str) -> str: 124 | return f"DECODE({a}, {b}, 1, 0) = 0" 125 | 126 | def type_repr(self, t) -> str: 127 | try: 128 | return { 129 | str: "VARCHAR(1024)", 130 | }[t] 131 | except KeyError: 132 | return super().type_repr(t) 133 | 134 | def constant_values(self, rows) -> str: 135 | return " UNION ALL ".join( 136 | "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows 137 | ) 138 | 139 | def explain_as_text(self, query: str) -> str: 140 | raise NotImplementedError("Explain not yet implemented in Oracle") 141 | 142 | def parse_type( 143 | self, 144 | table_path: DbPath, 145 | col_name: str, 146 | type_repr: str, 147 | datetime_precision: int = None, 148 | numeric_precision: int = None, 149 | numeric_scale: int = None, 150 | ) -> ColType: 151 | regexps = { 152 | r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, 153 | r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, 154 | r"TIMESTAMP\((\d)\)": Timestamp, 155 | } 156 | 157 | for m, t_cls in match_regexps(regexps, type_repr): 158 | precision = int(m.group(1)) 159 | return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) 160 | 161 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) 162 | 163 | def set_timezone_to_utc(self) -> str: 164 | return "ALTER SESSION SET TIME_ZONE = 'UTC'" 165 | 166 | def current_timestamp(self) -> str: 167 | return "LOCALTIMESTAMP" 168 | 169 | 170 | class Oracle(ThreadedDatabase): 171 | dialect = Dialect() 172 | CONNECT_URI_HELP = "oracle://:@/" 173 | CONNECT_URI_PARAMS = ["database?"] 174 | 175 | def __init__(self, *, host, database, thread_count, **kw): 176 | self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) 177 | 178 | self.default_schema = kw.get("user").upper() 179 | 180 | super().__init__(thread_count=thread_count) 181 | 182 | def create_connection(self): 183 | self._oracle = import_oracle() 184 | try: 185 | c = self._oracle.connect(**self.kwargs) 186 | if SESSION_TIME_ZONE: 187 | c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") 188 | return c 189 | except Exception as e: 190 | raise ConnectError(*e.args) from e 191 | 192 | def _query_cursor(self, c, sql_code: str): 193 | try: 194 | return super()._query_cursor(c, sql_code) 195 | except self._oracle.DatabaseError as e: 196 | raise QueryError(e) 197 | 198 | def select_table_schema(self, path: DbPath) -> str: 199 | schema, name = self._normalize_table_path(path) 200 | 201 | return ( 202 | f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" 203 | f" FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'" 204 | ) 205 | -------------------------------------------------------------------------------- /sqeleton/databases/clickhouse.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | from .base import ( 4 | MD5_HEXDIGITS, 5 | CHECKSUM_HEXDIGITS, 6 | TIMESTAMP_PRECISION_POS, 7 | BaseDialect, 8 | ThreadedDatabase, 9 | import_helper, 10 | ConnectError, 11 | Mixin_RandomSample, 12 | ) 13 | from ..abcs.database_types import ( 14 | ColType, 15 | Decimal, 16 | Float, 17 | Integer, 18 | FractionalType, 19 | Native_UUID, 20 | TemporalType, 21 | Text, 22 | Timestamp, 23 | Boolean, 24 | ) 25 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 26 | 27 | # https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database 28 | DEFAULT_DATABASE = "default" 29 | 30 | 31 | @import_helper("clickhouse") 32 | def import_clickhouse(): 33 | import clickhouse_driver 34 | 35 | return clickhouse_driver 36 | 37 | 38 | class Mixin_MD5(AbstractMixin_MD5): 39 | def md5_as_int(self, s: str) -> str: 40 | substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS 41 | return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" 42 | 43 | 44 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 45 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 46 | # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. 47 | # For example: 48 | # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 49 | # select toString(toDecimal128(1.00, 2)); -- the result is 1 50 | # So, we should use some custom approach to save these trailing zeros. 51 | # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. 52 | # For examples above it looks like: 53 | # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 54 | # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 55 | # So, the algorithm is: 56 | # 1. Cast to decimal with precision + 1 57 | # 2. Add a small value 10^(-precision-1) 58 | # 3. Cast the result to string 59 | # 4. Drop the extra digit from the string. To do that, we need to slice the string 60 | # with length = digits in an integer part + 1 (symbol of ".") + precision 61 | 62 | if coltype.precision == 0: 63 | return self.to_string(f"round({value})") 64 | 65 | precision = coltype.precision 66 | # TODO: too complex, is there better performance way? 67 | value = f""" 68 | if({value} >= 0, '', '-') || left( 69 | toString( 70 | toDecimal128( 71 | round(abs({value}), {precision}), 72 | {precision} + 1 73 | ) 74 | + 75 | toDecimal128( 76 | exp10(-{precision + 1}), 77 | {precision} + 1 78 | ) 79 | ), 80 | toUInt8( 81 | greatest( 82 | floor(log10(abs({value}))) + 1, 83 | 1 84 | ) 85 | ) + 1 + {precision} 86 | ) 87 | """ 88 | return value 89 | 90 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 91 | prec = coltype.precision 92 | if coltype.rounds: 93 | timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" 94 | return self.to_string(timestamp) 95 | 96 | fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" 97 | fractional = f"lpad({self.to_string(fractional)}, 6, '0')" 98 | value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" 99 | return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" 100 | 101 | 102 | class Dialect(BaseDialect): 103 | name = "Clickhouse" 104 | ROUNDS_ON_PREC_LOSS = False 105 | TYPE_CLASSES = { 106 | "Int8": Integer, 107 | "Int16": Integer, 108 | "Int32": Integer, 109 | "Int64": Integer, 110 | "Int128": Integer, 111 | "Int256": Integer, 112 | "UInt8": Integer, 113 | "UInt16": Integer, 114 | "UInt32": Integer, 115 | "UInt64": Integer, 116 | "UInt128": Integer, 117 | "UInt256": Integer, 118 | "Float32": Float, 119 | "Float64": Float, 120 | "Decimal": Decimal, 121 | "UUID": Native_UUID, 122 | "String": Text, 123 | "FixedString": Text, 124 | "DateTime": Timestamp, 125 | "DateTime64": Timestamp, 126 | "Bool": Boolean, 127 | } 128 | MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 129 | 130 | def quote(self, s: str) -> str: 131 | return f'"{s}"' 132 | 133 | def to_string(self, s: str) -> str: 134 | return f"toString({s})" 135 | 136 | def _convert_db_precision_to_digits(self, p: int) -> int: 137 | # Done the same as for PostgreSQL but need to rewrite in another way 138 | # because it does not help for float with a big integer part. 139 | return super()._convert_db_precision_to_digits(p) - 2 140 | 141 | def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: 142 | nullable_prefix = "Nullable(" 143 | if type_repr.startswith(nullable_prefix): 144 | type_repr = type_repr[len(nullable_prefix) :].rstrip(")") 145 | 146 | if type_repr.startswith("Decimal"): 147 | type_repr = "Decimal" 148 | elif type_repr.startswith("FixedString"): 149 | type_repr = "FixedString" 150 | elif type_repr.startswith("DateTime64"): 151 | type_repr = "DateTime64" 152 | 153 | return self.TYPE_CLASSES.get(type_repr) 154 | 155 | # def timestamp_value(self, t: DbTime) -> str: 156 | # # return f"'{t}'" 157 | # return f"'{str(t)[:19]}'" 158 | 159 | def set_timezone_to_utc(self) -> str: 160 | raise NotImplementedError() 161 | 162 | def current_timestamp(self) -> str: 163 | return "now()" 164 | 165 | 166 | class Clickhouse(ThreadedDatabase): 167 | dialect = Dialect() 168 | CONNECT_URI_HELP = "clickhouse://:@/" 169 | CONNECT_URI_PARAMS = ["database?"] 170 | 171 | def __init__(self, *, thread_count: int, **kw): 172 | super().__init__(thread_count=thread_count) 173 | 174 | self._args = kw 175 | # In Clickhouse database and schema are the same 176 | self.default_schema = kw.get("database", DEFAULT_DATABASE) 177 | 178 | def create_connection(self): 179 | clickhouse = import_clickhouse() 180 | 181 | class SingleConnection(clickhouse.dbapi.connection.Connection): 182 | """Not thread-safe connection to Clickhouse""" 183 | 184 | def cursor(self, cursor_factory=None): 185 | if not len(self.cursors): 186 | _ = super().cursor() 187 | return self.cursors[0] 188 | 189 | try: 190 | return SingleConnection(**self._args) 191 | except clickhouse.OperationError as e: 192 | raise ConnectError(*e.args) from e 193 | 194 | @property 195 | def is_autocommit(self) -> bool: 196 | return True 197 | -------------------------------------------------------------------------------- /sqeleton/databases/databricks.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Sequence 3 | import logging 4 | 5 | from ..abcs.database_types import ( 6 | Integer, 7 | Float, 8 | Decimal, 9 | Timestamp, 10 | Text, 11 | TemporalType, 12 | NumericType, 13 | DbPath, 14 | ColType, 15 | UnknownColType, 16 | Boolean, 17 | ) 18 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 19 | from .base import ( 20 | MD5_HEXDIGITS, 21 | CHECKSUM_HEXDIGITS, 22 | BaseDialect, 23 | ThreadedDatabase, 24 | import_helper, 25 | parse_table_name, 26 | Mixin_RandomSample, 27 | ) 28 | 29 | 30 | @import_helper(text="You can install it using 'pip install databricks-sql-connector'") 31 | def import_databricks(): 32 | import databricks.sql 33 | 34 | return databricks 35 | 36 | 37 | class Mixin_MD5(AbstractMixin_MD5): 38 | def md5_as_int(self, s: str) -> str: 39 | return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" 40 | 41 | 42 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 43 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 44 | """Databricks timestamp contains no more than 6 digits in precision""" 45 | 46 | if coltype.rounds: 47 | timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" 48 | return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" 49 | 50 | precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) 51 | return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" 52 | 53 | def normalize_number(self, value: str, coltype: NumericType) -> str: 54 | value = f"cast({value} as decimal(38, {coltype.precision}))" 55 | if coltype.precision > 0: 56 | value = f"format_number({value}, {coltype.precision})" 57 | return f"replace({self.to_string(value)}, ',', '')" 58 | 59 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 60 | return self.to_string(f"cast ({value} as int)") 61 | 62 | 63 | class Dialect(BaseDialect): 64 | name = "Databricks" 65 | ROUNDS_ON_PREC_LOSS = True 66 | TYPE_CLASSES = { 67 | # Numbers 68 | "INT": Integer, 69 | "SMALLINT": Integer, 70 | "TINYINT": Integer, 71 | "BIGINT": Integer, 72 | "FLOAT": Float, 73 | "DOUBLE": Float, 74 | "DECIMAL": Decimal, 75 | # Timestamps 76 | "TIMESTAMP": Timestamp, 77 | # Text 78 | "STRING": Text, 79 | # Boolean 80 | "BOOLEAN": Boolean, 81 | } 82 | MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 83 | 84 | def quote(self, s: str): 85 | return f"`{s}`" 86 | 87 | def to_string(self, s: str) -> str: 88 | return f"cast({s} as string)" 89 | 90 | def _convert_db_precision_to_digits(self, p: int) -> int: 91 | # Subtracting 2 due to wierd precision issues 92 | return max(super()._convert_db_precision_to_digits(p) - 2, 0) 93 | 94 | def set_timezone_to_utc(self) -> str: 95 | return "SET TIME ZONE 'UTC'" 96 | 97 | 98 | class Databricks(ThreadedDatabase): 99 | dialect = Dialect() 100 | CONNECT_URI_HELP = "databricks://:@/" 101 | CONNECT_URI_PARAMS = ["catalog", "schema"] 102 | 103 | def __init__(self, *, thread_count, **kw): 104 | logging.getLogger("databricks.sql").setLevel(logging.WARNING) 105 | 106 | self._args = kw 107 | self.default_schema = kw.get("schema", "default") 108 | self.catalog = self._args.get("catalog", "hive_metastore") 109 | super().__init__(thread_count=thread_count) 110 | 111 | def create_connection(self): 112 | databricks = import_databricks() 113 | 114 | try: 115 | return databricks.sql.connect( 116 | server_hostname=self._args["server_hostname"], 117 | http_path=self._args["http_path"], 118 | access_token=self._args["access_token"], 119 | catalog=self.catalog, 120 | ) 121 | except databricks.sql.exc.Error as e: 122 | raise ConnectionError(*e.args) from e 123 | 124 | def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: 125 | # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. 126 | # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html 127 | # So, to obtain information about schema, we should use another approach. 128 | 129 | conn = self.create_connection() 130 | 131 | catalog, schema, table = self._normalize_table_path(path) 132 | with conn.cursor() as cursor: 133 | cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) 134 | try: 135 | rows = cursor.fetchall() 136 | finally: 137 | conn.close() 138 | if not rows: 139 | raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") 140 | 141 | d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} 142 | assert len(d) == len(rows) 143 | return d 144 | 145 | def _process_table_schema( 146 | self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None 147 | ): 148 | accept = {i.lower() for i in filter_columns} 149 | rows = [row for name, row in raw_schema.items() if name.lower() in accept] 150 | 151 | resulted_rows = [] 152 | for row in rows: 153 | row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] 154 | type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) 155 | 156 | if issubclass(type_cls, Integer): 157 | row = (row[0], row_type, None, None, 0) 158 | 159 | elif issubclass(type_cls, Float): 160 | numeric_precision = math.ceil(row[2] / math.log(2, 10)) 161 | row = (row[0], row_type, None, numeric_precision, None) 162 | 163 | elif issubclass(type_cls, Decimal): 164 | items = row[1][8:].rstrip(")").split(",") 165 | numeric_precision, numeric_scale = int(items[0]), int(items[1]) 166 | row = (row[0], row_type, None, numeric_precision, numeric_scale) 167 | 168 | elif issubclass(type_cls, Timestamp): 169 | row = (row[0], row_type, row[2], None, None) 170 | 171 | else: 172 | row = (row[0], row_type, None, None, None) 173 | 174 | resulted_rows.append(row) 175 | 176 | col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} 177 | 178 | self._refine_coltypes(path, col_dict, where) 179 | return col_dict 180 | 181 | def parse_table_name(self, name: str) -> DbPath: 182 | path = parse_table_name(name) 183 | return tuple(i for i in self._normalize_table_path(path) if i is not None) 184 | 185 | @property 186 | def is_autocommit(self) -> bool: 187 | return True 188 | 189 | def _normalize_table_path(self, path: DbPath) -> DbPath: 190 | if len(path) == 1: 191 | return self.catalog, self.default_schema, path[0] 192 | elif len(path) == 2: 193 | return self.catalog, path[0], path[1] 194 | elif len(path) == 3: 195 | return path 196 | 197 | raise ValueError( 198 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" 199 | ) 200 | -------------------------------------------------------------------------------- /sqeleton/databases/bigquery.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from ..abcs.database_types import ( 3 | Timestamp, 4 | Datetime, 5 | Integer, 6 | Decimal, 7 | Float, 8 | Text, 9 | DbPath, 10 | FractionalType, 11 | TemporalType, 12 | Boolean, 13 | ) 14 | from ..abcs.mixins import ( 15 | AbstractMixin_MD5, 16 | AbstractMixin_NormalizeValue, 17 | AbstractMixin_Schema, 18 | AbstractMixin_TimeTravel, 19 | ) 20 | from ..abcs import Compilable 21 | from ..queries import this, table, SKIP, code 22 | from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query, QueryResult 23 | from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample 24 | 25 | 26 | @import_helper(text="Please install BigQuery and configure your google-cloud access.") 27 | def import_bigquery(): 28 | from google.cloud import bigquery 29 | 30 | return bigquery 31 | 32 | 33 | class Mixin_MD5(AbstractMixin_MD5): 34 | def md5_as_int(self, s: str) -> str: 35 | return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" 36 | 37 | 38 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 39 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 40 | if coltype.rounds: 41 | timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" 42 | return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" 43 | 44 | if coltype.precision == 0: 45 | return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})" 46 | elif coltype.precision == 6: 47 | return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" 48 | 49 | timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" 50 | return ( 51 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 52 | ) 53 | 54 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 55 | return f"format('%.{coltype.precision}f', {value})" 56 | 57 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 58 | return self.to_string(f"cast({value} as int)") 59 | 60 | 61 | class Mixin_Schema(AbstractMixin_Schema): 62 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 63 | return ( 64 | table(table_schema, "INFORMATION_SCHEMA", "TABLES") 65 | .where( 66 | this.table_schema == table_schema, 67 | this.table_name.like(like) if like is not None else SKIP, 68 | this.table_type == "BASE TABLE", 69 | ) 70 | .select(this.table_name) 71 | ) 72 | 73 | 74 | class Mixin_TimeTravel(AbstractMixin_TimeTravel): 75 | def time_travel( 76 | self, 77 | table: Compilable, 78 | before: bool = False, 79 | timestamp: Compilable = None, 80 | offset: Compilable = None, 81 | statement: Compilable = None, 82 | ) -> Compilable: 83 | if before: 84 | raise NotImplementedError("before=True not supported for BigQuery time-travel") 85 | 86 | if statement is not None: 87 | raise NotImplementedError("BigQuery time-travel doesn't support querying by statement id") 88 | 89 | if timestamp is not None: 90 | assert offset is None 91 | return code("{table} FOR SYSTEM_TIME AS OF {timestamp}", table=table, timestamp=timestamp) 92 | 93 | assert offset is not None 94 | return code( 95 | "{table} FOR SYSTEM_TIME AS OF TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {offset} HOUR);", 96 | table=table, 97 | offset=offset, 98 | ) 99 | 100 | 101 | class Dialect(BaseDialect, Mixin_Schema): 102 | name = "BigQuery" 103 | ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation 104 | TYPE_CLASSES = { 105 | # Dates 106 | "TIMESTAMP": Timestamp, 107 | "DATETIME": Datetime, 108 | # Numbers 109 | "INT64": Integer, 110 | "INT32": Integer, 111 | "NUMERIC": Decimal, 112 | "BIGNUMERIC": Decimal, 113 | "FLOAT64": Float, 114 | "FLOAT32": Float, 115 | # Text 116 | "STRING": Text, 117 | # Boolean 118 | "BOOL": Boolean, 119 | } 120 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} 121 | 122 | def random(self) -> str: 123 | return "RAND()" 124 | 125 | def quote(self, s: str): 126 | return f"`{s}`" 127 | 128 | def to_string(self, s: str): 129 | return f"cast({s} as string)" 130 | 131 | def type_repr(self, t) -> str: 132 | try: 133 | return {str: "STRING", float: "FLOAT64"}[t] 134 | except KeyError: 135 | return super().type_repr(t) 136 | 137 | def set_timezone_to_utc(self) -> str: 138 | raise NotImplementedError() 139 | 140 | 141 | class BigQuery(Database): 142 | CONNECT_URI_HELP = "bigquery:///" 143 | CONNECT_URI_PARAMS = ["dataset"] 144 | dialect = Dialect() 145 | 146 | def __init__(self, project, *, dataset, **kw): 147 | bigquery = import_bigquery() 148 | 149 | self._client = bigquery.Client(project, **kw) 150 | self.project = project 151 | self.dataset = dataset 152 | 153 | self.default_schema = dataset 154 | 155 | def _normalize_returned_value(self, value): 156 | if isinstance(value, bytes): 157 | return value.decode() 158 | return value 159 | 160 | def _query_atom(self, sql_code: str): 161 | from google.cloud import bigquery 162 | 163 | try: 164 | result = self._client.query(sql_code).result() 165 | columns = [c.name for c in result.schema] 166 | rows = list(result) 167 | except Exception as e: 168 | msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" 169 | raise ConnectError(msg % (sql_code, e)) 170 | 171 | if rows and isinstance(rows[0], bigquery.table.Row): 172 | rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows] 173 | return QueryResult(rows, columns) 174 | 175 | def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: 176 | return apply_query(self._query_atom, sql_code) 177 | 178 | def close(self): 179 | super().close() 180 | self._client.close() 181 | 182 | def select_table_schema(self, path: DbPath) -> str: 183 | project, schema, name = self._normalize_table_path(path) 184 | return ( 185 | "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale " 186 | f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS " 187 | f"WHERE table_name = '{name}' AND table_schema = '{schema}'" 188 | ) 189 | 190 | def query_table_unique_columns(self, path: DbPath) -> List[str]: 191 | return [] 192 | 193 | def _normalize_table_path(self, path: DbPath) -> DbPath: 194 | if len(path) == 0: 195 | raise ValueError(f"{self.name}: Bad table path for {self}: ()") 196 | elif len(path) == 1: 197 | return (self.project, self.default_schema, path[0]) 198 | elif len(path) == 2: 199 | return (self.project,) + path 200 | elif len(path) == 3: 201 | return path 202 | else: 203 | raise ValueError( 204 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" 205 | ) 206 | 207 | def parse_table_name(self, name: str) -> DbPath: 208 | path = parse_table_name(name) 209 | return tuple(i for i in self._normalize_table_path(path) if i is not None) 210 | 211 | @property 212 | def is_autocommit(self) -> bool: 213 | return True 214 | -------------------------------------------------------------------------------- /sqeleton/databases/snowflake.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import logging 3 | 4 | from ..abcs.database_types import ( 5 | Timestamp, 6 | TimestampTZ, 7 | Decimal, 8 | Float, 9 | Text, 10 | FractionalType, 11 | TemporalType, 12 | DbPath, 13 | Boolean, 14 | Date, 15 | ) 16 | from ..abcs.mixins import ( 17 | AbstractMixin_MD5, 18 | AbstractMixin_NormalizeValue, 19 | AbstractMixin_Schema, 20 | AbstractMixin_TimeTravel, 21 | ) 22 | from ..abcs import Compilable 23 | from sqeleton.queries import table, this, SKIP, code 24 | from .base import ( 25 | BaseDialect, 26 | ConnectError, 27 | Database, 28 | import_helper, 29 | CHECKSUM_MASK, 30 | ThreadLocalInterpreter, 31 | Mixin_RandomSample, 32 | ) 33 | 34 | 35 | @import_helper("snowflake") 36 | def import_snowflake(): 37 | import snowflake.connector 38 | from cryptography.hazmat.primitives import serialization 39 | from cryptography.hazmat.backends import default_backend 40 | 41 | return snowflake, serialization, default_backend 42 | 43 | 44 | class Mixin_MD5(AbstractMixin_MD5): 45 | def md5_as_int(self, s: str) -> str: 46 | return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" 47 | 48 | 49 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 50 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 51 | if coltype.rounds: 52 | timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" 53 | else: 54 | timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" 55 | 56 | return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" 57 | 58 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 59 | return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") 60 | 61 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 62 | return self.to_string(f"{value}::int") 63 | 64 | 65 | class Mixin_Schema(AbstractMixin_Schema): 66 | def table_information(self) -> Compilable: 67 | return table("INFORMATION_SCHEMA", "TABLES") 68 | 69 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 70 | return ( 71 | self.table_information() 72 | .where( 73 | this.TABLE_SCHEMA == table_schema, 74 | this.TABLE_NAME.like(like) if like is not None else SKIP, 75 | this.TABLE_TYPE == "BASE TABLE", 76 | ) 77 | .select(table_name=this.TABLE_NAME) 78 | ) 79 | 80 | 81 | class Mixin_TimeTravel(AbstractMixin_TimeTravel): 82 | def time_travel( 83 | self, 84 | table: Compilable, 85 | before: bool = False, 86 | timestamp: Compilable = None, 87 | offset: Compilable = None, 88 | statement: Compilable = None, 89 | ) -> Compilable: 90 | at_or_before = "AT" if before else "BEFORE" 91 | if timestamp is not None: 92 | assert offset is None and statement is None 93 | key = "timestamp" 94 | value = timestamp 95 | elif offset is not None: 96 | assert statement is None 97 | key = "offset" 98 | value = offset 99 | else: 100 | assert statement is not None 101 | key = "statement" 102 | value = statement 103 | 104 | return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) 105 | 106 | 107 | class Dialect(BaseDialect, Mixin_Schema): 108 | name = "Snowflake" 109 | ROUNDS_ON_PREC_LOSS = False 110 | TYPE_CLASSES = { 111 | # Timestamps 112 | "TIMESTAMP_NTZ": Timestamp, 113 | "TIMESTAMP_LTZ": Timestamp, 114 | "TIMESTAMP_TZ": TimestampTZ, 115 | "DATE": Date, 116 | # Numbers 117 | "NUMBER": Decimal, 118 | "FLOAT": Float, 119 | # Text 120 | "TEXT": Text, 121 | # Boolean 122 | "BOOLEAN": Boolean, 123 | } 124 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} 125 | 126 | def explain_as_text(self, query: str) -> str: 127 | return f"EXPLAIN USING TEXT {query}" 128 | 129 | def quote(self, s: str): 130 | return f'"{s}"' 131 | 132 | def to_string(self, s: str): 133 | return f"cast({s} as string)" 134 | 135 | def table_information(self) -> Compilable: 136 | return table("INFORMATION_SCHEMA", "TABLES") 137 | 138 | def set_timezone_to_utc(self) -> str: 139 | return "ALTER SESSION SET TIMEZONE = 'UTC'" 140 | 141 | def optimizer_hints(self, hints: str) -> str: 142 | raise NotImplementedError("Optimizer hints not yet implemented in snowflake") 143 | 144 | def type_repr(self, t) -> str: 145 | if isinstance(t, TimestampTZ): 146 | return f"timestamp_tz({t.precision})" 147 | return super().type_repr(t) 148 | 149 | 150 | class Snowflake(Database): 151 | dialect = Dialect() 152 | CONNECT_URI_HELP = "snowflake://:@//?warehouse=" 153 | CONNECT_URI_PARAMS = ["database", "schema"] 154 | CONNECT_URI_KWPARAMS = ["warehouse"] 155 | 156 | def __init__(self, *, schema: str, **kw): 157 | snowflake, serialization, default_backend = import_snowflake() 158 | logging.getLogger("snowflake.connector").setLevel(logging.WARNING) 159 | 160 | # Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state 161 | # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 162 | logging.getLogger("snowflake.connector.network").disabled = True 163 | 164 | assert '"' not in schema, "Schema name should not contain quotes!" 165 | # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. 166 | if "key" in kw: 167 | with open(kw.get("key"), "rb") as key: 168 | if "password" in kw: 169 | raise ConnectError("Cannot use password and key at the same time") 170 | if kw.get("private_key_passphrase"): 171 | encoded_passphrase = kw.get("private_key_passphrase").encode() 172 | else: 173 | encoded_passphrase = None 174 | p_key = serialization.load_pem_private_key( 175 | key.read(), 176 | password=encoded_passphrase, 177 | backend=default_backend(), 178 | ) 179 | 180 | kw["private_key"] = p_key.private_bytes( 181 | encoding=serialization.Encoding.DER, 182 | format=serialization.PrivateFormat.PKCS8, 183 | encryption_algorithm=serialization.NoEncryption(), 184 | ) 185 | 186 | self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw) 187 | 188 | self.default_schema = schema 189 | 190 | def close(self): 191 | super().close() 192 | self._conn.close() 193 | 194 | def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): 195 | "Uses the standard SQL cursor interface" 196 | return self._query_conn(self._conn, sql_code) 197 | 198 | def select_table_schema(self, path: DbPath) -> str: 199 | """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" 200 | database, schema, name = self._normalize_table_path(path) 201 | info_schema_path = ["information_schema", "columns"] 202 | if database: 203 | info_schema_path.insert(0, database) 204 | 205 | return ( 206 | "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " 207 | f"FROM {'.'.join(info_schema_path)} " 208 | f"WHERE table_name = '{name}' AND table_schema = '{schema}'" 209 | ) 210 | 211 | def _normalize_table_path(self, path: DbPath) -> DbPath: 212 | if len(path) == 1: 213 | return None, self.default_schema, path[0] 214 | elif len(path) == 2: 215 | return None, path[0], path[1] 216 | elif len(path) == 3: 217 | return path 218 | 219 | raise ValueError( 220 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 221 | ) 222 | 223 | @property 224 | def is_autocommit(self) -> bool: 225 | return True 226 | 227 | def query_table_unique_columns(self, path: DbPath) -> List[str]: 228 | return [] 229 | -------------------------------------------------------------------------------- /sqeleton/conn_editor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | from concurrent.futures import ThreadPoolExecutor 4 | import toml 5 | 6 | from runtype import dataclass 7 | 8 | try: 9 | import textual 10 | except ModuleNotFoundError: 11 | raise ModuleNotFoundError( 12 | "Error: Cannot find the TUI library 'textual'. Please install it using `pip install sqeleton[tui]`" 13 | ) 14 | 15 | from textual.app import App, ComposeResult 16 | from textual.containers import Vertical, Container, Horizontal 17 | from textual.widgets import Header, Footer, Input, Label, Button, Static, ListView, ListItem 18 | 19 | from textual_select import Select 20 | 21 | from .databases._connect import DATABASE_BY_SCHEME 22 | from . import connect 23 | 24 | 25 | class ContentSwapper(Container): 26 | def __init__(self, *initial_widgets, **kw): 27 | self.container = Container(*initial_widgets) 28 | super().__init__(**kw) 29 | 30 | def compose(self): 31 | yield self.container 32 | 33 | def new_content(self, *new_widgets): 34 | for c in self.container.children: 35 | c.remove() 36 | 37 | self.container.mount(*new_widgets) 38 | 39 | 40 | def test_connect(connect_dict): 41 | conn = connect(connect_dict) 42 | assert conn.query("select 1+1", int) == 2 43 | 44 | 45 | @dataclass 46 | class Config: 47 | config: dict 48 | 49 | @property 50 | def databases(self): 51 | if "database" not in self.config: 52 | self.config["database"] = {} 53 | assert isinstance(self.config["database"], dict) 54 | return self.config["database"] 55 | 56 | def add_database(self, name: str, db: dict): 57 | assert isinstance(db, dict) 58 | self.config["database"][name] = db 59 | 60 | def remove_database(self, name): 61 | del self.config["database"][name] 62 | 63 | 64 | class EditConnection(Vertical): 65 | def __init__(self, db_name: str, config: Config) -> None: 66 | self._db_name = db_name 67 | self._config = config 68 | super().__init__() 69 | 70 | def compose(self): 71 | self.params_container = ContentSwapper(id="params") 72 | self.driver_select = Select( 73 | placeholder="Select a database driver", 74 | search=True, 75 | items=[{"value": k, "text": k} for k in DATABASE_BY_SCHEME], 76 | list_mount="#driver_container", 77 | value=self._config.databases.get(self._db_name, {}).get("driver"), 78 | id="driver", 79 | ) 80 | 81 | yield Vertical( 82 | Label("Connection name:"), 83 | Input(id="conn_name", value=self._db_name, classes="margin-bottom-1"), 84 | Label("Database Driver:"), 85 | self.driver_select, 86 | self.params_container, 87 | id="driver_container", 88 | ) 89 | 90 | self.create_params() 91 | 92 | def create_params(self): 93 | driver = self.driver_select.value 94 | if not driver: 95 | return 96 | db_config = self._config.databases.get(self._db_name, {}) 97 | 98 | db_cls = DATABASE_BY_SCHEME[driver] 99 | # Filter out repetitions, but keep order 100 | base_params = re.findall("<(.*?)>", db_cls.CONNECT_URI_HELP) 101 | params = dict.fromkeys([p.lower() for p in base_params] + [k for k in db_config if k != "driver"]) 102 | 103 | widgets = [] 104 | for p in params: 105 | label = p 106 | if p in ("user", "pass", "port", "dbname"): 107 | label += " (optional)" 108 | widgets.append(Label(f"{label}:")) 109 | p = p.lower() 110 | widgets.append(Input(id="input_" + p, name=p, classes="param", value=db_config.get(p))) 111 | 112 | self.params_container.new_content( 113 | Vertical( 114 | *widgets, 115 | Horizontal( 116 | Button("Test & Save Connection", id="test_and_save"), 117 | Button("Save Connection Without Testing", id="save_without_test"), 118 | ), 119 | Vertical(id="result"), 120 | ) 121 | ) 122 | 123 | def on_select_changed(self, event: Select.Changed) -> None: 124 | driver = str(event.value) 125 | self.driver_select.text = driver 126 | self.driver_select.value = driver 127 | self.create_params() 128 | 129 | def _get_connect_dict(self): 130 | connect_dict = {"driver": self.driver_select.value} 131 | for p in self.query(".param"): 132 | if p.value: 133 | connect_dict[p.name] = p.value 134 | 135 | return connect_dict 136 | 137 | async def on_button_pressed(self, event: Button.Pressed) -> None: 138 | button_id = event.button.id 139 | if button_id == "test_and_save": 140 | connect_dict = self._get_connect_dict() 141 | 142 | result_container = self.query_one("#result") 143 | result_container.mount(Label(f"Trying to connect to {connect_dict}")) 144 | 145 | try: 146 | test_connect(connect_dict) 147 | except Exception as e: 148 | error = str(e) 149 | result_container.mount(Label(f"[red]Error: {error}[/red]")) 150 | else: 151 | result_container.mount(Label(f"[green]Success![green]")) 152 | self.save(connect_dict) 153 | 154 | elif button_id == "save_without_test": 155 | connect_dict = self._get_connect_dict() 156 | self.save(connect_dict) 157 | 158 | def save(self, connect_dict): 159 | assert isinstance(connect_dict, dict) 160 | result_container = self.query_one("#result") 161 | result_container.mount(Label(f"Database saved")) 162 | 163 | name = self.query_one("#conn_name").value 164 | self._config.add_database(name, connect_dict) 165 | self.app.config_changed() 166 | 167 | 168 | class ListOfConnections(Vertical): 169 | def __init__(self, config: Config, **kw): 170 | self.config = config 171 | super().__init__(**kw) 172 | 173 | def compose(self) -> ComposeResult: 174 | list_items = [ 175 | ListItem(Label(name, id="list_label_" + name), name=name) for name in self.config.databases.keys() 176 | ] 177 | 178 | self.lv = lv = ListView(*list_items, id="connection_list") 179 | yield lv 180 | 181 | lv.focus() 182 | 183 | def on_list_view_highlighted(self, h: ListView.Highlighted): 184 | name = h.item.name 185 | self.app.query_one("#edit_container").new_content(EditConnection(name, self.config)) 186 | 187 | 188 | class ConnectionEditor(App): 189 | CSS = """ 190 | #conn_list { 191 | display: block; 192 | dock: left; 193 | height: 100%; 194 | max-width: 30%; 195 | margin: 1 196 | } 197 | 198 | #edit { 199 | margin: 1 200 | } 201 | 202 | #test_and_save { 203 | margin: 1 204 | } 205 | #save_without_test { 206 | margin: 1 207 | } 208 | 209 | .margin-bottom-1 { 210 | margin-bottom: 1 211 | } 212 | """ 213 | 214 | BINDINGS = [ 215 | ("q", "quit", "Quit"), 216 | ("d", "toggle_dark", "Toggle dark mode"), 217 | ("insert", "add_conn", "Add new connection"), 218 | ("delete", "del_conn", "Delete selected connection"), 219 | ("t", "test_conn", "Test selected connection"), 220 | ("a", "test_all_conns", "Test all connections"), 221 | ] 222 | 223 | def run(self, toml_path, **kw): 224 | self.toml_path = toml_path 225 | try: 226 | with open(self.toml_path) as f: 227 | self.config = Config(toml.load(f)) 228 | except FileNotFoundError: 229 | self.config = Config({}) 230 | 231 | return super().run(**kw) 232 | 233 | def config_changed(self): 234 | self.list_swapper.new_content(ListOfConnections(self.config)) 235 | with open(self.toml_path, "w") as f: 236 | toml.dump(self.config.config, f) 237 | 238 | def compose(self) -> ComposeResult: 239 | """Create child widgets for the app.""" 240 | self.list_swapper = ContentSwapper(ListOfConnections(self.config), id="conn_list") 241 | self.edit_swapper = ContentSwapper(id="edit_container") 242 | self.edit_swapper.new_content(EditConnection("New", self.config)) 243 | 244 | yield Header() 245 | yield Container(self.list_swapper, self.edit_swapper) 246 | yield Footer() 247 | 248 | def action_toggle_dark(self) -> None: 249 | """An action to toggle dark mode.""" 250 | self.dark = not self.dark 251 | 252 | def action_test_all_conns(self): 253 | return self.action_quit() 254 | 255 | def action_add_conn(self): 256 | self.edit_swapper.new_content(EditConnection("New", self.config)) 257 | 258 | def _selected_connection(self): 259 | connection_list: ListView = self.query_one("#connection_list") 260 | return connection_list.children[connection_list.index] 261 | 262 | def action_del_conn(self): 263 | name = self._selected_connection().name 264 | self.config.remove_database(name) 265 | self.config_changed() 266 | 267 | def _test_existing_db(self, name): 268 | label: Label = self.query_one("#list_label_" + name) 269 | label.update(f"{name}🔃") 270 | connect_dict = self.config.databases[name] 271 | try: 272 | test_connect(connect_dict) 273 | except Exception as e: 274 | label.update(f"{name}❌ {str(e)[:16]}") 275 | else: 276 | label.update(f"{name}✅") 277 | 278 | def action_test_conn(self): 279 | name = self._selected_connection().name 280 | t = ThreadPoolExecutor() 281 | t.submit(self._test_existing_db, name) 282 | 283 | def action_test_all_conns(self): 284 | t = ThreadPoolExecutor() 285 | for name in self.config.databases: 286 | t.submit(self._test_existing_db, name) 287 | 288 | 289 | def main(toml_path: str): 290 | app = ConnectionEditor() 291 | app.run(toml_path) 292 | 293 | 294 | if __name__ == "__main__": 295 | main(sys.argv[1]) 296 | -------------------------------------------------------------------------------- /sqeleton/databases/_connect.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Optional, Union, Dict 2 | from itertools import zip_longest 3 | from contextlib import suppress 4 | import dsnparse 5 | import toml 6 | 7 | from runtype import dataclass 8 | 9 | from ..abcs.mixins import AbstractMixin 10 | from ..utils import WeakCache, Self 11 | from .base import Database, ThreadedDatabase 12 | from .postgresql import PostgreSQL 13 | from .mysql import MySQL 14 | from .oracle import Oracle 15 | from .snowflake import Snowflake 16 | from .bigquery import BigQuery 17 | from .redshift import Redshift 18 | from .presto import Presto 19 | from .databricks import Databricks 20 | from .trino import Trino 21 | from .clickhouse import Clickhouse 22 | from .vertica import Vertica 23 | from .duckdb import DuckDB 24 | 25 | 26 | @dataclass 27 | class MatchUriPath: 28 | database_cls: Type[Database] 29 | 30 | def match_path(self, dsn): 31 | help_str = self.database_cls.CONNECT_URI_HELP 32 | params = self.database_cls.CONNECT_URI_PARAMS 33 | kwparams = self.database_cls.CONNECT_URI_KWPARAMS 34 | 35 | dsn_dict = dict(dsn.query) 36 | matches = {} 37 | for param, arg in zip_longest(params, dsn.paths): 38 | if param is None: 39 | raise ValueError(f"Too many parts to path. Expected format: {help_str}") 40 | 41 | optional = param.endswith("?") 42 | param = param.rstrip("?") 43 | 44 | if arg is None: 45 | try: 46 | arg = dsn_dict.pop(param) 47 | except KeyError: 48 | if not optional: 49 | raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") 50 | 51 | arg = None 52 | 53 | assert param and param not in matches 54 | matches[param] = arg 55 | 56 | for param in kwparams: 57 | try: 58 | arg = dsn_dict.pop(param) 59 | except KeyError: 60 | raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") 61 | 62 | assert param and arg and param not in matches, (param, arg, matches.keys()) 63 | matches[param] = arg 64 | 65 | for param, value in dsn_dict.items(): 66 | if param in matches: 67 | raise ValueError( 68 | f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" 69 | ) 70 | 71 | matches[param] = value 72 | 73 | return matches 74 | 75 | 76 | DATABASE_BY_SCHEME = { 77 | "postgresql": PostgreSQL, 78 | "mysql": MySQL, 79 | "oracle": Oracle, 80 | "redshift": Redshift, 81 | "snowflake": Snowflake, 82 | "presto": Presto, 83 | "bigquery": BigQuery, 84 | "databricks": Databricks, 85 | "duckdb": DuckDB, 86 | "trino": Trino, 87 | "clickhouse": Clickhouse, 88 | "vertica": Vertica, 89 | } 90 | 91 | 92 | class Connect: 93 | """Provides methods for connecting to a supported database using a URL or connection dict.""" 94 | 95 | def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): 96 | self.database_by_scheme = database_by_scheme 97 | self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} 98 | self.conn_cache = WeakCache() 99 | 100 | def for_databases(self, *dbs): 101 | database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} 102 | return type(self)(database_by_scheme) 103 | 104 | def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self: 105 | "Extend all the databases with a list of mixins that implement the given abstract mixins." 106 | database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()} 107 | return type(self)(database_by_scheme) 108 | 109 | def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database: 110 | """Connect to the given database uri 111 | 112 | thread_count determines the max number of worker threads per database, 113 | if relevant. None means no limit. 114 | 115 | Parameters: 116 | db_uri (str): The URI for the database to connect 117 | thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) 118 | 119 | Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. 120 | 121 | Supported schemes: 122 | - postgresql 123 | - mysql 124 | - oracle 125 | - snowflake 126 | - bigquery 127 | - redshift 128 | - presto 129 | - databricks 130 | - trino 131 | - clickhouse 132 | - vertica 133 | - duckdb 134 | """ 135 | 136 | dsn = dsnparse.parse(db_uri) 137 | if len(dsn.schemes) > 1: 138 | raise NotImplementedError("No support for multiple schemes") 139 | (scheme,) = dsn.schemes 140 | 141 | if scheme == "toml": 142 | toml_path = dsn.path or dsn.host 143 | database = dsn.fragment 144 | if not database: 145 | raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") 146 | with open(toml_path) as f: 147 | config = toml.load(f) 148 | try: 149 | conn_dict = config["database"][database] 150 | except KeyError: 151 | raise ValueError(f"Cannot find database config named '{database}'.") 152 | return self.connect_with_dict(conn_dict, thread_count) 153 | 154 | try: 155 | matcher = self.match_uri_path[scheme] 156 | except KeyError: 157 | raise NotImplementedError(f"Scheme '{scheme}' currently not supported") 158 | 159 | cls = matcher.database_cls 160 | 161 | if scheme == "databricks": 162 | assert not dsn.user 163 | kw = {} 164 | kw["access_token"] = dsn.password 165 | kw["http_path"] = dsn.path 166 | kw["server_hostname"] = dsn.host 167 | kw.update(dsn.query) 168 | elif scheme == "duckdb": 169 | kw = {} 170 | kw["filepath"] = dsn.dbname 171 | kw["dbname"] = dsn.user 172 | else: 173 | kw = matcher.match_path(dsn) 174 | 175 | if scheme == "bigquery": 176 | kw["project"] = dsn.host 177 | return cls(**kw) 178 | 179 | if scheme == "snowflake": 180 | kw["account"] = dsn.host 181 | assert not dsn.port 182 | kw["user"] = dsn.user 183 | kw["password"] = dsn.password 184 | else: 185 | kw["host"] = dsn.host 186 | kw["port"] = dsn.port 187 | kw["user"] = dsn.user 188 | if dsn.password: 189 | kw["password"] = dsn.password 190 | 191 | kw = {k: v for k, v in kw.items() if v is not None} 192 | 193 | if issubclass(cls, ThreadedDatabase): 194 | db = cls(thread_count=thread_count, **kw) 195 | else: 196 | db = cls(**kw) 197 | 198 | return self._connection_created(db) 199 | 200 | def connect_with_dict(self, d, thread_count): 201 | d = dict(d) 202 | driver = d.pop("driver") 203 | try: 204 | matcher = self.match_uri_path[driver] 205 | except KeyError: 206 | raise NotImplementedError(f"Driver '{driver}' currently not supported") 207 | 208 | cls = matcher.database_cls 209 | if issubclass(cls, ThreadedDatabase): 210 | db = cls(thread_count=thread_count, **d) 211 | else: 212 | db = cls(**d) 213 | 214 | return self._connection_created(db) 215 | 216 | def _connection_created(self, db): 217 | "Nop function to be overridden by subclasses." 218 | return db 219 | 220 | def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database: 221 | """Connect to a database using the given database configuration. 222 | 223 | Configuration can be given either as a URI string, or as a dict of {option: value}. 224 | 225 | The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. 226 | 227 | thread_count determines the max number of worker threads per database, 228 | if relevant. None means no limit. 229 | 230 | Parameters: 231 | db_conf (str | dict): The configuration for the database to connect. URI or dict. 232 | thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) 233 | shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) 234 | 235 | Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. 236 | 237 | Supported drivers: 238 | - postgresql 239 | - mysql 240 | - oracle 241 | - snowflake 242 | - bigquery 243 | - redshift 244 | - presto 245 | - databricks 246 | - trino 247 | - clickhouse 248 | - vertica 249 | 250 | Example: 251 | >>> connect("mysql://localhost/db") 252 | 253 | >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) 254 | 255 | """ 256 | if shared: 257 | with suppress(KeyError): 258 | conn = self.conn_cache.get(db_conf) 259 | if not conn.is_closed: 260 | return conn 261 | 262 | if isinstance(db_conf, str): 263 | conn = self.connect_to_uri(db_conf, thread_count) 264 | elif isinstance(db_conf, dict): 265 | conn = self.connect_with_dict(db_conf, thread_count) 266 | else: 267 | raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") 268 | 269 | if shared: 270 | self.conn_cache.add(db_conf, conn) 271 | return conn 272 | -------------------------------------------------------------------------------- /sqeleton/utils.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Iterable, 3 | Iterator, 4 | MutableMapping, 5 | Union, 6 | Any, 7 | Sequence, 8 | Dict, 9 | Hashable, 10 | TypeVar, 11 | TYPE_CHECKING, 12 | List, 13 | ) 14 | from abc import abstractmethod 15 | from weakref import ref 16 | import math 17 | import string 18 | import re 19 | from uuid import UUID 20 | from urllib.parse import urlparse 21 | 22 | # -- Common -- 23 | 24 | try: 25 | from typing import Self 26 | except ImportError: 27 | Self = Any 28 | 29 | 30 | class WeakCache: 31 | def __init__(self): 32 | self._cache = {} 33 | 34 | def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable: 35 | if isinstance(k, dict): 36 | return tuple(k.items()) 37 | return k 38 | 39 | def add(self, key: Union[dict, Hashable], value: Any): 40 | key = self._hashable_key(key) 41 | self._cache[key] = ref(value) 42 | 43 | def get(self, key: Union[dict, Hashable]) -> Any: 44 | key = self._hashable_key(key) 45 | 46 | value = self._cache[key]() 47 | if value is None: 48 | del self._cache[key] 49 | raise KeyError(f"Key {key} not found, or no longer a valid reference") 50 | 51 | return value 52 | 53 | 54 | def join_iter(joiner: Any, iterable: Iterable) -> Iterable: 55 | it = iter(iterable) 56 | try: 57 | yield next(it) 58 | except StopIteration: 59 | return 60 | for i in it: 61 | yield joiner 62 | yield i 63 | 64 | 65 | def safezip(*args): 66 | "zip but makes sure all sequences are the same length" 67 | lens = list(map(len, args)) 68 | if len(set(lens)) != 1: 69 | raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") 70 | return zip(*args) 71 | 72 | 73 | def is_uuid(u): 74 | try: 75 | UUID(u) 76 | except ValueError: 77 | return False 78 | return True 79 | 80 | 81 | def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: 82 | for regexp, v in regexps.items(): 83 | m = re.match(regexp + "$", s) 84 | if m: 85 | yield m, v 86 | 87 | 88 | # -- Schema -- 89 | 90 | V = TypeVar("V") 91 | 92 | 93 | class CaseAwareMapping(MutableMapping[str, V]): 94 | @abstractmethod 95 | def get_key(self, key: str) -> str: 96 | ... 97 | 98 | def new(self, initial=()): 99 | return type(self)(initial) 100 | 101 | 102 | class CaseInsensitiveDict(CaseAwareMapping): 103 | def __init__(self, initial): 104 | self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} 105 | 106 | def __getitem__(self, key: str) -> V: 107 | return self._dict[key.lower()][1] 108 | 109 | def __iter__(self) -> Iterator[V]: 110 | return iter(self._dict) 111 | 112 | def __len__(self) -> int: 113 | return len(self._dict) 114 | 115 | def __setitem__(self, key: str, value): 116 | k = key.lower() 117 | if k in self._dict: 118 | key = self._dict[k][0] 119 | self._dict[k] = key, value 120 | 121 | def __delitem__(self, key: str): 122 | del self._dict[key.lower()] 123 | 124 | def get_key(self, key: str) -> str: 125 | return self._dict[key.lower()][0] 126 | 127 | def __repr__(self) -> str: 128 | return repr(dict(self.items())) 129 | 130 | 131 | class CaseSensitiveDict(dict, CaseAwareMapping): 132 | def get_key(self, key): 133 | self[key] # Throw KeyError if key doesn't exist 134 | return key 135 | 136 | def as_insensitive(self): 137 | return CaseInsensitiveDict(self) 138 | 139 | 140 | # -- Alphanumerics -- 141 | 142 | alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase 143 | 144 | 145 | class ArithString: 146 | @classmethod 147 | def new(cls, *args, **kw): 148 | return cls(*args, **kw) 149 | 150 | def range(self, other: "ArithString", count: int): 151 | assert isinstance(other, ArithString) 152 | checkpoints = split_space(self.int, other.int, count) 153 | return [self.new(int=i) for i in checkpoints] 154 | 155 | 156 | class ArithUUID(UUID, ArithString): 157 | "A UUID that supports basic arithmetic (add, sub)" 158 | 159 | def __int__(self): 160 | return self.int 161 | 162 | def __add__(self, other: int): 163 | if isinstance(other, int): 164 | return self.new(int=self.int + other) 165 | return NotImplemented 166 | 167 | def __sub__(self, other: Union[UUID, int]): 168 | if isinstance(other, int): 169 | return self.new(int=self.int - other) 170 | elif isinstance(other, UUID): 171 | return self.int - other.int 172 | return NotImplemented 173 | 174 | 175 | def numberToAlphanum(num: int, base: str = alphanums) -> str: 176 | digits = [] 177 | while num > 0: 178 | num, remainder = divmod(num, len(base)) 179 | digits.append(remainder) 180 | return "".join(base[i] for i in digits[::-1]) 181 | 182 | 183 | def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: 184 | num = 0 185 | for c in alphanum: 186 | num = num * len(base) + base.index(c) 187 | return num 188 | 189 | 190 | def justify_alphanums(s1: str, s2: str): 191 | max_len = max(len(s1), len(s2)) 192 | s1 = s1.ljust(max_len) 193 | s2 = s2.ljust(max_len) 194 | return s1, s2 195 | 196 | 197 | def alphanums_to_numbers(s1: str, s2: str): 198 | s1, s2 = justify_alphanums(s1, s2) 199 | n1 = alphanumToNumber(s1) 200 | n2 = alphanumToNumber(s2) 201 | return n1, n2 202 | 203 | 204 | class ArithAlphanumeric(ArithString): 205 | def __init__(self, s: str, max_len=None): 206 | if s is None: 207 | raise ValueError("Alphanum string cannot be None") 208 | if max_len and len(s) > max_len: 209 | raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") 210 | 211 | for ch in s: 212 | if ch not in alphanums: 213 | raise ValueError(f"Unexpected character {ch} in alphanum string") 214 | 215 | self._str = s 216 | self._max_len = max_len 217 | 218 | # @property 219 | # def int(self): 220 | # return alphanumToNumber(self._str, alphanums) 221 | 222 | def __str__(self): 223 | s = self._str 224 | if self._max_len: 225 | s = s.rjust(self._max_len, alphanums[0]) 226 | return s 227 | 228 | def __len__(self): 229 | return len(self._str) 230 | 231 | def __repr__(self): 232 | return f'alphanum"{self._str}"' 233 | 234 | def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric": 235 | if isinstance(other, int): 236 | if other != 1: 237 | raise NotImplementedError("not implemented for arbitrary numbers") 238 | num = alphanumToNumber(self._str) 239 | return self.new(numberToAlphanum(num + 1)) 240 | 241 | return NotImplemented 242 | 243 | def range(self, other: "ArithAlphanumeric", count: int): 244 | assert isinstance(other, ArithAlphanumeric) 245 | n1, n2 = alphanums_to_numbers(self._str, other._str) 246 | split = split_space(n1, n2, count) 247 | return [self.new(numberToAlphanum(s)) for s in split] 248 | 249 | def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: 250 | if isinstance(other, ArithAlphanumeric): 251 | n1, n2 = alphanums_to_numbers(self._str, other._str) 252 | return n1 - n2 253 | 254 | return NotImplemented 255 | 256 | def __ge__(self, other): 257 | if not isinstance(other, type(self)): 258 | return NotImplemented 259 | return self._str >= other._str 260 | 261 | def __lt__(self, other): 262 | if not isinstance(other, type(self)): 263 | return NotImplemented 264 | return self._str < other._str 265 | 266 | def __eq__(self, other): 267 | if not isinstance(other, type(self)): 268 | return NotImplemented 269 | return self._str == other._str 270 | 271 | def new(self, *args, **kw): 272 | return type(self)(*args, **kw, max_len=self._max_len) 273 | 274 | 275 | def number_to_human(n): 276 | millnames = ["", "k", "m", "b"] 277 | n = float(n) 278 | millidx = max( 279 | 0, 280 | min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))), 281 | ) 282 | 283 | return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) 284 | 285 | 286 | def split_space(start, end, count) -> List[int]: 287 | size = end - start 288 | assert count <= size, (count, size) 289 | return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] 290 | 291 | 292 | def remove_passwords_in_dict(d: dict, replace_with: str = "***"): 293 | for k, v in d.items(): 294 | if k == "password": 295 | d[k] = replace_with 296 | elif isinstance(v, dict): 297 | remove_passwords_in_dict(v, replace_with) 298 | elif k.startswith("database"): 299 | d[k] = remove_password_from_url(v, replace_with) 300 | 301 | 302 | def _join_if_any(sym, args): 303 | args = list(args) 304 | if not args: 305 | return "" 306 | return sym.join(str(a) for a in args if a) 307 | 308 | 309 | def remove_password_from_url(url: str, replace_with: str = "***") -> str: 310 | parsed = urlparse(url) 311 | account = parsed.username or "" 312 | if parsed.password: 313 | account += ":" + replace_with 314 | host = _join_if_any(":", filter(None, [parsed.hostname, parsed.port])) 315 | netloc = _join_if_any("@", filter(None, [account, host])) 316 | replaced = parsed._replace(netloc=netloc) 317 | return replaced.geturl() 318 | 319 | 320 | def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]: 321 | reo = re.compile(pattern.replace("%", ".*").replace("?", ".") + "$") 322 | for s in strs: 323 | if reo.match(s): 324 | yield s 325 | 326 | 327 | 328 | class UnknownMeta(type): 329 | def __instancecheck__(self, instance): 330 | return instance is Unknown 331 | 332 | def __repr__(self): 333 | return "Unknown" 334 | 335 | 336 | class Unknown(metaclass=UnknownMeta): 337 | def __nonzero__(self): 338 | raise TypeError() 339 | 340 | def __new__(class_, *args, **kwargs): 341 | raise RuntimeError("Unknown is a singleton") 342 | -------------------------------------------------------------------------------- /tests/test_query.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional 3 | import unittest 4 | from sqeleton.abcs import AbstractDatabase, AbstractDialect 5 | from sqeleton.utils import CaseInsensitiveDict, CaseSensitiveDict 6 | 7 | from sqeleton.queries import this, table, Compiler, outerjoin, cte, when, coalesce, CompileError 8 | from sqeleton.queries.ast_classes import Random 9 | from sqeleton import code, this, table 10 | 11 | 12 | def normalize_spaces(s: str): 13 | return " ".join(s.split()) 14 | 15 | 16 | class MockDialect(AbstractDialect): 17 | name = "MockDialect" 18 | 19 | PLACEHOLDER_TABLE = None 20 | ROUNDS_ON_PREC_LOSS = False 21 | 22 | def quote(self, s: str) -> str: 23 | return s 24 | 25 | def concat(self, l: List[str]) -> str: 26 | s = ", ".join(l) 27 | return f"concat({s})" 28 | 29 | def to_string(self, s: str) -> str: 30 | return f"cast({s} as varchar)" 31 | 32 | def is_distinct_from(self, a: str, b: str) -> str: 33 | return f"{a} is distinct from {b}" 34 | 35 | def random(self) -> str: 36 | return "random()" 37 | 38 | def current_timestamp(self) -> str: 39 | return "now()" 40 | 41 | def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): 42 | x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}" 43 | return " ".join(filter(None, x)) 44 | 45 | def explain_as_text(self, query: str) -> str: 46 | return f"explain {query}" 47 | 48 | def timestamp_value(self, t: datetime) -> str: 49 | return f"timestamp '{t}'" 50 | 51 | def set_timezone_to_utc(self) -> str: 52 | return "set timezone 'UTC'" 53 | 54 | def optimizer_hints(self, s: str): 55 | return f"/*+ {s} */ " 56 | 57 | def load_mixins(self): 58 | raise NotImplementedError() 59 | 60 | parse_type = NotImplemented 61 | 62 | 63 | class MockDatabase(AbstractDatabase): 64 | dialect = MockDialect() 65 | 66 | _query = NotImplemented 67 | query_table_schema = NotImplemented 68 | select_table_schema = NotImplemented 69 | _process_table_schema = NotImplemented 70 | parse_table_name = NotImplemented 71 | close = NotImplemented 72 | _normalize_table_path = NotImplemented 73 | is_autocommit = NotImplemented 74 | 75 | 76 | class TestQuery(unittest.TestCase): 77 | def setUp(self): 78 | pass 79 | 80 | def test_basic(self): 81 | c = Compiler(MockDatabase()) 82 | 83 | t = table("point") 84 | t2 = t.select(x=this.x + 1, y=t["y"] + this.x) 85 | assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point" 86 | 87 | t = table("point").where(this.x == 1, this.y == 2) 88 | assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)" 89 | 90 | t = table("person").where(this.name == "Albert") 91 | self.assertEqual(c.compile(t), "SELECT * FROM person WHERE (name = 'Albert')") 92 | 93 | def test_outerjoin(self): 94 | c = Compiler(MockDatabase()) 95 | 96 | a = table("a") 97 | b = table("b") 98 | keys = ["x", "y"] 99 | cols = ["u", "v"] 100 | 101 | j = outerjoin(a, b).on(a[k] == b[k] for k in keys) 102 | 103 | self.assertEqual( 104 | c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)" 105 | ) 106 | 107 | def test_schema(self): 108 | c = Compiler(MockDatabase()) 109 | schema = dict(id="int", comment="varchar") 110 | 111 | # test table 112 | t = table("a", schema=CaseInsensitiveDict(schema)) 113 | q = t.select(this.Id, t["COMMENT"]) 114 | assert c.compile(q) == "SELECT id, comment FROM a" 115 | 116 | t = table("a", schema=CaseSensitiveDict(schema)) 117 | self.assertRaises(KeyError, t.__getitem__, "Id") 118 | self.assertRaises(KeyError, t.select, this.Id) 119 | 120 | # test select 121 | q = t.select(this.id) 122 | self.assertRaises(KeyError, q.__getitem__, "comment") 123 | 124 | # test join 125 | s = CaseInsensitiveDict({"x": int, "y": int}) 126 | a = table("a", schema=s) 127 | b = table("b", schema=s) 128 | keys = ["x", "y"] 129 | j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a["x"], b["y"], xsum=a["x"] + b["x"]) 130 | j["x"], j["y"], j["xsum"] 131 | self.assertRaises(KeyError, j.__getitem__, "ysum") 132 | 133 | def test_commutable_select(self): 134 | # c = Compiler(MockDatabase()) 135 | 136 | t = table("a") 137 | q1 = t.select("a").where("b") 138 | q2 = t.where("b").select("a") 139 | assert q1 == q2, (q1, q2) 140 | 141 | def test_cte(self): 142 | c = Compiler(MockDatabase()) 143 | 144 | t = table("a") 145 | 146 | # single cte 147 | t2 = cte(t.select(this.x)) 148 | t3 = t2.select(this.x) 149 | 150 | expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1" 151 | assert normalize_spaces(c.compile(t3)) == expected 152 | 153 | # nested cte 154 | c = Compiler(MockDatabase()) 155 | t4 = cte(t3).select(this.x) 156 | 157 | expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2" 158 | assert normalize_spaces(c.compile(t4)) == expected 159 | 160 | # parameterized cte 161 | c = Compiler(MockDatabase()) 162 | t2 = cte(t.select(this.x), params=["y"]) 163 | t3 = t2.select(this.y) 164 | 165 | expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1" 166 | assert normalize_spaces(c.compile(t3)) == expected 167 | 168 | def test_funcs(self): 169 | c = Compiler(MockDatabase()) 170 | t = table("a") 171 | 172 | q = c.compile(t.order_by(Random()).limit(10)) 173 | self.assertEqual(q, "SELECT * FROM a ORDER BY random() LIMIT 10") 174 | 175 | q = c.compile(t.select(coalesce(this.a, this.b))) 176 | self.assertEqual(q, "SELECT COALESCE(a, b) FROM a") 177 | 178 | def test_select_distinct(self): 179 | c = Compiler(MockDatabase()) 180 | t = table("a") 181 | 182 | q = c.compile(t.select(this.b, distinct=True)) 183 | assert q == "SELECT DISTINCT b FROM a" 184 | 185 | # selects merge 186 | q = c.compile(t.where(this.b > 10).select(this.b, distinct=True)) 187 | self.assertEqual(q, "SELECT DISTINCT b FROM a WHERE (b > 10)") 188 | 189 | # selects stay apart 190 | q = c.compile(t.limit(10).select(this.b, distinct=True)) 191 | self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM a LIMIT 10) tmp1") 192 | 193 | q = c.compile(t.select(this.b, distinct=True).select(distinct=False)) 194 | self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2") 195 | 196 | def test_select_with_optimizer_hints(self): 197 | c = Compiler(MockDatabase()) 198 | t = table("a") 199 | 200 | q = c.compile(t.select(this.b, optimizer_hints="PARALLEL(a 16)")) 201 | assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a" 202 | 203 | q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)")) 204 | self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)") 205 | 206 | q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)")) 207 | self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM a LIMIT 10) tmp1") 208 | 209 | q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)")) 210 | self.assertEqual( 211 | q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3" 212 | ) 213 | 214 | def test_table_ops(self): 215 | c = Compiler(MockDatabase()) 216 | a = table("a").select(this.x) 217 | b = table("b").select(this.y) 218 | 219 | q = c.compile(a.union(b)) 220 | assert q == "SELECT x FROM a UNION SELECT y FROM b" 221 | 222 | q = c.compile(a.union_all(b)) 223 | assert q == "SELECT x FROM a UNION ALL SELECT y FROM b" 224 | 225 | q = c.compile(a.minus(b)) 226 | assert q == "SELECT x FROM a EXCEPT SELECT y FROM b" 227 | 228 | q = c.compile(a.intersect(b)) 229 | assert q == "SELECT x FROM a INTERSECT SELECT y FROM b" 230 | 231 | def test_ops(self): 232 | c = Compiler(MockDatabase()) 233 | t = table("a") 234 | 235 | q = c.compile(t.select(this.b + this.c)) 236 | self.assertEqual(q, "SELECT (b + c) FROM a") 237 | 238 | q = c.compile(t.select(this.b.like(this.c))) 239 | self.assertEqual(q, "SELECT (b LIKE c) FROM a") 240 | 241 | q = c.compile(t.select(-this.b.sum())) 242 | self.assertEqual(q, "SELECT (-SUM(b)) FROM a") 243 | 244 | def test_group_by(self): 245 | c = Compiler(MockDatabase()) 246 | t = table("a") 247 | 248 | q = c.compile(t.group_by(this.b).agg(this.c)) 249 | self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1") 250 | 251 | q = c.compile(t.where(this.b > 1).group_by(this.b).agg(this.c)) 252 | self.assertEqual(q, "SELECT b, c FROM a WHERE (b > 1) GROUP BY 1") 253 | 254 | self.assertRaises(CompileError, c.compile, t.select(this.b).group_by(this.b)) 255 | 256 | q = c.compile(t.select(this.b).group_by(this.b).agg()) 257 | self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp1 GROUP BY 1") 258 | 259 | q = c.compile(t.group_by(this.b, this.c).agg(this.d, this.e)) 260 | self.assertEqual(q, "SELECT b, c, d, e FROM a GROUP BY 1, 2") 261 | 262 | # Having 263 | q = c.compile(t.group_by(this.b).agg(this.c).having(this.b > 1)) 264 | self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)") 265 | 266 | q = c.compile(t.group_by(this.b).having(this.b > 1).agg(this.c)) 267 | self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)") 268 | 269 | q = c.compile(t.select(this.b).group_by(this.b).agg().having(this.b > 1)) 270 | self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp2 GROUP BY 1 HAVING (b > 1)") 271 | 272 | # Having sum 273 | q = c.compile(t.group_by(this.b).agg(this.c, this.d).having(this.b.sum() > 1)) 274 | self.assertEqual(q, "SELECT b, c, d FROM a GROUP BY 1 HAVING (SUM(b) > 1)") 275 | 276 | # Select interaction 277 | q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(this.c + 1)) 278 | self.assertEqual(q, "SELECT (c + 1) FROM (SELECT b, c FROM (SELECT a FROM a) tmp3 GROUP BY 1) tmp4") 279 | 280 | def test_case_when(self): 281 | c = Compiler(MockDatabase()) 282 | t = table("a") 283 | 284 | q = c.compile(t.select(when(this.b).then(this.c))) 285 | self.assertEqual(q, "SELECT CASE WHEN b THEN c END FROM a") 286 | 287 | q = c.compile(t.select(when(this.b).then(this.c).else_(this.d))) 288 | self.assertEqual(q, "SELECT CASE WHEN b THEN c ELSE d END FROM a") 289 | 290 | q = c.compile( 291 | t.select( 292 | when(this.type == "text") 293 | .then(this.text) 294 | .when(this.type == "number") 295 | .then(this.number) 296 | .else_("unknown type") 297 | ) 298 | ) 299 | self.assertEqual( 300 | q, 301 | "SELECT CASE WHEN (type = 'text') THEN text WHEN (type = 'number') THEN number ELSE 'unknown type' END FROM a", 302 | ) 303 | 304 | def test_code(self): 305 | c = Compiler(MockDatabase()) 306 | t = table("a") 307 | 308 | q = c.compile(t.select(this.b, code("")).where(code(""))) 309 | self.assertEqual(q, "SELECT b, FROM a WHERE ") 310 | 311 | def tablesample(t, size): 312 | return code("{t} TABLESAMPLE BERNOULLI ({size})", t=t, size=size) 313 | 314 | nonzero = table("points").where(this.x > 0, this.y > 0) 315 | 316 | q = c.compile(tablesample(nonzero, 10)) 317 | self.assertEqual(q, "SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10)") 318 | --------------------------------------------------------------------------------