├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── request-support-for-a-database.md └── workflows │ ├── ci.yml │ └── ci_full.yml ├── .gitignore ├── LICENSE ├── README.md ├── dev ├── Dockerfile.prestosql.340 ├── dev.env ├── presto-conf │ └── standalone │ │ ├── catalog │ │ ├── jmx.properties │ │ ├── memory.properties │ │ ├── postgresql.properties │ │ ├── tpcds.properties │ │ └── tpch.properties │ │ ├── combined.pem │ │ ├── config.properties │ │ ├── jvm.config │ │ ├── log.properties │ │ ├── node.properties │ │ ├── password-authenticator.properties │ │ └── password.db └── trino-conf │ ├── README.md │ ├── cert.conf │ └── etc │ ├── catalog │ ├── jms.properties │ ├── memory.properties │ ├── postgresql.properties │ ├── tpcds.properties │ └── tpch.properties │ ├── config.properties │ ├── jvm.config │ ├── node.properties │ └── password-authenticator.properties ├── docker-compose.yml ├── docs ├── Makefile ├── conf.py ├── conn_editor.md ├── index.rst ├── install.md ├── intro.md ├── make.bat ├── python-api.rst ├── requirements.txt └── supported-databases.md ├── poetry.lock ├── pyproject.toml ├── readthedocs.yml ├── sqeleton ├── __init__.py ├── __main__.py ├── abcs │ ├── __init__.py │ ├── compiler.py │ ├── database_types.py │ └── mixins.py ├── bound_exprs.py ├── conn_editor.py ├── databases │ ├── __init__.py │ ├── _connect.py │ ├── base.py │ ├── bigquery.py │ ├── clickhouse.py │ ├── databricks.py │ ├── duckdb.py │ ├── mssql.py │ ├── mysql.py │ ├── oracle.py │ ├── postgresql.py │ ├── presto.py │ ├── redshift.py │ ├── snowflake.py │ ├── trino.py │ └── vertica.py ├── queries │ ├── __init__.py │ ├── api.py │ ├── ast_classes.py │ ├── base.py │ ├── compiler.py │ └── extras.py ├── query_utils.py ├── repl.py ├── schema.py └── utils.py └── tests ├── __init__.py ├── common.py ├── test_database.py ├── test_mixins.py ├── test_query.py ├── test_sql.py ├── test_utils.py └── waiting_for_stack_up.sh /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Describe the environment** 14 | 15 | Describe which OS you're using, which sqeleton version, and any other information that might be relevant to this bug. 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/request-support-for-a-database.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Request support for a database 3 | about: 'Request a driver to support a new database ' 4 | title: 'Add support for ' 5 | labels: new-db-driver 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI-COVER-VERSIONS 2 | 3 | on: 4 | # push: 5 | # paths: 6 | # - '**.py' 7 | # - '.github/workflows/**' 8 | # - '!dev/**' 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | 13 | jobs: 14 | unit_tests: 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest] 19 | python-version: 20 | - "3.8" 21 | - "3.9" 22 | - "3.10" 23 | - "3.11" 24 | - "3.12" 25 | 26 | name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} 27 | runs-on: ${{ matrix.os }} 28 | steps: 29 | - uses: actions/checkout@v3 30 | 31 | - name: Setup Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v3 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | 36 | - name: Build the stack 37 | run: docker compose up -d mysql postgres trino clickhouse vertica 38 | 39 | - name: Install Poetry 40 | run: pip install poetry 41 | 42 | - name: Install package 43 | run: "poetry install" 44 | 45 | - name: Run unit tests 46 | env: 47 | # PRESTO_URI: 'presto://presto@127.0.0.1/postgresql/public' 48 | TRINO_URI: 'trino://postgres@127.0.0.1:8081/postgresql/public' 49 | CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' 50 | VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' 51 | SNOWFLAKE_URI: '${{ secrets.SNOWFLAKE_URI }}' 52 | REDSHIFT_URI: '${{ secrets.REDSHIFT_URI }}' 53 | run: | 54 | chmod +x tests/waiting_for_stack_up.sh 55 | ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=0 poetry run unittest-parallel -j 16 56 | -------------------------------------------------------------------------------- /.github/workflows/ci_full.yml: -------------------------------------------------------------------------------- 1 | name: CI-COVER-DATABASES 2 | 3 | on: 4 | # push: 5 | # paths: 6 | # - '**.py' 7 | # - '.github/workflows/**' 8 | # - '!dev/**' 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | 13 | jobs: 14 | unit_tests: 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest] 19 | python-version: 20 | - "3.10" 21 | 22 | name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} 23 | runs-on: ${{ matrix.os }} 24 | steps: 25 | - uses: actions/checkout@v3 26 | 27 | - name: Setup Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v3 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Build the stack 33 | run: docker compose up -d mysql postgres trino clickhouse vertica 34 | 35 | - name: Install Poetry 36 | run: pip install poetry 37 | 38 | - name: Install package 39 | run: "poetry install" 40 | 41 | - name: Run unit tests 42 | env: 43 | # PRESTO_URI: 'presto://presto@127.0.0.1/postgresql/public' 44 | TRINO_URI: 'trino://postgres@127.0.0.1:8081/postgresql/public' 45 | CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' 46 | VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' 47 | SNOWFLAKE_URI: '${{ secrets.SNOWFLAKE_URI }}' 48 | REDSHIFT_URI: '${{ secrets.REDSHIFT_URI }}' 49 | run: | 50 | chmod +x tests/waiting_for_stack_up.sh 51 | ./tests/waiting_for_stack_up.sh && poetry run unittest-parallel -j 16 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mac 132 | .DS_Store 133 | 134 | # IntelliJ 135 | .idea 136 | 137 | # VSCode 138 | .vscode 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Erez Shinan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sqeleton 2 | 3 | Sqeleton is a Python library for querying SQL databases. 4 | 5 | It consists of - 6 | 7 | - A fast and concise query builder, designed from scratch, but inspired by PyPika and SQLAlchemy 8 | 9 | - A modular database interface, with drivers for a long list of SQL databases (see below) 10 | 11 | It is comparable to other libraries such as SQLAlchemy or PyPika, in terms of API and intended audience. However, there are several notable ways in which it is different. 12 | 13 | ## **Features:** 14 | 15 | 🏃‍♂️ **High-performance**: Sqeleton's API is designed to maximize performance using batch operations 16 | 17 | - No ORM. Sqeleton takes the position that while ORMs feel easy and familiar, their granular operations are far too slow. 18 | - Compiles queries 4 times faster than SQLAlchemy 19 | 20 | 🙌 **Parallel**: Seamless multi-threading and multi-processing support 21 | 22 | 💖 **Well-tested**: In addition to having an extensive test-suite, sqeleton is used as the core of [data-diff](https://github.com/datafold/data-diff) (now [reladiff](https://github.com/erezsh/reladiff)). 23 | 24 | ✅ **Type-aware**: The schema is used for validation when building expressions, making sure the names are correct, and that the data-types align. (WIP) 25 | 26 | - The schema can be queried at run-time, if the tables already exist in the database 27 | 28 | ✨ **Multi-database access**: Sqeleton is designed to work with several databases at the same time. Its API abstracts away as many implementation details as possible. 29 | 30 | _Databases we fully support_: 31 | 32 | - PostgreSQL >=10 33 | - MySQL 34 | - Snowflake 35 | - BigQuery 36 | - Redshift 37 | - Oracle 38 | - Presto 39 | - Databricks 40 | - Trino 41 | - Clickhouse 42 | - Vertica 43 | - DuckDB >=0.6 44 | - SQLite (coming soon) 45 | 46 | 💻 **Built-in SQL client**: Connect to any of the supported databases with just one line. 47 | 48 | Example usage: `sqeleton repl snowflake://...` 49 | 50 | - Has syntax-highlighting, and autocomplete 51 | - Use `*text` to find all tables like `%text%` (or just `*` to see all tables) 52 | - Use `?name` to see the schema of the table called `name`. 53 | 54 | ## Documentation 55 | 56 | [Read the docs!](https://sqeleton.readthedocs.io) 57 | 58 | Or jump straight to the [introduction](https://sqeleton.readthedocs.io/en/latest/intro.html). 59 | 60 | ### Install 61 | 62 | Install using pip: 63 | 64 | ```bash 65 | pip install sqeleton 66 | ``` 67 | 68 | It is recommended to install the driver dependencies using pip's `[]` syntax: 69 | 70 | ```bash 71 | pip install 'sqeleton[mysql, postgresql]' 72 | ``` 73 | 74 | Read more in [install / getting started.](https://sqeleton.readthedocs.io/en/latest/install.html) 75 | 76 | ### Example: Basic usage 77 | 78 | We will create a table with the numbers 0..100, and then sum them up. 79 | 80 | ```python 81 | from sqeleton import connect, table, this 82 | 83 | # Create a new database connection 84 | ddb = connect("duckdb://:memory:") 85 | 86 | # Define a table with one int column 87 | tbl = table('my_list', schema={'item': int}) 88 | 89 | # Make a bunch of queries 90 | queries = [ 91 | # Create table 'my_list' 92 | tbl.create(), 93 | 94 | # Insert 100 numbers 95 | tbl.insert_rows([x] for x in range(100)), 96 | 97 | # Get the sum of the numbers 98 | tbl.select(this.item.sum()) 99 | ] 100 | # Query in order, and return the last result as an int 101 | result = ddb.query(queries, int) 102 | 103 | # Prints: Total sum of 0..100 = 4950 104 | print(f"Total sum of 0..100 = {result}") 105 | ``` 106 | 107 | ### Example: Advanced usage 108 | 109 | We will define a function that performs outer-join on any database, and adds two extra fields: `only_a` and `only_b`. 110 | 111 | ```python 112 | from sqeleton.databases import Database 113 | from sqeleton.queries import ITable, leftjoin, rightjoin, outerjoin, and_, Expr 114 | 115 | def my_outerjoin( 116 | db: Database, 117 | a: ITable, b: ITable, 118 | keys1: List[str], keys2: List[str], 119 | select_fields: Dict[str, Expr] 120 | ) -> ITable: 121 | """This function accepts two table expressions, and returns an outer-join query. 122 | 123 | The resulting rows will include two extra boolean fields: 124 | "only_a", and "only_b", describing whether there was a match for that row 125 | only in the first table, or only in the second table. 126 | 127 | Parameters: 128 | db - the database connection to use 129 | a, b - the tables to outer-join 130 | keys1, keys2 - the names of the columns to join on, for each table respectively 131 | select_fields - A dictionary of {column_name: expression} to select as a result of the outer-join 132 | """ 133 | # Predicates to join on 134 | on = [a[k1] == b[k2] for k1, k2 in zip(keys1, keys2)] 135 | 136 | # Define the new boolean fields 137 | # If all keys are None, it means there was no match 138 | # Compiles to " IS NULL AND IS NULL AND IS NULL..." etc. 139 | only_a = and_(b[k] == None for k in keys2) 140 | only_b = and_(a[k] == None for k in keys1) 141 | 142 | if isinstance(db, MySQL): 143 | # MySQL doesn't support "outer join" 144 | # Instead, we union "left join" and "right join" 145 | l = leftjoin(a, b).on(*on).select( 146 | only_a=only_a, 147 | only_b=False, 148 | **select_fields 149 | ) 150 | r = rightjoin(a, b).on(*on).select( 151 | only_a=False, 152 | only_b=only_b, 153 | **select_fields 154 | ) 155 | return l.union(r) 156 | 157 | # Other databases 158 | return outerjoin(a, b).on(*on).select( 159 | only_a=only_a, 160 | only_b=only_b, 161 | **select_fields 162 | ) 163 | ``` 164 | 165 | 166 | 167 | # TODO 168 | 169 | - Transactions 170 | 171 | - Indexes 172 | 173 | - Date/time expressions 174 | 175 | - Window functions 176 | 177 | ## Possible plans for the future (not determined yet) 178 | 179 | - Cache the compilation of repetitive queries for even faster query-building 180 | 181 | - Compile control flow, functions 182 | 183 | - Define tables using type-annotated classes (SQLModel style) 184 | 185 | ## Alternatives 186 | 187 | - [SQLAlchemy](https://www.sqlalchemy.org/) 188 | - [PyPika](https://github.com/kayak/pypika) 189 | - [PonyORM](https://ponyorm.org/) 190 | - [peewee](https://github.com/coleifer/peewee) 191 | 192 | # Thanks 193 | 194 | Thanks to Datafold for having sponsored Sqeleton in its initial stages. For reference, [the original repo](https://github.com/datafold/sqeleton/). 195 | -------------------------------------------------------------------------------- /dev/Dockerfile.prestosql.340: -------------------------------------------------------------------------------- 1 | FROM openjdk:11-jdk-slim-buster 2 | 3 | ENV PRESTO_VERSION=340 4 | ENV PRESTO_SERVER_URL=https://repo1.maven.org/maven2/io/prestosql/presto-server/${PRESTO_VERSION}/presto-server-${PRESTO_VERSION}.tar.gz 5 | ENV PRESTO_CLI_URL=https://repo1.maven.org/maven2/io/prestosql/presto-cli/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar 6 | ENV PRESTO_HOME=/opt/presto 7 | ENV PATH=${PRESTO_HOME}/bin:${PATH} 8 | 9 | WORKDIR $PRESTO_HOME 10 | 11 | RUN set -xe \ 12 | && apt-get update \ 13 | && apt-get install -y curl less python \ 14 | && curl -sSL $PRESTO_SERVER_URL | tar xz --strip 1 \ 15 | && curl -sSL $PRESTO_CLI_URL > ./bin/presto \ 16 | && chmod +x ./bin/presto \ 17 | && apt-get remove -y curl \ 18 | && rm -rf /var/lib/apt/lists/* 19 | 20 | VOLUME /data 21 | 22 | EXPOSE 8080 23 | 24 | ENTRYPOINT ["launcher"] 25 | CMD ["run"] 26 | -------------------------------------------------------------------------------- /dev/dev.env: -------------------------------------------------------------------------------- 1 | POSTGRES_USER=postgres 2 | POSTGRES_PASSWORD=Password1 3 | POSTGRES_DB=postgres 4 | 5 | MYSQL_DATABASE=mysql 6 | MYSQL_USER=mysql 7 | MYSQL_PASSWORD=Password1 8 | MYSQL_ROOT_PASSWORD=RootPassword1 9 | 10 | CLICKHOUSE_USER=clickhouse 11 | CLICKHOUSE_PASSWORD=Password1 12 | CLICKHOUSE_DB=clickhouse 13 | CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT=1 14 | 15 | # Vertica credentials 16 | APP_DB_USER=vertica 17 | APP_DB_PASSWORD=Password1 18 | VERTICA_DB_NAME=vertica 19 | 20 | # To prevent generating sample demo VMart data (more about it here https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/GettingStartedGuide/IntroducingVMart/IntroducingVMart.htm), 21 | # leave VMART_DIR and VMART_ETL_SCRIPT empty. 22 | VMART_DIR= 23 | VMART_ETL_SCRIPT= 24 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/jmx.properties: -------------------------------------------------------------------------------- 1 | connector.name=jmx 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/memory.properties: -------------------------------------------------------------------------------- 1 | connector.name=memory 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/postgresql.properties: -------------------------------------------------------------------------------- 1 | connector.name=postgresql 2 | connection-url=jdbc:postgresql://postgres:5432/postgres 3 | connection-user=postgres 4 | connection-password=Password1 5 | allow-drop-table=true 6 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/tpcds.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpcds 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/catalog/tpch.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpch 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/combined.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIERDCCAyygAwIBAgIUBxO/CflDP+0yZAXt5FKm/WQ4538wDQYJKoZIhvcNAQEL 3 | BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM 4 | GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X 5 | DTIyMDgyNDA4NTI1N1oXDTMyMDgyMTA4NTI1N1owWTELMAkGA1UEBhMCQVUxEzAR 6 | BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 7 | IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 8 | MIIBCgKCAQEA3lNywkj/eGPGoFA3Lcx++98l17CRy+uzZMtJsr6lYAg1p/n1vPw0 9 | BQXI5TSBJ6vM/axtwgwrfXQsjQ/GYJKQkb6eEBCc3xb+Rk5HNBiBBZsIjYm0U1zz 10 | 7dKnNwAznjx3j72s2ZQiqkoxcu7Bctw28ynbg0rjNkuUk3QESKuOgaTltpWKZiiu 11 | XwWasREeH6MH7ROy8db6cz+MwGaig0mUvGPmD97bPRD/X683RyOiXzEaogl/rpGK 12 | qZ3jRsmS8ZwawzKxx16kqPsX8/01EruGIoubMttr3YoZG044zq7nQqdAAz6wXx6V 13 | mgzToCHI+/g+8JS/bgqJTyb2Y6aGXExiuQIDAQABo4IBAjCB/zAdBgNVHQ4EFgQU 14 | 5i1F8pTnwjFxw6W/0RjwpaJaK9MwgZYGA1UdIwSBjjCBi4AU5i1F8pTnwjFxw6W/ 15 | 0RjwpaJaK9OhXaRbMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRl 16 | MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxv 17 | Y2FsaG9zdIIUBxO/CflDP+0yZAXt5FKm/WQ4538wDAYDVR0TBAUwAwEB/zALBgNV 18 | HQ8EBAMCAvwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MBQGA1UdEgQNMAuCCWxvY2Fs 19 | aG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAjBQLl/UFSd9TH2VLM1GH8bixtEJ9+rm7 20 | x+Jw665+XLjW107dJ33qxy9zjd3cZ2fynKg2Tb7+9QAvSlqpt2YMGP9jr4W2w16u 21 | ngbNB+kfoOotcUChk90aHmdHLZgOOve/ArFIvbr8douLOn0NAJBrj+iX4zC1pgEC 22 | 9hsMUekkAPIcCGc0rEkEc8r8uiUBWNAdEWpBt0X2fE1ownLuB/E/+3HutLTw8Lv0 23 | b+jNt/vogVixcw/FF4atoO+F7S5FYzAb0U7YXaNISfVPVBsA89oPy7PlxULHDUIF 24 | Iq+vVqKdj1EXR+Iec0TMiMsa3MnIGkpL7ZuUXaG+xGBaVhGrUp67lQ== 25 | -----END CERTIFICATE----- 26 | -----BEGIN RSA PRIVATE KEY----- 27 | MIIEogIBAAKCAQEA3lNywkj/eGPGoFA3Lcx++98l17CRy+uzZMtJsr6lYAg1p/n1 28 | vPw0BQXI5TSBJ6vM/axtwgwrfXQsjQ/GYJKQkb6eEBCc3xb+Rk5HNBiBBZsIjYm0 29 | U1zz7dKnNwAznjx3j72s2ZQiqkoxcu7Bctw28ynbg0rjNkuUk3QESKuOgaTltpWK 30 | ZiiuXwWasREeH6MH7ROy8db6cz+MwGaig0mUvGPmD97bPRD/X683RyOiXzEaogl/ 31 | rpGKqZ3jRsmS8ZwawzKxx16kqPsX8/01EruGIoubMttr3YoZG044zq7nQqdAAz6w 32 | Xx6VmgzToCHI+/g+8JS/bgqJTyb2Y6aGXExiuQIDAQABAoIBAD3pKwnjXhDeaA94 33 | hwUf7zSgfV9E8jTBHCGzYoB+CntljduLBd1stee4Jqt9JYIwm1MA00e4L9wtn8Jg 34 | ZDO8XLnZRRbgKW8ObhyR684cDMHM3GLdt/OG7P6LLLlqOvWTjQ/gF+Q3FjgplP+W 35 | cRRVMpAgVdqH3iHehi9RnWfHLlX3WkBC97SumWFzWqBnqUQAC8AvFHiUCqA9qIeA 36 | 8ieOEoE17yv2nkmu+A5OZoCXtVfc2lQ90Fj9QZiv4rIVXBtTRRURJuvi2iX3nOPl 37 | MsjAUIBK1ndzpJ7wuLICSR1U3/npPC6Va06lTm0H/Q6DEqZjEHbx9TGY3pTgVXuA 38 | +G0C5GkCgYEA/spvoDUMZrH2JE43TT/qMEHPtHW4qT6fTmzu08Gx++8nFmddNgSD 39 | zrdjxqPMUGV7Q5smwoHaQyqFxHMM2jh8icnV6VoBDrDdZM0eGFAs6HUjKmyaAdQO 40 | dC4kPiy3LX5pJUnQnmwq1fVsgXWGQF/LhD0L/y6xOiqdhZp/8nv6SFMCgYEA32GR 41 | gWJQdgWXTXxSEDn0twKPevgAT0s778/7h5osCLG82Q7ab2+Fc1qTleiwiQ2SAuOl 42 | mWvtz0Eg4dYe/q6jugqkEgAYZvdHGL7CSmC28O98fTLapgKQC5GUUan90sCbRec4 43 | kjbyx5scICNBYJVchdFg6UUSNz5czORUVgQEF0MCgYB1toUX2Spfj7yOTWyTTgIe 44 | RWl2kCS+XGYxT3aPcp+OK5E9cofH2xIiQOvh6+8K/beTJm0j0+ZIva6LcjPv5cTz 45 | y8H+S0zNwrymQ3Wx+eilhOi4QvBsA9KhrmekKfh/FjXxukadyo+HxhlZPjjGKPvX 46 | nnSacrICk4mvHhAasViSbQKBgD7mZGiAXJO/I0moVhtHlobp66j+qGerkacHc5ZN 47 | bVTNZ5XfPtbeGj/PI3u01/Dfp1u06m53G7GebznoZzXjyyqZ0HVZHYXw304yeNck 48 | wJ67cNx4M2VHl3QKfC86pMRxg8d9Qkq5ukdGf/b0tnYR2Mm9mYJV9rkjkFIJgU3v 49 | N4+tAoGABOlVGuRx2cSQ9QeC0AcqKlxXygdrzyadA7i0KNBZGGyrMSpJDrl2rrRn 50 | ylzAgGjvfilwQzZuqTm6Vo2yvaX+TTGS44B+DnxCZvuviftea++sNMjuEkBLTCpF 51 | xk2yOzsOnx652kWO4L+dVrDAxl65f3v0YaKWZI504LFYl18uS/E= 52 | -----END RSA PRIVATE KEY----- 53 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/config.properties: -------------------------------------------------------------------------------- 1 | coordinator=true 2 | node-scheduler.include-coordinator=true 3 | http-server.http.port=8080 4 | query.max-memory=5GB 5 | query.max-memory-per-node=1GB 6 | query.max-total-memory-per-node=2GB 7 | discovery-server.enabled=true 8 | discovery.uri=http://127.0.0.1:8080 9 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/jvm.config: -------------------------------------------------------------------------------- 1 | -server 2 | -Xmx16G 3 | -XX:+UseG1GC 4 | -XX:G1HeapRegionSize=32M 5 | -XX:+UseGCOverheadLimit 6 | -XX:+ExplicitGCInvokesConcurrent 7 | -XX:+HeapDumpOnOutOfMemoryError 8 | -XX:+ExitOnOutOfMemoryError 9 | -XX:OnOutOfMemoryError=kill -9 %p 10 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/log.properties: -------------------------------------------------------------------------------- 1 | com.facebook.presto=INFO 2 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/node.properties: -------------------------------------------------------------------------------- 1 | node.environment=production 2 | node.data-dir=/data 3 | node.id=standalone 4 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/password-authenticator.properties: -------------------------------------------------------------------------------- 1 | password-authenticator.name=file 2 | file.password-file=/opt/presto/etc/password.db 3 | -------------------------------------------------------------------------------- /dev/presto-conf/standalone/password.db: -------------------------------------------------------------------------------- 1 | test:$2y$10$877iU3J5a26SPDjFSrrz2eFAq2DwMDsBAus92Dj0z5A5qNMNlnpHa 2 | -------------------------------------------------------------------------------- /dev/trino-conf/README.md: -------------------------------------------------------------------------------- 1 | ### Local Trino Cluster with Self-Signed Cert 2 | 3 | In order to test Trino with user@password you need a local cluster with Basic Auth enabled. 4 | 5 | ## First thing, we need a self signed cert. 6 | ``` 7 | openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout dev/trino-conf/etc/trino-key.pem -out dev/trino-conf/etc/trino-cert.pem -config dev/trino-conf/cert.conf 8 | cat dev/trino-conf/etc/trino-cert.pem dev/trino-conf/etc/trino-key.pem > dev/trino-conf/etc/trino-combined.pem 9 | ``` 10 | 11 | ## Include https settings in config.properties 12 | 13 | Include the following lines in dev/trino-conf/etc/config.properties 14 | 15 | ``` 16 | http-server.https.port=8443 17 | http-server.https.enabled=true 18 | http-server.https.keystore.path=/etc/trino/trino-combined.pem 19 | http-server.https.keystore.key= 20 | http-server.authentication.type=PASSWORD 21 | internal-communication.shared-secret="secret" 22 | ``` 23 | 24 | ## Create a password.db file 25 | 26 | ``` 27 | touch dev/trino-conf/etc/password.db 28 | htpasswd -B -C 10 dev/trino-conf/etc/password.db test 29 | ``` -------------------------------------------------------------------------------- /dev/trino-conf/cert.conf: -------------------------------------------------------------------------------- 1 | [req] 2 | distinguished_name = req_distinguished_name 3 | req_extensions = req_ext 4 | x509_extensions = v3_ca 5 | prompt = no 6 | 7 | [req_distinguished_name] 8 | CN = localhost 9 | 10 | [req_ext] 11 | subjectAltName = @alt_names 12 | 13 | [v3_ca] 14 | subjectAltName = @alt_names 15 | 16 | [alt_names] 17 | DNS.1 = localhost 18 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/jms.properties: -------------------------------------------------------------------------------- 1 | connector.name=jmx 2 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/memory.properties: -------------------------------------------------------------------------------- 1 | connector.name=memory 2 | memory.max-data-per-node=128MB 3 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/postgresql.properties: -------------------------------------------------------------------------------- 1 | connector.name=postgresql 2 | connection-url=jdbc:postgresql://postgres:5432/postgres 3 | connection-user=postgres 4 | connection-password=Password1 5 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/tpcds.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpcds 2 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/catalog/tpch.properties: -------------------------------------------------------------------------------- 1 | connector.name=tpch 2 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/config.properties: -------------------------------------------------------------------------------- 1 | coordinator=true 2 | node-scheduler.include-coordinator=true 3 | http-server.http.port=8080 4 | discovery.uri=http://localhost:8080 5 | discovery-server.enabled=true 6 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/jvm.config: -------------------------------------------------------------------------------- 1 | -server 2 | -Xmx1G 3 | -XX:+UseG1GC 4 | -XX:+UnlockDiagnosticVMOptions 5 | -XX:G1NumCollectionsKeepPinned=10000000 6 | -XX:G1HeapRegionSize=32M 7 | -XX:+ExplicitGCInvokesConcurrent 8 | -XX:+HeapDumpOnOutOfMemoryError 9 | -XX:+UseGCOverheadLimit 10 | -XX:+ExitOnOutOfMemoryError 11 | -XX:ReservedCodeCacheSize=256M 12 | -Djdk.attach.allowAttachSelf=true 13 | -Djdk.nio.maxCachedBufferSize=2000000 -------------------------------------------------------------------------------- /dev/trino-conf/etc/node.properties: -------------------------------------------------------------------------------- 1 | node.environment=docker 2 | node.data-dir=/data/trino 3 | plugin.dir=/usr/lib/trino/plugin 4 | -------------------------------------------------------------------------------- /dev/trino-conf/etc/password-authenticator.properties: -------------------------------------------------------------------------------- 1 | password-authenticator.name=file 2 | file.password-file=/etc/trino/password.db 3 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | postgres: 5 | container_name: dd-postgresql 6 | image: postgres:14.1-alpine 7 | # work_mem: less tmp files 8 | # maintenance_work_mem: improve table-level op perf 9 | # max_wal_size: allow more time before merging to heap 10 | command: > 11 | -c work_mem=1GB 12 | -c maintenance_work_mem=1GB 13 | -c max_wal_size=8GB 14 | restart: always 15 | volumes: 16 | - postgresql-data:/var/lib/postgresql/data:delegated 17 | ports: 18 | - '5432:5432' 19 | expose: 20 | - '5432' 21 | env_file: 22 | - dev/dev.env 23 | tty: true 24 | networks: 25 | - local 26 | 27 | mysql: 28 | container_name: dd-mysql 29 | image: mysql:oracle 30 | # fsync less aggressively for insertion perf for test setup 31 | command: > 32 | --binlog-cache-size=16M 33 | --key_buffer_size=0 34 | --max_connections=1000 35 | --innodb_flush_log_at_trx_commit=2 36 | --innodb_flush_log_at_timeout=10 37 | --innodb_log_compressed_pages=OFF 38 | --sync_binlog=0 39 | restart: always 40 | volumes: 41 | - mysql-data:/var/lib/mysql:delegated 42 | user: mysql 43 | ports: 44 | - '3306:3306' 45 | expose: 46 | - '3306' 47 | env_file: 48 | - dev/dev.env 49 | tty: true 50 | networks: 51 | - local 52 | 53 | clickhouse: 54 | container_name: dd-clickhouse 55 | image: clickhouse/clickhouse-server:21.12.3.32 56 | restart: always 57 | volumes: 58 | - clickhouse-data:/var/lib/clickhouse:delegated 59 | ulimits: 60 | nproc: 65535 61 | nofile: 62 | soft: 262144 63 | hard: 262144 64 | ports: 65 | - '8123:8123' 66 | - '9000:9000' 67 | expose: 68 | - '8123' 69 | - '9000' 70 | env_file: 71 | - dev/dev.env 72 | tty: true 73 | networks: 74 | - local 75 | 76 | # prestodb.dbapi.connect(host="127.0.0.1", user="presto").cursor().execute('SELECT * FROM system.runtime.nodes') 77 | presto: 78 | container_name: dd-presto 79 | build: 80 | context: ./dev 81 | dockerfile: ./Dockerfile.prestosql.340 82 | volumes: 83 | - ./dev/presto-conf/standalone:/opt/presto/etc:ro 84 | ports: 85 | - '8080:8080' 86 | tty: true 87 | networks: 88 | - local 89 | 90 | trino: 91 | container_name: dd-trino 92 | image: 'trinodb/trino:450' 93 | hostname: trino 94 | ports: 95 | - '8081:8080' 96 | - '8443:8443' 97 | volumes: 98 | - ./dev/trino-conf/etc:/etc/trino:ro 99 | networks: 100 | - local 101 | 102 | vertica: 103 | container_name: dd-vertica 104 | image: vertica/vertica-ce:12.0.0-0 105 | restart: always 106 | volumes: 107 | - vertica-data:/data:delegated 108 | ports: 109 | - '5433:5433' 110 | - '5444:5444' 111 | expose: 112 | - '5433' 113 | - '5444' 114 | env_file: 115 | - dev/dev.env 116 | tty: true 117 | networks: 118 | - local 119 | 120 | 121 | 122 | volumes: 123 | postgresql-data: 124 | mysql-data: 125 | clickhouse-data: 126 | vertica-data: 127 | 128 | networks: 129 | local: 130 | driver: bridge 131 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = sqeleton 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Documentation build configuration file, created by 5 | # sphinx-quickstart on Sun Aug 16 13:09:41 2020. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath("..")) 24 | sys.path.append(os.path.abspath("./_ext")) 25 | autodoc_member_order = "bysource" 26 | 27 | import builtins 28 | 29 | builtins.__sphinx_build__ = True 30 | 31 | # -- General configuration ------------------------------------------------ 32 | 33 | # If your documentation needs a minimal Sphinx version, state it here. 34 | # 35 | # needs_sphinx = '1.0' 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 39 | # ones. 40 | extensions = [ 41 | "sphinx.ext.autodoc", 42 | "sphinx.ext.napoleon", 43 | "sphinx.ext.coverage", 44 | "recommonmark", 45 | "sphinx_markdown_tables", 46 | "sphinx_copybutton", 47 | "enum_tools.autoenum", 48 | # 'sphinx_gallery.gen_gallery' 49 | ] 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ["_templates"] 53 | 54 | # The suffix(es) of source filenames. 55 | # You can specify multiple suffix as a list of string: 56 | # 57 | # source_suffix = ['.rst', '.md'] 58 | source_suffix = {".rst": "restructuredtext", ".md": "markdown"} 59 | 60 | 61 | # The master toctree document. 62 | master_doc = "index" 63 | 64 | # General information about the project. 65 | project = "Sqeleton" 66 | copyright = "Erez Shinan" 67 | author = "Erez Shinan" 68 | 69 | # The version info for the project you're documenting, acts as replacement for 70 | # |version| and |release|, also used in various other places throughout the 71 | # built documents. 72 | # 73 | # The short X.Y version. 74 | version = "" 75 | # The full version, including alpha/beta/rc tags. 76 | release = "" 77 | 78 | # The language for content autogenerated by Sphinx. Refer to documentation 79 | # for a list of supported languages. 80 | # 81 | # This is also used if you do content translation via gettext catalogs. 82 | # Usually you set "language" from the command line for these cases. 83 | language = None 84 | 85 | # List of patterns, relative to source directory, that match files and 86 | # directories to ignore when looking for source files. 87 | # This patterns also effect to html_static_path and html_extra_path 88 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 89 | 90 | # The name of the Pygments (syntax highlighting) style to use. 91 | pygments_style = "sphinx" 92 | 93 | # If true, `todo` and `todoList` produce output, else they produce nothing. 94 | todo_include_todos = False 95 | 96 | autodoc_default_options = { 97 | # 'special-members': '__init__', 98 | "exclude-members": "json,aslist,astuple,replace", 99 | } 100 | 101 | # -- Options for HTML output ---------------------------------------------- 102 | 103 | # The theme to use for HTML and HTML Help pages. See the documentation for 104 | # a list of builtin themes. 105 | # 106 | html_theme = "sphinx_rtd_theme" 107 | 108 | # Theme options are theme-specific and customize the look and feel of a theme 109 | # further. For a list of options available for each theme, see the 110 | # documentation. 111 | # 112 | # html_theme_options = {} 113 | 114 | # Add any paths that contain custom static files (such as style sheets) here, 115 | # relative to this directory. They are copied after the builtin static files, 116 | # so a file named "default.css" will overwrite the builtin "default.css". 117 | html_static_path = ["_static"] 118 | 119 | html_css_files = [ 120 | "custom.css", 121 | ] 122 | 123 | # Custom sidebar templates, must be a dictionary that maps document names 124 | # to template names. 125 | # 126 | # This is required for the alabaster theme 127 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 128 | html_sidebars = { 129 | "**": [ 130 | "relations.html", # needs 'show_related': True theme option to display 131 | "searchbox.html", 132 | ] 133 | } 134 | 135 | 136 | # -- Options for HTMLHelp output ------------------------------------------ 137 | 138 | # Output file base name for HTML help builder. 139 | htmlhelp_basename = "sqeletondoc" 140 | 141 | 142 | # -- Options for LaTeX output --------------------------------------------- 143 | 144 | latex_elements = { 145 | # The paper size ('letterpaper' or 'a4paper'). 146 | # 147 | # 'papersize': 'letterpaper', 148 | # The font size ('10pt', '11pt' or '12pt'). 149 | # 150 | # 'pointsize': '10pt', 151 | # Additional stuff for the LaTeX preamble. 152 | # 153 | # 'preamble': '', 154 | # Latex figure (float) alignment 155 | # 156 | # 'figure_align': 'htbp', 157 | } 158 | 159 | # Grouping the document tree into LaTeX files. List of tuples 160 | # (source start file, target name, title, 161 | # author, documentclass [howto, manual, or own class]). 162 | latex_documents = [ 163 | (master_doc, "Sqeleton.tex", "Sqeleton Documentation", "Erez Shinan", "manual"), 164 | ] 165 | 166 | 167 | # -- Options for manual page output --------------------------------------- 168 | 169 | # One entry per manual page. List of tuples 170 | # (source start file, name, description, authors, manual section). 171 | man_pages = [(master_doc, "Sqeleton", "Sqeleton Documentation", [author], 1)] 172 | 173 | 174 | # -- Options for Texinfo output ------------------------------------------- 175 | 176 | # Grouping the document tree into Texinfo files. List of tuples 177 | # (source start file, target name, title, author, 178 | # dir menu entry, description, category) 179 | texinfo_documents = [ 180 | ( 181 | master_doc, 182 | "Sqeleton", 183 | "Sqeleton Documentation", 184 | author, 185 | "Sqeleton", 186 | "One line description of project.", 187 | "Miscellaneous", 188 | ), 189 | ] 190 | 191 | # -- Sphinx gallery config ------------------------------------------- 192 | 193 | # sphinx_gallery_conf = { 194 | # 'examples_dirs': ['../examples'], 195 | # 'gallery_dirs': ['examples'], 196 | # } 197 | -------------------------------------------------------------------------------- /docs/conn_editor.md: -------------------------------------------------------------------------------- 1 | # Connection Editor 2 | 3 | A common complaint among new users was the difficulty in setting up the connections. 4 | 5 | Connection URLs are admittedly confusing, and editing `.toml` files isn't always straight-forward either. 6 | 7 | To ease this initial difficulty, we added a `textual`-based TUI tool to sqeleton, that allows users to edit configuration files and test the connections while editing them. 8 | 9 | ## Install 10 | 11 | This tool needs `textual` to run. You can install it using: 12 | 13 | ```bash 14 | pip install 'sqeleton[tui]' 15 | ``` 16 | 17 | Make sure you also have drivers for whatever database connection you're going to edit! 18 | 19 | ## Run 20 | 21 | Once everything is installed, you can run the editor with the following command: 22 | 23 | ```bash 24 | sqeleton conn-editor 25 | ``` 26 | 27 | Example: 28 | 29 | ```bash 30 | sqeleton conn-editor ~/dbs.toml 31 | ``` 32 | 33 | The available actions and hotkeys will be listed in the status bar. 34 | 35 | Note: using the connection editor will delete comments and reformat the file! 36 | 37 | We recommend backing up the configuration file before editing it. -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | :caption: API Reference 4 | :hidden: 5 | 6 | install 7 | intro 8 | supported-databases 9 | conn_editor 10 | python-api 11 | 12 | Sqeleton 13 | --------- 14 | 15 | **Sqeleton** is a Python library for querying SQL databases. 16 | 17 | It consists of - 18 | 19 | - A fast and concise query builder, inspired by PyPika and SQLAlchemy 20 | 21 | - A modular database interface, with drivers for a long list of SQL databases. 22 | 23 | It is comparable to other libraries such as SQLAlchemy or PyPika, in terms of API and intended audience. However there are several notable ways in which it is different. 24 | 25 | For more information, `See our README `_ 26 | 27 | 28 | Resources 29 | --------- 30 | 31 | - :doc:`install` 32 | - :doc:`intro` 33 | - :doc:`supported-databases` 34 | - :doc:`conn_editor` 35 | - :doc:`python-api` 36 | - Source code (git): ``_ 37 | 38 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Install / Get started 2 | 3 | Sqeleton can be installed using pip: 4 | 5 | ``` 6 | pip install sqeleton 7 | ``` 8 | 9 | ## Database drivers 10 | 11 | To ensure that the database drivers are compatible with sqeleton, we recommend installing them along with sqeleton, using pip's `[]` syntax: 12 | 13 | - `pip install 'sqeleton[mysql]'` 14 | 15 | - `pip install 'sqeleton[postgresql]'` 16 | 17 | - `pip install 'sqeleton[snowflake]'` 18 | 19 | - `pip install 'sqeleton[presto]'` 20 | 21 | - `pip install 'sqeleton[oracle]'` 22 | 23 | - `pip install 'sqeleton[trino]'` 24 | 25 | - `pip install 'sqeleton[clickhouse]'` 26 | 27 | - `pip install 'sqeleton[vertica]'` 28 | 29 | - For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ 30 | 31 | _Some drivers have dependencies that cannot be installed using `pip` and still need to be installed manually._ 32 | 33 | 34 | It is also possible to install several databases at once. For example: 35 | 36 | ```bash 37 | pip install 'sqeleton[mysql, postgresql]' 38 | ``` 39 | 40 | Note: Some shells use `"` for escaping instead, like: 41 | 42 | ```bash 43 | pip install "sqeleton[mysql, postgresql]" 44 | ``` 45 | 46 | ### Postgresql + Debian 47 | 48 | Before installing the postgresql driver, ensure you have libpq-dev: 49 | 50 | ```bash 51 | apt-get install libpq-dev 52 | ``` 53 | 54 | ## Connection editor 55 | 56 | Sqeleton provides a TUI connection editor, that can be installed using: 57 | 58 | ```bash 59 | pip install 'sqeleton[tui]' 60 | ``` 61 | 62 | Read more [here](conn_editor.md). 63 | 64 | ## What's next? 65 | 66 | Read the [introduction](intro.md) and start coding! 67 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=sqeleton 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/python-api.rst: -------------------------------------------------------------------------------- 1 | ******************** 2 | Python API Reference 3 | ******************** 4 | 5 | .. py:module:: sqeleton 6 | 7 | 8 | User API 9 | ======== 10 | 11 | .. autofunction:: table 12 | 13 | .. autofunction:: code 14 | 15 | .. autoclass:: connect 16 | :members: 17 | 18 | .. autodata:: SKIP 19 | .. autodata:: this 20 | 21 | Database 22 | -------- 23 | 24 | .. automodule:: sqeleton.databases.base 25 | :members: 26 | 27 | 28 | Queries 29 | -------- 30 | 31 | .. automodule:: sqeleton.queries.api 32 | :members: 33 | 34 | Internals 35 | ========= 36 | 37 | This section is for developers who wish to improve sqeleton, or to extend it within their own project. 38 | 39 | Regular users might also find it useful for debugging and understanding, especially at this early stage of the project. 40 | 41 | 42 | Query ASTs 43 | ----------- 44 | 45 | .. automodule:: sqeleton.queries.ast_classes 46 | :members: 47 | 48 | Query Compiler 49 | -------------- 50 | 51 | .. automodule:: sqeleton.queries.compiler 52 | :members: 53 | 54 | ABCS 55 | ----- 56 | 57 | .. automodule:: sqeleton.abcs.database_types 58 | :members: 59 | 60 | .. automodule:: sqeleton.abcs.mixins 61 | :members: 62 | 63 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/guides/specifying-dependencies.html#specifying-a-requirements-file 2 | sphinx-gallery 3 | sphinx_markdown_tables 4 | sphinx-copybutton 5 | sphinx-rtd-theme 6 | recommonmark 7 | enum-tools[sphinx] 8 | 9 | sqeleton 10 | -------------------------------------------------------------------------------- /docs/supported-databases.md: -------------------------------------------------------------------------------- 1 | # List of supported databases 2 | 3 | | Database | Status | Connection string | 4 | |---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------| 5 | | PostgreSQL >=10 | 💚 | `postgresql://:@:5432/` | 6 | | MySQL | 💚 | `mysql://:@:5432/` | 7 | | Snowflake | 💚 | `"snowflake://[:]@//?warehouse=&role=[&authenticator=externalbrowser]"` | 8 | | BigQuery | 💚 | `bigquery:///` | 9 | | Redshift | 💚 | `redshift://:@:5439/` | 10 | | Oracle | 💛 | `oracle://:@/database` | 11 | | Presto | 💛 | `presto://:@:8080/` | 12 | | Databricks | 💛 | `databricks://:@//` | 13 | | Trino | 💛 | `trino://:@:8080/` | 14 | | Clickhouse | 💛 | `clickhouse://:@:9000/` | 15 | | Vertica | 💛 | `vertica://:@:5433/` | 16 | | DuckDB | 💛 | | 17 | | ElasticSearch | 📝 | | 18 | | Planetscale | 📝 | | 19 | | Pinot | 📝 | | 20 | | Druid | 📝 | | 21 | | Kafka | 📝 | | 22 | | SQLite | 📝 | | 23 | 24 | * 💚: Implemented and thoroughly tested. 25 | * 💛: Implemented, but not thoroughly tested yet. 26 | * ⏳: Implementation in progress. 27 | * 📝: Implementation planned. Contributions welcome. 28 | 29 | Is your database not listed here? We accept pull-requests! 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sqeleton" 3 | version = "0.1.7" 4 | description = "Python library for querying SQL databases" 5 | authors = ["Erez Shinan "] 6 | license = "MIT" 7 | readme = "README.md" 8 | repository = "https://github.com/erezsh/sqeleton" 9 | documentation = "https://sqeleton.readthedocs.io/en/latest/" 10 | classifiers = [ 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Information Technology", 13 | "Intended Audience :: System Administrators", 14 | "Programming Language :: Python :: 3.8", 15 | "Programming Language :: Python :: 3.9", 16 | "Programming Language :: Python :: 3.10", 17 | "Development Status :: 2 - Pre-Alpha", 18 | "Environment :: Console", 19 | "Topic :: Database :: Database Engines/Servers", 20 | "Typing :: Typed" 21 | ] 22 | packages = [{ include = "sqeleton" }] 23 | 24 | [tool.poetry.dependencies] 25 | python = "^3.8" 26 | runtype = ">=0.5.0" 27 | dsnparse = "*" 28 | click = ">=8.1" 29 | rich = "*" 30 | toml = ">=0.10.2" 31 | mysql-connector-python = {version=">=8.0.29", optional=true} 32 | # psycopg2 = {version="*", optional=true} 33 | psycopg2-binary = {version="*", optional=true} 34 | snowflake-connector-python = {version=">=2.7.2", optional=true} 35 | cryptography = {version="*", optional=true} 36 | trino = {version=">=0.314.0", optional=true} 37 | presto-python-client = {version="*", optional=true} 38 | clickhouse-driver = {version="*", optional=true} 39 | duckdb = {version=">=0.7.0", optional=true} 40 | textual = {version=">=0.9.1", optional=true} 41 | textual-select = {version="*", optional=true} 42 | pygments = {version=">=2.13.0", optional=true} 43 | prompt-toolkit = {version=">=3.0.36", optional=true} 44 | 45 | [tool.poetry.dev-dependencies] 46 | parameterized = "*" 47 | unittest-parallel = "*" 48 | 49 | duckdb = "*" 50 | mysql-connector-python = "*" 51 | # psycopg2 = "*" 52 | psycopg2-binary = "*" 53 | snowflake-connector-python = ">=2.7.2" 54 | cryptography = "*" 55 | trino = ">=0.314.0" 56 | presto-python-client = "*" 57 | clickhouse-driver = "*" 58 | vertica-python = "*" 59 | 60 | [tool.poetry.extras] 61 | mysql = ["mysql-connector-python"] 62 | postgresql = ["psycopg2-binary"] 63 | snowflake = ["snowflake-connector-python", "cryptography"] 64 | presto = ["presto-python-client"] 65 | oracle = ["cx_Oracle"] 66 | databricks = ["databricks-sql-connector"] 67 | trino = ["trino"] 68 | clickhouse = ["clickhouse-driver"] 69 | vertica = ["vertica-python"] 70 | duckdb = ["duckdb"] 71 | tui = ["textual", "textual-select", "pygments", "prompt-toolkit"] 72 | 73 | [build-system] 74 | requires = ["poetry-core>=1.0.0"] 75 | build-backend = "poetry.core.masonry.api" 76 | 77 | [tool.poetry.scripts] 78 | sqeleton = 'sqeleton.__main__:main' 79 | 80 | [tool.ruff] 81 | line-length = 120 82 | target-version = "py38" 83 | 84 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: all 4 | 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "3.8" 9 | 10 | python: 11 | install: 12 | - requirements: docs/requirements.txt 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | -------------------------------------------------------------------------------- /sqeleton/__init__.py: -------------------------------------------------------------------------------- 1 | from .databases import connect 2 | from .queries import table, this, SKIP, code, commit 3 | 4 | __version__ = "0.1.7" 5 | -------------------------------------------------------------------------------- /sqeleton/__main__.py: -------------------------------------------------------------------------------- 1 | import click 2 | from .repl import repl as repl_main 3 | 4 | 5 | @click.group(no_args_is_help=True) 6 | def main(): 7 | pass 8 | 9 | 10 | @main.command(no_args_is_help=True) 11 | @click.argument("database", required=True) 12 | def repl(database): 13 | return repl_main(database) 14 | 15 | 16 | CONN_EDITOR_HELP = """CONFIG_PATH - Path to a TOML config file of db connections, new or existing.""" 17 | 18 | 19 | @main.command(no_args_is_help=True, help=CONN_EDITOR_HELP) 20 | @click.argument("config_path", required=True) 21 | def conn_editor(config_path): 22 | from .conn_editor import main 23 | 24 | return main(config_path) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /sqeleton/abcs/__init__.py: -------------------------------------------------------------------------------- 1 | from .database_types import ( 2 | AbstractDatabase, 3 | AbstractDialect, 4 | DbKey, 5 | DbPath, 6 | DbTime, 7 | IKey, 8 | ColType_UUID, 9 | NumericType, 10 | PrecisionType, 11 | StringType, 12 | Boolean, 13 | Text, 14 | ) 15 | from .compiler import AbstractCompiler, Compilable 16 | -------------------------------------------------------------------------------- /sqeleton/abcs/compiler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class AbstractCompiler(ABC): 6 | @abstractmethod 7 | def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: 8 | ... 9 | 10 | 11 | class Compilable(ABC): 12 | "Marks an item as compilable. Needs to be implemented through multiple-dispatch." 13 | -------------------------------------------------------------------------------- /sqeleton/abcs/mixins.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID 3 | from .compiler import Compilable 4 | 5 | 6 | class AbstractMixin(ABC): 7 | "A mixin for a database dialect" 8 | 9 | 10 | class AbstractMixin_NormalizeValue(AbstractMixin): 11 | @abstractmethod 12 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 13 | """Creates an SQL expression, that converts 'value' to a normalized timestamp. 14 | 15 | The returned expression must accept any SQL datetime/timestamp, and return a string. 16 | 17 | Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` 18 | 19 | Precision of dates should be rounded up/down according to coltype.rounds 20 | """ 21 | 22 | @abstractmethod 23 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 24 | """Creates an SQL expression, that converts 'value' to a normalized number. 25 | 26 | The returned expression must accept any SQL int/numeric/float, and return a string. 27 | 28 | Floats/Decimals are expected in the format 29 | "I.P" 30 | 31 | Where I is the integer part of the number (as many digits as necessary), 32 | and must be at least one digit (0). 33 | P is the fractional digits, the amount of which is specified with 34 | coltype.precision. Trailing zeroes may be necessary. 35 | If P is 0, the dot is omitted. 36 | 37 | Note: We use 'precision' differently than most databases. For decimals, 38 | it's the same as ``numeric_scale``, and for floats, who use binary precision, 39 | it can be calculated as ``log10(2**numeric_precision)``. 40 | """ 41 | 42 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 43 | """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" 44 | return self.to_string(value) 45 | 46 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 47 | """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" 48 | if isinstance(coltype, String_UUID): 49 | return f"TRIM({value})" 50 | return self.to_string(value) 51 | 52 | def normalize_value_by_type(self, value: str, coltype: ColType) -> str: 53 | """Creates an SQL expression, that converts 'value' to a normalized representation. 54 | 55 | The returned expression must accept any SQL value, and return a string. 56 | 57 | The default implementation dispatches to a method according to `coltype`: 58 | 59 | :: 60 | 61 | TemporalType -> normalize_timestamp() 62 | FractionalType -> normalize_number() 63 | *else* -> to_string() 64 | 65 | (`Integer` falls in the *else* category) 66 | 67 | """ 68 | if isinstance(coltype, TemporalType): 69 | return self.normalize_timestamp(value, coltype) 70 | elif isinstance(coltype, FractionalType): 71 | return self.normalize_number(value, coltype) 72 | elif isinstance(coltype, ColType_UUID): 73 | return self.normalize_uuid(value, coltype) 74 | elif isinstance(coltype, Boolean): 75 | return self.normalize_boolean(value, coltype) 76 | return self.to_string(value) 77 | 78 | 79 | class AbstractMixin_MD5(AbstractMixin): 80 | """Methods for calculating an MD6 hash as an integer.""" 81 | 82 | @abstractmethod 83 | def md5_as_int(self, s: str) -> str: 84 | "Provide SQL for computing md5 and returning an int" 85 | 86 | 87 | class AbstractMixin_Schema(AbstractMixin): 88 | """Methods for querying the database schema 89 | 90 | TODO: Move AbstractDatabase.query_table_schema() and friends over here 91 | """ 92 | 93 | def table_information(self) -> Compilable: 94 | "Query to return a table of schema information about existing tables" 95 | raise NotImplementedError() 96 | 97 | @abstractmethod 98 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 99 | """Query to select the list of tables in the schema. (query return type: table[str]) 100 | 101 | If 'like' is specified, the value is applied to the table name, using the 'like' operator. 102 | """ 103 | 104 | 105 | class AbstractMixin_Regex(AbstractMixin): 106 | @abstractmethod 107 | def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: 108 | """Tests whether the regex pattern matches the string. Returns a bool expression.""" 109 | 110 | 111 | class AbstractMixin_RandomSample(AbstractMixin): 112 | @abstractmethod 113 | def random_sample_n(self, tbl: str, size: int) -> str: 114 | """Take a random sample of the given size, i.e. return 'size' amount of rows""" 115 | 116 | @abstractmethod 117 | def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str: 118 | """Take a random sample of the approximate size determined by the ratio (0..1), 119 | ratio of 0 means no rows, and 1 means all rows 120 | 121 | i.e. the actual mount of rows returned may vary by standard deviation. 122 | """ 123 | 124 | # def random_sample_ratio(self, table: AbstractTable, ratio: float): 125 | # """Take a random sample of the size determined by the ratio (0..1), where 0 means no rows, and 1 means all rows 126 | # """ 127 | 128 | 129 | class AbstractMixin_TimeTravel(AbstractMixin): 130 | @abstractmethod 131 | def time_travel( 132 | self, 133 | table: Compilable, 134 | before: bool = False, 135 | timestamp: Compilable = None, 136 | offset: Compilable = None, 137 | statement: Compilable = None, 138 | ) -> Compilable: 139 | """Selects historical data from a table 140 | 141 | Parameters: 142 | table - The name of the table whose history we're querying 143 | timestamp - A constant timestamp 144 | offset - the time 'offset' seconds before now 145 | statement - identifier for statement, e.g. query ID 146 | 147 | Must specify exactly one of `timestamp`, `offset` or `statement`. 148 | """ 149 | 150 | 151 | class AbstractMixin_OptimizerHints(AbstractMixin): 152 | @abstractmethod 153 | def optimizer_hints(self, optimizer_hints: str) -> str: 154 | """Creates a compatible optimizer_hints string 155 | 156 | Parameters: 157 | optimizer_hints - string of optimizer hints 158 | """ 159 | -------------------------------------------------------------------------------- /sqeleton/bound_exprs.py: -------------------------------------------------------------------------------- 1 | """Expressions bound to a specific database""" 2 | 3 | import inspect 4 | from functools import wraps 5 | from typing import Union, TYPE_CHECKING 6 | 7 | from runtype import dataclass 8 | 9 | from .abcs import AbstractDatabase, AbstractCompiler 10 | from .queries.ast_classes import ExprNode, TablePath, Compilable 11 | from .queries.api import table 12 | from .schema import create_schema 13 | 14 | 15 | @dataclass 16 | class BoundNode(ExprNode): 17 | database: AbstractDatabase 18 | node: Compilable 19 | 20 | def __getattr__(self, attr): 21 | value = getattr(self.node, attr) 22 | if inspect.ismethod(value): 23 | 24 | @wraps(value) 25 | def bound_method(*args, **kw): 26 | return BoundNode(self.database, value(*args, **kw)) 27 | 28 | return bound_method 29 | return value 30 | 31 | def query(self, res_type=list): 32 | return self.database.query(self.node, res_type=res_type) 33 | 34 | @property 35 | def type(self): 36 | return self.node.type 37 | 38 | def compile(self, c: AbstractCompiler) -> str: 39 | assert c.database is self.database 40 | return c.compile_elem(self.node) 41 | 42 | 43 | def bind_node(node, database): 44 | return BoundNode(database, node) 45 | 46 | 47 | ExprNode.bind = bind_node 48 | 49 | 50 | @dataclass 51 | class BoundTable(BoundNode): # ITable 52 | database: AbstractDatabase 53 | node: TablePath 54 | 55 | def with_schema(self, schema): 56 | table_path = self.node.replace(schema=schema) 57 | return self.replace(node=table_path) 58 | 59 | def query_schema(self, *, refine: bool = True, refine_where = None, case_sensitive=True): 60 | table_path = self.node 61 | 62 | if table_path.schema: 63 | return self 64 | 65 | raw_schema = self.database.query_table_schema(table_path.path) 66 | schema, _samples = self.database.process_query_table_schema(table_path.path, raw_schema, refine, refine_where) 67 | schema = create_schema(self.database, table_path.path, schema, case_sensitive) 68 | return self.with_schema(schema) 69 | 70 | @property 71 | def schema(self): 72 | return self.node.schema 73 | 74 | 75 | def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tuple], **kw): 76 | return BoundTable(database, table(table_path, **kw)) 77 | 78 | 79 | if TYPE_CHECKING: 80 | class BoundTable(BoundTable, TablePath): 81 | pass 82 | -------------------------------------------------------------------------------- /sqeleton/conn_editor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | from concurrent.futures import ThreadPoolExecutor 4 | import toml 5 | 6 | from runtype import dataclass 7 | 8 | try: 9 | import textual 10 | except ModuleNotFoundError: 11 | raise ModuleNotFoundError( 12 | "Error: Cannot find the TUI library 'textual'. Please install it using `pip install sqeleton[tui]`" 13 | ) 14 | 15 | from textual.app import App, ComposeResult 16 | from textual.containers import Vertical, Container, Horizontal 17 | from textual.widgets import Header, Footer, Input, Label, Button, ListView, ListItem 18 | 19 | from textual_select import Select 20 | 21 | from .databases._connect import DATABASE_BY_SCHEME 22 | from . import connect 23 | 24 | 25 | class ContentSwapper(Container): 26 | def __init__(self, *initial_widgets, **kw): 27 | self.container = Container(*initial_widgets) 28 | super().__init__(**kw) 29 | 30 | def compose(self): 31 | yield self.container 32 | 33 | def new_content(self, *new_widgets): 34 | for c in self.container.children: 35 | c.remove() 36 | 37 | self.container.mount(*new_widgets) 38 | 39 | 40 | def test_connect(connect_dict): 41 | conn = connect(connect_dict) 42 | assert conn.query("select 1+1", int) == 2 43 | 44 | 45 | @dataclass 46 | class Config: 47 | config: dict 48 | 49 | @property 50 | def databases(self): 51 | if "database" not in self.config: 52 | self.config["database"] = {} 53 | assert isinstance(self.config["database"], dict) 54 | return self.config["database"] 55 | 56 | def add_database(self, name: str, db: dict): 57 | assert isinstance(db, dict) 58 | self.config["database"][name] = db 59 | 60 | def remove_database(self, name): 61 | del self.config["database"][name] 62 | 63 | 64 | class EditConnection(Vertical): 65 | def __init__(self, db_name: str, config: Config) -> None: 66 | self._db_name = db_name 67 | self._config = config 68 | super().__init__() 69 | 70 | def compose(self): 71 | self.params_container = ContentSwapper(id="params") 72 | self.driver_select = Select( 73 | placeholder="Select a database driver", 74 | search=True, 75 | items=[{"value": k, "text": k} for k in DATABASE_BY_SCHEME], 76 | list_mount="#driver_container", 77 | value=self._config.databases.get(self._db_name, {}).get("driver"), 78 | id="driver", 79 | ) 80 | 81 | yield Vertical( 82 | Label("Connection name:"), 83 | Input(id="conn_name", value=self._db_name, classes="margin-bottom-1"), 84 | Label("Database Driver:"), 85 | self.driver_select, 86 | self.params_container, 87 | id="driver_container", 88 | ) 89 | 90 | self.create_params() 91 | 92 | def create_params(self): 93 | driver = self.driver_select.value 94 | if not driver: 95 | return 96 | db_config = self._config.databases.get(self._db_name, {}) 97 | 98 | db_cls = DATABASE_BY_SCHEME[driver] 99 | # Filter out repetitions, but keep order 100 | base_params = re.findall("<(.*?)>", db_cls.CONNECT_URI_HELP) 101 | params = dict.fromkeys([p.lower() for p in base_params] + [k for k in db_config if k != "driver"]) 102 | 103 | widgets = [] 104 | for p in params: 105 | label = p 106 | if p in ("user", "pass", "port", "dbname"): 107 | label += " (optional)" 108 | widgets.append(Label(f"{label}:")) 109 | p = p.lower() 110 | widgets.append(Input(id="input_" + p, name=p, classes="param", value=db_config.get(p))) 111 | 112 | self.params_container.new_content( 113 | Vertical( 114 | *widgets, 115 | Horizontal( 116 | Button("Test & Save Connection", id="test_and_save"), 117 | Button("Save Connection Without Testing", id="save_without_test"), 118 | ), 119 | Vertical(id="result"), 120 | ) 121 | ) 122 | 123 | def on_select_changed(self, event: Select.Changed) -> None: 124 | driver = str(event.value) 125 | self.driver_select.text = driver 126 | self.driver_select.value = driver 127 | self.create_params() 128 | 129 | def _get_connect_dict(self): 130 | connect_dict = {"driver": self.driver_select.value} 131 | for p in self.query(".param"): 132 | if p.value: 133 | connect_dict[p.name] = p.value 134 | 135 | return connect_dict 136 | 137 | async def on_button_pressed(self, event: Button.Pressed) -> None: 138 | button_id = event.button.id 139 | if button_id == "test_and_save": 140 | connect_dict = self._get_connect_dict() 141 | 142 | result_container = self.query_one("#result") 143 | result_container.mount(Label(f"Trying to connect to {connect_dict}")) 144 | 145 | try: 146 | test_connect(connect_dict) 147 | except Exception as e: 148 | error = str(e) 149 | result_container.mount(Label(f"[red]Error: {error}[/red]")) 150 | else: 151 | result_container.mount(Label(f"[green]Success![green]")) 152 | self.save(connect_dict) 153 | 154 | elif button_id == "save_without_test": 155 | connect_dict = self._get_connect_dict() 156 | self.save(connect_dict) 157 | 158 | def save(self, connect_dict): 159 | assert isinstance(connect_dict, dict) 160 | result_container = self.query_one("#result") 161 | result_container.mount(Label(f"Database saved")) 162 | 163 | name = self.query_one("#conn_name").value 164 | self._config.add_database(name, connect_dict) 165 | self.app.config_changed() 166 | 167 | 168 | class ListOfConnections(Vertical): 169 | def __init__(self, config: Config, **kw): 170 | self.config = config 171 | super().__init__(**kw) 172 | 173 | def compose(self) -> ComposeResult: 174 | list_items = [ 175 | ListItem(Label(name, id="list_label_" + name), name=name) for name in self.config.databases.keys() 176 | ] 177 | 178 | self.lv = lv = ListView(*list_items, id="connection_list") 179 | yield lv 180 | 181 | lv.focus() 182 | 183 | def on_list_view_highlighted(self, h: ListView.Highlighted): 184 | name = h.item.name 185 | self.app.query_one("#edit_container").new_content(EditConnection(name, self.config)) 186 | 187 | 188 | class ConnectionEditor(App): 189 | CSS = """ 190 | #conn_list { 191 | display: block; 192 | dock: left; 193 | height: 100%; 194 | max-width: 30%; 195 | margin: 1 196 | } 197 | 198 | #edit { 199 | margin: 1 200 | } 201 | 202 | #test_and_save { 203 | margin: 1 204 | } 205 | #save_without_test { 206 | margin: 1 207 | } 208 | 209 | .margin-bottom-1 { 210 | margin-bottom: 1 211 | } 212 | """ 213 | 214 | BINDINGS = [ 215 | ("q", "quit", "Quit"), 216 | ("d", "toggle_dark", "Toggle dark mode"), 217 | ("insert", "add_conn", "Add new connection"), 218 | ("delete", "del_conn", "Delete selected connection"), 219 | ("t", "test_conn", "Test selected connection"), 220 | ("a", "test_all_conns", "Test all connections"), 221 | ] 222 | 223 | def run(self, toml_path, **kw): 224 | self.toml_path = toml_path 225 | try: 226 | with open(self.toml_path) as f: 227 | self.config = Config(toml.load(f)) 228 | except FileNotFoundError: 229 | self.config = Config({}) 230 | 231 | return super().run(**kw) 232 | 233 | def config_changed(self): 234 | self.list_swapper.new_content(ListOfConnections(self.config)) 235 | with open(self.toml_path, "w") as f: 236 | toml.dump(self.config.config, f) 237 | 238 | def compose(self) -> ComposeResult: 239 | """Create child widgets for the app.""" 240 | self.list_swapper = ContentSwapper(ListOfConnections(self.config), id="conn_list") 241 | self.edit_swapper = ContentSwapper(id="edit_container") 242 | self.edit_swapper.new_content(EditConnection("New", self.config)) 243 | 244 | yield Header() 245 | yield Container(self.list_swapper, self.edit_swapper) 246 | yield Footer() 247 | 248 | def action_toggle_dark(self) -> None: 249 | """An action to toggle dark mode.""" 250 | self.dark = not self.dark 251 | 252 | def action_add_conn(self): 253 | self.edit_swapper.new_content(EditConnection("New", self.config)) 254 | 255 | def _selected_connection(self): 256 | connection_list: ListView = self.query_one("#connection_list") 257 | return connection_list.children[connection_list.index] 258 | 259 | def action_del_conn(self): 260 | name = self._selected_connection().name 261 | self.config.remove_database(name) 262 | self.config_changed() 263 | 264 | def _test_existing_db(self, name): 265 | label: Label = self.query_one("#list_label_" + name) 266 | label.update(f"{name}🔃") 267 | connect_dict = self.config.databases[name] 268 | try: 269 | test_connect(connect_dict) 270 | except Exception as e: 271 | label.update(f"{name}❌ {str(e)[:16]}") 272 | else: 273 | label.update(f"{name}✅") 274 | 275 | def action_test_conn(self): 276 | name = self._selected_connection().name 277 | t = ThreadPoolExecutor() 278 | t.submit(self._test_existing_db, name) 279 | 280 | def action_test_all_conns(self): 281 | t = ThreadPoolExecutor() 282 | for name in self.config.databases: 283 | t.submit(self._test_existing_db, name) 284 | 285 | 286 | def main(toml_path: str): 287 | app = ConnectionEditor() 288 | app.run(toml_path) 289 | 290 | 291 | if __name__ == "__main__": 292 | main(sys.argv[1]) 293 | -------------------------------------------------------------------------------- /sqeleton/databases/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database, logger 2 | from ..abcs import DbPath, DbKey, DbTime 3 | from ._connect import Connect 4 | 5 | from .postgresql import PostgreSQL 6 | from .mysql import MySQL 7 | from .oracle import Oracle 8 | from .snowflake import Snowflake 9 | from .bigquery import BigQuery 10 | from .redshift import Redshift 11 | from .presto import Presto 12 | from .databricks import Databricks 13 | from .trino import Trino 14 | from .clickhouse import Clickhouse 15 | from .vertica import Vertica 16 | from .duckdb import DuckDB 17 | 18 | connect = Connect() 19 | -------------------------------------------------------------------------------- /sqeleton/databases/_connect.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | from typing import Type, Optional, Union, Dict 3 | from itertools import zip_longest 4 | from contextlib import suppress 5 | import dsnparse 6 | import toml 7 | 8 | from runtype import dataclass 9 | 10 | from ..abcs.mixins import AbstractMixin 11 | from ..utils import WeakCache, Self 12 | from .base import Database, ThreadedDatabase 13 | from .postgresql import PostgreSQL 14 | from .mysql import MySQL 15 | from .oracle import Oracle 16 | from .snowflake import Snowflake 17 | from .bigquery import BigQuery 18 | from .redshift import Redshift 19 | from .presto import Presto 20 | from .databricks import Databricks 21 | from .trino import Trino 22 | from .clickhouse import Clickhouse 23 | from .vertica import Vertica 24 | from .duckdb import DuckDB 25 | 26 | 27 | @dataclass 28 | class MatchUriPath: 29 | database_cls: Type[Database] 30 | 31 | def match_path(self, dsn): 32 | help_str = self.database_cls.CONNECT_URI_HELP 33 | params = self.database_cls.CONNECT_URI_PARAMS 34 | kwparams = self.database_cls.CONNECT_URI_KWPARAMS 35 | 36 | dsn_dict = dict(urllib.parse.parse_qsl(dsn.query)) 37 | matches = {} 38 | for param, arg in zip_longest(params, dsn.paths): 39 | if param is None: 40 | raise ValueError(f"Too many parts to path. Expected format: {help_str}") 41 | 42 | optional = param.endswith("?") 43 | param = param.rstrip("?") 44 | 45 | if arg is None: 46 | try: 47 | arg = dsn_dict.pop(param) 48 | except KeyError: 49 | if not optional: 50 | raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") 51 | 52 | arg = None 53 | 54 | assert param and param not in matches 55 | matches[param] = arg 56 | 57 | for param in kwparams: 58 | try: 59 | arg = dsn_dict.pop(param) 60 | except KeyError: 61 | raise ValueError(f"URI must specify '{param}'. Expected format: {help_str}") 62 | 63 | assert param and arg and param not in matches, (param, arg, matches.keys()) 64 | matches[param] = arg 65 | 66 | for param, value in dsn_dict.items(): 67 | if param in matches: 68 | raise ValueError( 69 | f"Parameter '{param}' already provided as positional argument. Expected format: {help_str}" 70 | ) 71 | 72 | matches[param] = value 73 | 74 | return matches 75 | 76 | 77 | DATABASE_BY_SCHEME = { 78 | "postgresql": PostgreSQL, 79 | "mysql": MySQL, 80 | "oracle": Oracle, 81 | "redshift": Redshift, 82 | "snowflake": Snowflake, 83 | "presto": Presto, 84 | "bigquery": BigQuery, 85 | "databricks": Databricks, 86 | "duckdb": DuckDB, 87 | "trino": Trino, 88 | "clickhouse": Clickhouse, 89 | "vertica": Vertica, 90 | } 91 | 92 | 93 | class Connect: 94 | """Provides methods for connecting to a supported database using a URL or connection dict.""" 95 | 96 | def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): 97 | self.database_by_scheme = database_by_scheme 98 | self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} 99 | self.conn_cache = WeakCache() 100 | 101 | def for_databases(self, *dbs): 102 | database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} 103 | return type(self)(database_by_scheme) 104 | 105 | def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self: 106 | "Extend all the databases with a list of mixins that implement the given abstract mixins." 107 | database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()} 108 | return type(self)(database_by_scheme) 109 | 110 | def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database: 111 | """Connect to the given database uri 112 | 113 | thread_count determines the max number of worker threads per database, 114 | if relevant. None means no limit. 115 | 116 | Parameters: 117 | db_uri (str): The URI for the database to connect 118 | thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) 119 | 120 | Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. 121 | 122 | Supported schemes: 123 | - postgresql 124 | - mysql 125 | - oracle 126 | - snowflake 127 | - bigquery 128 | - redshift 129 | - presto 130 | - databricks 131 | - trino 132 | - clickhouse 133 | - vertica 134 | - duckdb 135 | """ 136 | 137 | dsn = dsnparse.parse(db_uri) 138 | if len(dsn.schemes) > 1: 139 | raise NotImplementedError("No support for multiple schemes") 140 | (scheme,) = dsn.schemes 141 | 142 | if scheme == "toml": 143 | toml_path = dsn.path or dsn.host 144 | database = dsn.fragment 145 | if not database: 146 | raise ValueError("Must specify a database name, e.g. 'toml://path#database'. ") 147 | with open(toml_path) as f: 148 | config = toml.load(f) 149 | try: 150 | conn_dict = config["database"][database] 151 | except KeyError: 152 | raise ValueError(f"Cannot find database config named '{database}'.") 153 | return self.connect_with_dict(conn_dict, thread_count) 154 | 155 | try: 156 | matcher = self.match_uri_path[scheme] 157 | except KeyError: 158 | raise NotImplementedError(f"Scheme '{scheme}' currently not supported") 159 | 160 | cls = matcher.database_cls 161 | 162 | if scheme == "databricks": 163 | assert not dsn.user 164 | kw = {} 165 | kw["access_token"] = dsn.password 166 | kw["http_path"] = dsn.path 167 | kw["server_hostname"] = dsn.host 168 | kw.update(dsn.query) 169 | elif scheme == "duckdb": 170 | kw = {} 171 | kw["filepath"] = dsn.dbname 172 | kw["dbname"] = dsn.user 173 | else: 174 | kw = matcher.match_path(dsn) 175 | 176 | if scheme == "bigquery": 177 | kw["project"] = dsn.host 178 | return cls(**kw) 179 | 180 | if scheme == "snowflake": 181 | kw["account"] = dsn.host 182 | assert not dsn.port 183 | kw["user"] = dsn.user 184 | kw["password"] = dsn.password 185 | else: 186 | kw["host"] = dsn.host 187 | kw["port"] = dsn.port 188 | kw["user"] = dsn.user 189 | if dsn.password: 190 | kw["password"] = dsn.password 191 | 192 | kw = {k: v for k, v in kw.items() if v is not None} 193 | 194 | if issubclass(cls, ThreadedDatabase): 195 | db = cls(thread_count=thread_count, **kw) 196 | else: 197 | db = cls(**kw) 198 | 199 | return self._connection_created(db) 200 | 201 | def connect_with_dict(self, d, thread_count): 202 | d = dict(d) 203 | driver = d.pop("driver") 204 | try: 205 | matcher = self.match_uri_path[driver] 206 | except KeyError: 207 | raise NotImplementedError(f"Driver '{driver}' currently not supported") 208 | 209 | cls = matcher.database_cls 210 | if issubclass(cls, ThreadedDatabase): 211 | db = cls(thread_count=thread_count, **d) 212 | else: 213 | db = cls(**d) 214 | 215 | return self._connection_created(db) 216 | 217 | def _connection_created(self, db): 218 | "Nop function to be overridden by subclasses." 219 | return db 220 | 221 | def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database: 222 | """Connect to a database using the given database configuration. 223 | 224 | Configuration can be given either as a URI string, or as a dict of {option: value}. 225 | 226 | The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. 227 | 228 | thread_count determines the max number of worker threads per database, 229 | if relevant. None means no limit. 230 | 231 | Parameters: 232 | db_conf (str | dict): The configuration for the database to connect. URI or dict. 233 | thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) 234 | shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True) 235 | 236 | Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. 237 | 238 | Supported drivers: 239 | - postgresql 240 | - mysql 241 | - oracle 242 | - snowflake 243 | - bigquery 244 | - redshift 245 | - presto 246 | - databricks 247 | - trino 248 | - clickhouse 249 | - vertica 250 | 251 | Example: 252 | >>> connect("mysql://localhost/db") 253 | 254 | >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) 255 | 256 | """ 257 | if shared: 258 | with suppress(KeyError): 259 | conn = self.conn_cache.get(db_conf) 260 | if not conn.is_closed: 261 | return conn 262 | 263 | if isinstance(db_conf, str): 264 | conn = self.connect_to_uri(db_conf, thread_count) 265 | elif isinstance(db_conf, dict): 266 | conn = self.connect_with_dict(db_conf, thread_count) 267 | else: 268 | raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") 269 | 270 | if shared: 271 | self.conn_cache.add(db_conf, conn) 272 | return conn 273 | -------------------------------------------------------------------------------- /sqeleton/databases/bigquery.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from ..abcs.database_types import ( 3 | Timestamp, 4 | Datetime, 5 | Integer, 6 | Decimal, 7 | Float, 8 | Text, 9 | DbPath, 10 | FractionalType, 11 | TemporalType, 12 | Boolean, 13 | ) 14 | from ..abcs.mixins import ( 15 | AbstractMixin_MD5, 16 | AbstractMixin_NormalizeValue, 17 | AbstractMixin_Schema, 18 | AbstractMixin_TimeTravel, 19 | ) 20 | from ..abcs import Compilable 21 | from ..queries import this, table, SKIP, code 22 | from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query, QueryResult 23 | from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter, Mixin_RandomSample 24 | 25 | 26 | @import_helper(text="Please install BigQuery and configure your google-cloud access.") 27 | def import_bigquery(): 28 | from google.cloud import bigquery 29 | 30 | return bigquery 31 | 32 | 33 | class Mixin_MD5(AbstractMixin_MD5): 34 | def md5_as_int(self, s: str) -> str: 35 | return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" 36 | 37 | 38 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 39 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 40 | if coltype.rounds: 41 | timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" 42 | return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" 43 | 44 | if coltype.precision == 0: 45 | return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000', {value})" 46 | elif coltype.precision == 6: 47 | return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" 48 | 49 | timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" 50 | return ( 51 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 52 | ) 53 | 54 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 55 | return f"format('%.{coltype.precision}f', {value})" 56 | 57 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 58 | return self.to_string(f"cast({value} as int)") 59 | 60 | 61 | class Mixin_Schema(AbstractMixin_Schema): 62 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 63 | return ( 64 | table(table_schema, "INFORMATION_SCHEMA", "TABLES") 65 | .where( 66 | this.table_schema == table_schema, 67 | this.table_name.like(like) if like is not None else SKIP, 68 | this.table_type == "BASE TABLE", 69 | ) 70 | .select(this.table_name) 71 | ) 72 | 73 | 74 | class Mixin_TimeTravel(AbstractMixin_TimeTravel): 75 | def time_travel( 76 | self, 77 | table: Compilable, 78 | before: bool = False, 79 | timestamp: Compilable = None, 80 | offset: Compilable = None, 81 | statement: Compilable = None, 82 | ) -> Compilable: 83 | if before: 84 | raise NotImplementedError("before=True not supported for BigQuery time-travel") 85 | 86 | if statement is not None: 87 | raise NotImplementedError("BigQuery time-travel doesn't support querying by statement id") 88 | 89 | if timestamp is not None: 90 | assert offset is None 91 | return code("{table} FOR SYSTEM_TIME AS OF {timestamp}", table=table, timestamp=timestamp) 92 | 93 | assert offset is not None 94 | return code( 95 | "{table} FOR SYSTEM_TIME AS OF TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {offset} HOUR);", 96 | table=table, 97 | offset=offset, 98 | ) 99 | 100 | 101 | class Dialect(BaseDialect, Mixin_Schema): 102 | name = "BigQuery" 103 | ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation 104 | TYPE_CLASSES = { 105 | # Dates 106 | "TIMESTAMP": Timestamp, 107 | "DATETIME": Datetime, 108 | # Numbers 109 | "INT64": Integer, 110 | "INT32": Integer, 111 | "NUMERIC": Decimal, 112 | "BIGNUMERIC": Decimal, 113 | "FLOAT64": Float, 114 | "FLOAT32": Float, 115 | # Text 116 | "STRING": Text, 117 | # Boolean 118 | "BOOL": Boolean, 119 | } 120 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} 121 | 122 | def random(self) -> str: 123 | return "RAND()" 124 | 125 | def quote(self, s: str): 126 | return f"`{s}`" 127 | 128 | def to_string(self, s: str): 129 | return f"cast({s} as string)" 130 | 131 | def type_repr(self, t) -> str: 132 | try: 133 | return {str: "STRING", float: "FLOAT64"}[t] 134 | except KeyError: 135 | return super().type_repr(t) 136 | 137 | def set_timezone_to_utc(self) -> str: 138 | raise NotImplementedError() 139 | 140 | 141 | class BigQuery(Database): 142 | CONNECT_URI_HELP = "bigquery:///" 143 | CONNECT_URI_PARAMS = ["dataset"] 144 | dialect = Dialect() 145 | 146 | def __init__(self, project, *, dataset, **kw): 147 | bigquery = import_bigquery() 148 | 149 | self._client = bigquery.Client(project, **kw) 150 | self.project = project 151 | self.dataset = dataset 152 | 153 | self.default_schema = dataset 154 | 155 | def _normalize_returned_value(self, value): 156 | if isinstance(value, bytes): 157 | return value.decode() 158 | return value 159 | 160 | def _query_atom(self, sql_code: str): 161 | from google.cloud import bigquery 162 | 163 | try: 164 | result = self._client.query(sql_code).result() 165 | columns = [c.name for c in result.schema] 166 | rows = list(result) 167 | except Exception as e: 168 | msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s" 169 | raise ConnectError(msg % (sql_code, e)) 170 | 171 | if rows and isinstance(rows[0], bigquery.table.Row): 172 | rows = [tuple(self._normalize_returned_value(v) for v in row.values()) for row in rows] 173 | return QueryResult(rows, columns) 174 | 175 | def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult: 176 | return apply_query(self._query_atom, sql_code) 177 | 178 | def close(self): 179 | super().close() 180 | self._client.close() 181 | 182 | def select_table_schema(self, path: DbPath) -> str: 183 | project, schema, name = self._normalize_table_path(path) 184 | return ( 185 | "SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale " 186 | f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS " 187 | f"WHERE table_name = '{name}' AND table_schema = '{schema}'" 188 | ) 189 | 190 | def query_table_unique_columns(self, path: DbPath) -> List[str]: 191 | return [] 192 | 193 | def _normalize_table_path(self, path: DbPath) -> DbPath: 194 | if len(path) == 0: 195 | raise ValueError(f"{self.name}: Bad table path for {self}: ()") 196 | elif len(path) == 1: 197 | return (self.project, self.default_schema, path[0]) 198 | elif len(path) == 2: 199 | return (self.project,) + path 200 | elif len(path) == 3: 201 | return path 202 | else: 203 | raise ValueError( 204 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" 205 | ) 206 | 207 | def parse_table_name(self, name: str) -> DbPath: 208 | path = parse_table_name(name) 209 | return tuple(i for i in self._normalize_table_path(path) if i is not None) 210 | 211 | @property 212 | def is_autocommit(self) -> bool: 213 | return True 214 | -------------------------------------------------------------------------------- /sqeleton/databases/clickhouse.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | from .base import ( 4 | MD5_HEXDIGITS, 5 | CHECKSUM_HEXDIGITS, 6 | TIMESTAMP_PRECISION_POS, 7 | BaseDialect, 8 | ThreadedDatabase, 9 | import_helper, 10 | ConnectError, 11 | Mixin_RandomSample, 12 | ) 13 | from ..abcs.database_types import ( 14 | ColType, 15 | Decimal, 16 | Float, 17 | Integer, 18 | FractionalType, 19 | Native_UUID, 20 | TemporalType, 21 | Text, 22 | Timestamp, 23 | Boolean, 24 | ) 25 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 26 | 27 | # https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database 28 | DEFAULT_DATABASE = "default" 29 | 30 | 31 | @import_helper("clickhouse") 32 | def import_clickhouse(): 33 | import clickhouse_driver 34 | 35 | return clickhouse_driver 36 | 37 | 38 | class Mixin_MD5(AbstractMixin_MD5): 39 | def md5_as_int(self, s: str) -> str: 40 | substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS 41 | return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" 42 | 43 | 44 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 45 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 46 | # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. 47 | # For example: 48 | # select toString(toDecimal128(1.10, 2)); -- the result is 1.1 49 | # select toString(toDecimal128(1.00, 2)); -- the result is 1 50 | # So, we should use some custom approach to save these trailing zeros. 51 | # To avoid it, we can add a small value like 0.000001 to prevent dropping of zeros from the end when casting. 52 | # For examples above it looks like: 53 | # select toString(toDecimal128(1.10, 2 + 1) + toDecimal128(0.001, 3)); -- the result is 1.101 54 | # After that, cut an extra symbol from the string, i.e. 1.101 -> 1.10 55 | # So, the algorithm is: 56 | # 1. Cast to decimal with precision + 1 57 | # 2. Add a small value 10^(-precision-1) 58 | # 3. Cast the result to string 59 | # 4. Drop the extra digit from the string. To do that, we need to slice the string 60 | # with length = digits in an integer part + 1 (symbol of ".") + precision 61 | 62 | if coltype.precision == 0: 63 | return self.to_string(f"round({value})") 64 | 65 | precision = coltype.precision 66 | # TODO: too complex, is there better performance way? 67 | value = f""" 68 | if({value} >= 0, '', '-') || left( 69 | toString( 70 | toDecimal128( 71 | round(abs({value}), {precision}), 72 | {precision} + 1 73 | ) 74 | + 75 | toDecimal128( 76 | exp10(-{precision + 1}), 77 | {precision} + 1 78 | ) 79 | ), 80 | toUInt8( 81 | greatest( 82 | floor(log10(abs({value}))) + 1, 83 | 1 84 | ) 85 | ) + 1 + {precision} 86 | ) 87 | """ 88 | return value 89 | 90 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 91 | prec = coltype.precision 92 | if coltype.rounds: 93 | timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)" 94 | return self.to_string(timestamp) 95 | 96 | fractional = f"toUnixTimestamp64Micro(toDateTime64({value}, {prec})) % 1000000" 97 | fractional = f"lpad({self.to_string(fractional)}, 6, '0')" 98 | value = f"formatDateTime({value}, '%Y-%m-%d %H:%M:%S') || '.' || {self.to_string(fractional)}" 99 | return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" 100 | 101 | 102 | class Dialect(BaseDialect): 103 | name = "Clickhouse" 104 | ROUNDS_ON_PREC_LOSS = False 105 | ARG_SYMBOL = None # TODO Clickhouse only supports named parameters, not positional 106 | TYPE_CLASSES = { 107 | "Int8": Integer, 108 | "Int16": Integer, 109 | "Int32": Integer, 110 | "Int64": Integer, 111 | "Int128": Integer, 112 | "Int256": Integer, 113 | "UInt8": Integer, 114 | "UInt16": Integer, 115 | "UInt32": Integer, 116 | "UInt64": Integer, 117 | "UInt128": Integer, 118 | "UInt256": Integer, 119 | "Float32": Float, 120 | "Float64": Float, 121 | "Decimal": Decimal, 122 | "UUID": Native_UUID, 123 | "String": Text, 124 | "FixedString": Text, 125 | "DateTime": Timestamp, 126 | "DateTime64": Timestamp, 127 | "Bool": Boolean, 128 | } 129 | MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 130 | 131 | def quote(self, s: str) -> str: 132 | return f'"{s}"' 133 | 134 | def to_string(self, s: str) -> str: 135 | return f"toString({s})" 136 | 137 | def _convert_db_precision_to_digits(self, p: int) -> int: 138 | # Done the same as for PostgreSQL but need to rewrite in another way 139 | # because it does not help for float with a big integer part. 140 | return super()._convert_db_precision_to_digits(p) - 2 141 | 142 | def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: 143 | nullable_prefix = "Nullable(" 144 | if type_repr.startswith(nullable_prefix): 145 | type_repr = type_repr[len(nullable_prefix) :].rstrip(")") 146 | 147 | if type_repr.startswith("Decimal"): 148 | type_repr = "Decimal" 149 | elif type_repr.startswith("FixedString"): 150 | type_repr = "FixedString" 151 | elif type_repr.startswith("DateTime64"): 152 | type_repr = "DateTime64" 153 | 154 | return self.TYPE_CLASSES.get(type_repr) 155 | 156 | # def timestamp_value(self, t: DbTime) -> str: 157 | # # return f"'{t}'" 158 | # return f"'{str(t)[:19]}'" 159 | 160 | def set_timezone_to_utc(self) -> str: 161 | raise NotImplementedError() 162 | 163 | def current_timestamp(self) -> str: 164 | return "now()" 165 | 166 | 167 | class Clickhouse(ThreadedDatabase): 168 | dialect = Dialect() 169 | CONNECT_URI_HELP = "clickhouse://:@/" 170 | CONNECT_URI_PARAMS = ["database?"] 171 | 172 | def __init__(self, *, thread_count: int, **kw): 173 | super().__init__(thread_count=thread_count) 174 | 175 | self._args = kw 176 | # In Clickhouse database and schema are the same 177 | self.default_schema = kw.get("database", DEFAULT_DATABASE) 178 | 179 | def create_connection(self): 180 | clickhouse = import_clickhouse() 181 | 182 | class SingleConnection(clickhouse.dbapi.connection.Connection): 183 | """Not thread-safe connection to Clickhouse""" 184 | 185 | def cursor(self, cursor_factory=None): 186 | if not len(self.cursors): 187 | _ = super().cursor() 188 | return self.cursors[0] 189 | 190 | try: 191 | return SingleConnection(**self._args) 192 | except clickhouse.OperationError as e: 193 | raise ConnectError(*e.args) from e 194 | 195 | @property 196 | def is_autocommit(self) -> bool: 197 | return True 198 | -------------------------------------------------------------------------------- /sqeleton/databases/databricks.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Sequence, Optional, Tuple 3 | import logging 4 | 5 | from ..abcs.database_types import ( 6 | Integer, 7 | Float, 8 | Decimal, 9 | Timestamp, 10 | Text, 11 | TemporalType, 12 | NumericType, 13 | DbPath, 14 | ColType, 15 | UnknownColType, 16 | Boolean, 17 | ) 18 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 19 | from .base import ( 20 | MD5_HEXDIGITS, 21 | CHECKSUM_HEXDIGITS, 22 | BaseDialect, 23 | ThreadedDatabase, 24 | import_helper, 25 | parse_table_name, 26 | Mixin_RandomSample, 27 | ) 28 | 29 | 30 | @import_helper(text="You can install it using 'pip install databricks-sql-connector'") 31 | def import_databricks(): 32 | import databricks.sql 33 | 34 | return databricks 35 | 36 | 37 | class Mixin_MD5(AbstractMixin_MD5): 38 | def md5_as_int(self, s: str) -> str: 39 | return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" 40 | 41 | 42 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 43 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 44 | """Databricks timestamp contains no more than 6 digits in precision""" 45 | 46 | if coltype.rounds: 47 | timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)" 48 | return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')" 49 | 50 | precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision) 51 | return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" 52 | 53 | def normalize_number(self, value: str, coltype: NumericType) -> str: 54 | value = f"cast({value} as decimal(38, {coltype.precision}))" 55 | if coltype.precision > 0: 56 | value = f"format_number({value}, {coltype.precision})" 57 | return f"replace({self.to_string(value)}, ',', '')" 58 | 59 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 60 | return self.to_string(f"cast ({value} as int)") 61 | 62 | 63 | class Dialect(BaseDialect): 64 | name = "Databricks" 65 | ROUNDS_ON_PREC_LOSS = True 66 | TYPE_CLASSES = { 67 | # Numbers 68 | "INT": Integer, 69 | "SMALLINT": Integer, 70 | "TINYINT": Integer, 71 | "BIGINT": Integer, 72 | "FLOAT": Float, 73 | "DOUBLE": Float, 74 | "DECIMAL": Decimal, 75 | # Timestamps 76 | "TIMESTAMP": Timestamp, 77 | # Text 78 | "STRING": Text, 79 | # Boolean 80 | "BOOLEAN": Boolean, 81 | } 82 | MIXINS = {Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 83 | 84 | def quote(self, s: str): 85 | return f"`{s}`" 86 | 87 | def to_string(self, s: str) -> str: 88 | return f"cast({s} as string)" 89 | 90 | def _convert_db_precision_to_digits(self, p: int) -> int: 91 | # Subtracting 2 due to wierd precision issues 92 | return max(super()._convert_db_precision_to_digits(p) - 2, 0) 93 | 94 | def set_timezone_to_utc(self) -> str: 95 | return "SET TIME ZONE 'UTC'" 96 | 97 | 98 | class Databricks(ThreadedDatabase): 99 | dialect = Dialect() 100 | CONNECT_URI_HELP = "databricks://:@/" 101 | CONNECT_URI_PARAMS = ["catalog", "schema"] 102 | 103 | def __init__(self, *, thread_count, **kw): 104 | logging.getLogger("databricks.sql").setLevel(logging.WARNING) 105 | 106 | self._args = kw 107 | self.default_schema = kw.get("schema", "default") 108 | self.catalog = self._args.get("catalog", "hive_metastore") 109 | super().__init__(thread_count=thread_count) 110 | 111 | def create_connection(self): 112 | databricks = import_databricks() 113 | 114 | try: 115 | return databricks.sql.connect( 116 | server_hostname=self._args["server_hostname"], 117 | http_path=self._args["http_path"], 118 | access_token=self._args["access_token"], 119 | catalog=self.catalog, 120 | ) 121 | except databricks.sql.exc.Error as e: 122 | raise ConnectionError(*e.args) from e 123 | 124 | def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: 125 | # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. 126 | # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html 127 | # So, to obtain information about schema, we should use another approach. 128 | 129 | conn = self.create_connection() 130 | 131 | catalog, schema, table = self._normalize_table_path(path) 132 | with conn.cursor() as cursor: 133 | cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) 134 | try: 135 | rows = cursor.fetchall() 136 | finally: 137 | conn.close() 138 | if not rows: 139 | raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") 140 | 141 | d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} 142 | assert len(d) == len(rows) 143 | return d 144 | 145 | def _process_table_schema( 146 | self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None 147 | ): 148 | accept = {i.lower() for i in filter_columns} 149 | rows = [row for name, row in raw_schema.items() if name.lower() in accept] 150 | 151 | resulted_rows = [] 152 | for row in rows: 153 | row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] 154 | type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) 155 | 156 | if issubclass(type_cls, Integer): 157 | row = (row[0], row_type, None, None, 0) 158 | 159 | elif issubclass(type_cls, Float): 160 | numeric_precision = math.ceil(row[2] / math.log(2, 10)) 161 | row = (row[0], row_type, None, numeric_precision, None) 162 | 163 | elif issubclass(type_cls, Decimal): 164 | items = row[1][8:].rstrip(")").split(",") 165 | numeric_precision, numeric_scale = int(items[0]), int(items[1]) 166 | row = (row[0], row_type, None, numeric_precision, numeric_scale) 167 | 168 | elif issubclass(type_cls, Timestamp): 169 | row = (row[0], row_type, row[2], None, None) 170 | 171 | else: 172 | row = (row[0], row_type, None, None, None) 173 | 174 | resulted_rows.append(row) 175 | 176 | col_dict: Dict[str, ColType] = {row[0]: self.dialect.parse_type(path, *row) for row in resulted_rows} 177 | 178 | self._refine_coltypes(path, col_dict, where) 179 | return col_dict 180 | 181 | def process_query_table_schema(self, path: DbPath, raw_schema: Dict[str, Tuple], refine: bool = True, refine_where: Optional[str] = None) -> Tuple[Dict[str, ColType], Optional[list]]: 182 | if not refine: 183 | raise NotImplementedError() 184 | return self._process_table_schema(path, raw_schema, list(raw_schema), refine_where), None 185 | 186 | def parse_table_name(self, name: str) -> DbPath: 187 | path = parse_table_name(name) 188 | return tuple(i for i in self._normalize_table_path(path) if i is not None) 189 | 190 | @property 191 | def is_autocommit(self) -> bool: 192 | return True 193 | 194 | def _normalize_table_path(self, path: DbPath) -> DbPath: 195 | if len(path) == 1: 196 | return self.catalog, self.default_schema, path[0] 197 | elif len(path) == 2: 198 | return self.catalog, path[0], path[1] 199 | elif len(path) == 3: 200 | return path 201 | 202 | raise ValueError( 203 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or catalog.schema.table" 204 | ) 205 | -------------------------------------------------------------------------------- /sqeleton/databases/duckdb.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from ..utils import match_regexps 4 | from ..abcs.database_types import ( 5 | Timestamp, 6 | TimestampTZ, 7 | DbPath, 8 | ColType, 9 | Float, 10 | Decimal, 11 | Integer, 12 | TemporalType, 13 | Native_UUID, 14 | Text, 15 | FractionalType, 16 | Boolean, 17 | AbstractTable, 18 | ) 19 | from ..abcs.mixins import ( 20 | AbstractMixin_MD5, 21 | AbstractMixin_NormalizeValue, 22 | AbstractMixin_RandomSample, 23 | AbstractMixin_Regex, 24 | ) 25 | from .base import ( 26 | Database, 27 | BaseDialect, 28 | import_helper, 29 | ConnectError, 30 | ThreadLocalInterpreter, 31 | TIMESTAMP_PRECISION_POS, 32 | ) 33 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema 34 | from ..queries.ast_classes import Func, Compilable 35 | from ..queries.api import code 36 | 37 | 38 | @import_helper("duckdb") 39 | def import_duckdb(): 40 | import duckdb 41 | 42 | return duckdb 43 | 44 | 45 | class Mixin_MD5(AbstractMixin_MD5): 46 | def md5_as_int(self, s: str) -> str: 47 | return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" 48 | 49 | 50 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 51 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 52 | # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. 53 | if coltype.rounds and coltype.precision > 0: 54 | return f"CONCAT(SUBSTRING(STRFTIME({value}::TIMESTAMP, '%Y-%m-%d %H:%M:%S.'),1,23), LPAD(((ROUND(strftime({value}::timestamp, '%f')::DECIMAL(15,7)/100000,{coltype.precision-1})*100000)::INT)::VARCHAR,6,'0'))" 55 | 56 | return f"rpad(substring(strftime({value}::timestamp, '%Y-%m-%d %H:%M:%S.%f'),1,{TIMESTAMP_PRECISION_POS+coltype.precision}),26,'0')" 57 | 58 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 59 | return self.to_string(f"{value}::DECIMAL(38, {coltype.precision})") 60 | 61 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 62 | return self.to_string(f"{value}::INTEGER") 63 | 64 | 65 | class Mixin_RandomSample(AbstractMixin_RandomSample): 66 | def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: 67 | return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) 68 | 69 | def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: 70 | return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) 71 | 72 | 73 | class Mixin_Regex(AbstractMixin_Regex): 74 | def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: 75 | return Func("regexp_matches", [string, pattern]) 76 | 77 | 78 | class Dialect(BaseDialect, Mixin_Schema): 79 | name = "DuckDB" 80 | ROUNDS_ON_PREC_LOSS = False 81 | SUPPORTS_PRIMARY_KEY = True 82 | SUPPORTS_INDEXES = True 83 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 84 | ARG_SYMBOL = "?" 85 | 86 | TYPE_CLASSES = { 87 | # Timestamps 88 | "TIMESTAMP WITH TIME ZONE": TimestampTZ, 89 | "TIMESTAMP": Timestamp, 90 | # Numbers 91 | "DOUBLE": Float, 92 | "FLOAT": Float, 93 | "DECIMAL": Decimal, 94 | "INTEGER": Integer, 95 | "BIGINT": Integer, 96 | # Text 97 | "VARCHAR": Text, 98 | "TEXT": Text, 99 | # UUID 100 | "UUID": Native_UUID, 101 | # Bool 102 | "BOOLEAN": Boolean, 103 | } 104 | 105 | def quote(self, s: str): 106 | return f'"{s}"' 107 | 108 | def to_string(self, s: str): 109 | return f"{s}::VARCHAR" 110 | 111 | def _convert_db_precision_to_digits(self, p: int) -> int: 112 | # Subtracting 2 due to wierd precision issues in PostgreSQL 113 | return super()._convert_db_precision_to_digits(p) - 2 114 | 115 | def parse_type( 116 | self, 117 | table_path: DbPath, 118 | col_name: str, 119 | type_repr: str, 120 | datetime_precision: int = None, 121 | numeric_precision: int = None, 122 | numeric_scale: int = None, 123 | ) -> ColType: 124 | regexps = { 125 | r"DECIMAL\((\d+),(\d+)\)": Decimal, 126 | } 127 | 128 | for m, t_cls in match_regexps(regexps, type_repr): 129 | precision = int(m.group(2)) 130 | return t_cls(precision=precision) 131 | 132 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) 133 | 134 | def set_timezone_to_utc(self) -> str: 135 | return "SET GLOBAL TimeZone='UTC'" 136 | 137 | def current_timestamp(self) -> str: 138 | return "current_timestamp" 139 | 140 | 141 | class DuckDB(Database): 142 | dialect = Dialect() 143 | SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it 144 | default_schema = "main" 145 | CONNECT_URI_HELP = "duckdb://@" 146 | CONNECT_URI_PARAMS = ["database", "dbpath"] 147 | 148 | def __init__(self, **kw): 149 | self._args = kw 150 | self._conn = self.create_connection() 151 | 152 | @property 153 | def is_autocommit(self) -> bool: 154 | return True 155 | 156 | def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): 157 | "Uses the standard SQL cursor interface" 158 | return self._query_conn(self._conn, sql_code) 159 | 160 | def close(self): 161 | super().close() 162 | self._conn.close() 163 | 164 | def create_connection(self): 165 | ddb = import_duckdb() 166 | try: 167 | return ddb.connect(self._args["filepath"]) 168 | except ddb.OperationalError as e: 169 | raise ConnectError(*e.args) from e 170 | 171 | def select_table_schema(self, path: DbPath) -> str: 172 | database, schema, table = self._normalize_table_path(path) 173 | 174 | info_schema_path = ["information_schema", "columns"] 175 | if database: 176 | info_schema_path.insert(0, database) 177 | 178 | return ( 179 | f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " 180 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 181 | ) 182 | 183 | def _normalize_table_path(self, path: DbPath) -> DbPath: 184 | if len(path) == 1: 185 | return None, self.default_schema, path[0] 186 | elif len(path) == 2: 187 | return None, path[0], path[1] 188 | elif len(path) == 3: 189 | return path 190 | 191 | raise ValueError( 192 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 193 | ) 194 | -------------------------------------------------------------------------------- /sqeleton/databases/mssql.py: -------------------------------------------------------------------------------- 1 | # class MsSQL(ThreadedDatabase): 2 | # "AKA sql-server" 3 | 4 | # def __init__(self, host, port, user, password, *, database, thread_count, **kw): 5 | # args = dict(server=host, port=port, database=database, user=user, password=password, **kw) 6 | # self._args = {k: v for k, v in args.items() if v is not None} 7 | 8 | # super().__init__(thread_count=thread_count) 9 | 10 | # def create_connection(self): 11 | # mssql = import_mssql() 12 | # try: 13 | # return mssql.connect(**self._args) 14 | # except mssql.Error as e: 15 | # raise ConnectError(*e.args) from e 16 | 17 | # def quote(self, s: str): 18 | # return f"[{s}]" 19 | 20 | # def md5_as_int(self, s: str) -> str: 21 | # return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" 22 | # # return f"CONVERT(bigint, (CHECKSUM({s})))" 23 | 24 | # def to_string(self, s: str): 25 | # return f"CONVERT(varchar, {s})" 26 | -------------------------------------------------------------------------------- /sqeleton/databases/mysql.py: -------------------------------------------------------------------------------- 1 | from ..abcs.database_types import ( 2 | Datetime, 3 | Timestamp, 4 | Float, 5 | Decimal, 6 | Integer, 7 | Text, 8 | TemporalType, 9 | FractionalType, 10 | ColType_UUID, 11 | Boolean, 12 | Date, 13 | ) 14 | from ..abcs.mixins import ( 15 | AbstractMixin_MD5, 16 | AbstractMixin_NormalizeValue, 17 | AbstractMixin_Regex, 18 | AbstractMixin_RandomSample, 19 | ) 20 | from .base import Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable 21 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema, Mixin_RandomSample 22 | from ..queries.ast_classes import BinBoolOp 23 | 24 | 25 | @import_helper("mysql") 26 | def import_mysql(): 27 | import mysql.connector 28 | 29 | return mysql.connector 30 | 31 | 32 | class Mixin_MD5(AbstractMixin_MD5): 33 | def md5_as_int(self, s: str) -> str: 34 | return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" 35 | 36 | 37 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 38 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 39 | if coltype.rounds: 40 | return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") 41 | 42 | s = self.to_string(f"cast({value} as datetime(6))") 43 | return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" 44 | 45 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 46 | return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") 47 | 48 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 49 | return f"TRIM(CAST({value} AS char))" 50 | 51 | 52 | class Mixin_Regex(AbstractMixin_Regex): 53 | def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: 54 | return BinBoolOp("REGEXP", [string, pattern]) 55 | 56 | 57 | class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): 58 | name = "MySQL" 59 | ROUNDS_ON_PREC_LOSS = True 60 | SUPPORTS_PRIMARY_KEY = True 61 | SUPPORTS_INDEXES = True 62 | 63 | TYPE_CLASSES = { 64 | # Dates 65 | "datetime": Datetime, 66 | "timestamp": Timestamp, 67 | "date": Date, 68 | # Numbers 69 | "double": Float, 70 | "float": Float, 71 | "decimal": Decimal, 72 | "int": Integer, 73 | "bigint": Integer, 74 | "smallint": Integer, 75 | "tinyint": Integer, 76 | "mediumint": Integer, 77 | # Text 78 | "varchar": Text, 79 | "char": Text, 80 | "varbinary": Text, 81 | "binary": Text, 82 | "text": Text, 83 | "mediumtext": Text, 84 | "longtext": Text, 85 | "tinytext": Text, 86 | # Boolean 87 | "boolean": Boolean, 88 | } 89 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 90 | 91 | def quote(self, s: str): 92 | return f"`{s}`" 93 | 94 | def to_string(self, s: str): 95 | return f"cast({s} as char)" 96 | 97 | def is_distinct_from(self, a: str, b: str) -> str: 98 | return f"not ({a} <=> {b})" 99 | 100 | def random(self) -> str: 101 | return "RAND()" 102 | 103 | def type_repr(self, t) -> str: 104 | if isinstance(t, type): 105 | try: 106 | return { 107 | str: "VARCHAR(1024)", 108 | }[t] 109 | except KeyError: 110 | pass 111 | return super().type_repr(t) 112 | 113 | def explain_as_text(self, query: str) -> str: 114 | return f"EXPLAIN FORMAT=TREE {query}" 115 | 116 | def optimizer_hints(self, s: str): 117 | return f"/*+ {s} */ " 118 | 119 | def set_timezone_to_utc(self) -> str: 120 | return "SET @@session.time_zone='+00:00'" 121 | 122 | 123 | class MySQL(ThreadedDatabase): 124 | dialect = Dialect() 125 | SUPPORTS_ALPHANUMS = False 126 | SUPPORTS_UNIQUE_CONSTAINT = True 127 | CONNECT_URI_HELP = "mysql://:@/" 128 | CONNECT_URI_PARAMS = ["database?"] 129 | 130 | def __init__(self, *, thread_count, **kw): 131 | self._args = kw 132 | 133 | super().__init__(thread_count=thread_count) 134 | 135 | # In MySQL schema and database are synonymous 136 | try: 137 | self.default_schema = kw["database"] 138 | except KeyError: 139 | raise ValueError("MySQL URL must specify a database") 140 | 141 | def create_connection(self): 142 | mysql = import_mysql() 143 | try: 144 | return mysql.connect(charset="utf8", use_unicode=True, **self._args) 145 | except mysql.Error as e: 146 | if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR: 147 | raise ConnectError("Bad user name or password") from e 148 | elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR: 149 | raise ConnectError("Database does not exist") from e 150 | raise ConnectError(*e.args) from e 151 | -------------------------------------------------------------------------------- /sqeleton/databases/oracle.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | from ..utils import match_regexps 4 | from ..abcs.database_types import ( 5 | Decimal, 6 | Float, 7 | Text, 8 | DbPath, 9 | TemporalType, 10 | ColType, 11 | DbTime, 12 | ColType_UUID, 13 | Timestamp, 14 | TimestampTZ, 15 | FractionalType, 16 | ) 17 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema 18 | from ..abcs import Compilable 19 | from ..queries import this, table, SKIP 20 | from .base import ( 21 | BaseDialect, 22 | Mixin_OptimizerHints, 23 | ThreadedDatabase, 24 | import_helper, 25 | ConnectError, 26 | QueryError, 27 | Mixin_RandomSample, 28 | ) 29 | from .base import TIMESTAMP_PRECISION_POS 30 | 31 | SESSION_TIME_ZONE = None # Changed by the tests 32 | 33 | 34 | @import_helper("oracle") 35 | def import_oracle(): 36 | import cx_Oracle 37 | 38 | return cx_Oracle 39 | 40 | 41 | class Mixin_MD5(AbstractMixin_MD5): 42 | def md5_as_int(self, s: str) -> str: 43 | # standard_hash is faster than DBMS_CRYPTO.Hash 44 | # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ? 45 | return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" 46 | 47 | 48 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 49 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 50 | # Cast is necessary for correct MD5 (trimming not enough) 51 | return f"CAST(TRIM({value}) AS VARCHAR(36))" 52 | 53 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 54 | if coltype.rounds: 55 | return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" 56 | 57 | if coltype.precision > 0: 58 | truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')" 59 | else: 60 | truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')" 61 | return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')" 62 | 63 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 64 | # FM999.9990 65 | format_str = "FM" + "9" * (38 - coltype.precision) 66 | if coltype.precision: 67 | format_str += "0." + "9" * (coltype.precision - 1) + "0" 68 | return f"to_char({value}, '{format_str}')" 69 | 70 | 71 | class Mixin_Schema(AbstractMixin_Schema): 72 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 73 | return ( 74 | table("ALL_TABLES") 75 | .where( 76 | this.OWNER == table_schema, 77 | this.TABLE_NAME.like(like) if like is not None else SKIP, 78 | ) 79 | .select(table_name=this.TABLE_NAME) 80 | ) 81 | 82 | 83 | class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): 84 | name = "Oracle" 85 | SUPPORTS_PRIMARY_KEY = True 86 | SUPPORTS_INDEXES = True 87 | TYPE_CLASSES: Dict[str, type] = { 88 | "NUMBER": Decimal, 89 | "FLOAT": Float, 90 | # Text 91 | "CHAR": Text, 92 | "NCHAR": Text, 93 | "NVARCHAR2": Text, 94 | "VARCHAR2": Text, 95 | "DATE": Timestamp, 96 | } 97 | ROUNDS_ON_PREC_LOSS = True 98 | PLACEHOLDER_TABLE = "DUAL" 99 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 100 | 101 | def quote(self, s: str): 102 | return f'"{s}"' 103 | 104 | def to_string(self, s: str): 105 | return f"cast({s} as varchar(1024))" 106 | 107 | def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): 108 | if offset: 109 | raise NotImplementedError("No support for OFFSET in query") 110 | 111 | return f"FETCH NEXT {limit} ROWS ONLY" 112 | 113 | def concat(self, items: List[str]) -> str: 114 | joined_exprs = " || ".join(items) 115 | return f"({joined_exprs})" 116 | 117 | def timestamp_value(self, t: DbTime) -> str: 118 | return "timestamp '%s'" % t.isoformat(" ") 119 | 120 | def random(self) -> str: 121 | return "dbms_random.value" 122 | 123 | def is_distinct_from(self, a: str, b: str) -> str: 124 | return f"DECODE({a}, {b}, 1, 0) = 0" 125 | 126 | def type_repr(self, t) -> str: 127 | try: 128 | return { 129 | str: "VARCHAR(1024)", 130 | }[t] 131 | except KeyError: 132 | return super().type_repr(t) 133 | 134 | def constant_values(self, rows) -> str: 135 | return " UNION ALL ".join("SELECT %s FROM DUAL" % ", ".join(row) for row in rows) 136 | 137 | def explain_as_text(self, query: str) -> str: 138 | raise NotImplementedError("Explain not yet implemented in Oracle") 139 | 140 | def parse_type( 141 | self, 142 | table_path: DbPath, 143 | col_name: str, 144 | type_repr: str, 145 | datetime_precision: int = None, 146 | numeric_precision: int = None, 147 | numeric_scale: int = None, 148 | ) -> ColType: 149 | regexps = { 150 | r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, 151 | r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ, 152 | r"TIMESTAMP\((\d)\)": Timestamp, 153 | } 154 | 155 | for m, t_cls in match_regexps(regexps, type_repr): 156 | precision = int(m.group(1)) 157 | return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) 158 | 159 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) 160 | 161 | def set_timezone_to_utc(self) -> str: 162 | return "ALTER SESSION SET TIME_ZONE = 'UTC'" 163 | 164 | def current_timestamp(self) -> str: 165 | return "LOCALTIMESTAMP" 166 | 167 | 168 | class Oracle(ThreadedDatabase): 169 | dialect = Dialect() 170 | CONNECT_URI_HELP = "oracle://:@/" 171 | CONNECT_URI_PARAMS = ["database?"] 172 | 173 | def __init__(self, *, host, database, thread_count, **kw): 174 | 175 | self.kwargs = kw 176 | 177 | # Build dsn if not present 178 | if "dsn" not in kw: 179 | self.kwargs["dsn"] = f"{host}/{database}" if database else host 180 | 181 | self.default_schema = kw.get("user").upper() 182 | 183 | super().__init__(thread_count=thread_count) 184 | 185 | def create_connection(self): 186 | self._oracle = import_oracle() 187 | try: 188 | c = self._oracle.connect(**self.kwargs) 189 | if SESSION_TIME_ZONE: 190 | c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'") 191 | return c 192 | except Exception as e: 193 | raise ConnectError(*e.args) from e 194 | 195 | def _query_cursor(self, c, sql_code: str): 196 | try: 197 | return super()._query_cursor(c, sql_code) 198 | except self._oracle.DatabaseError as e: 199 | raise QueryError(e) 200 | 201 | def select_table_schema(self, path: DbPath) -> str: 202 | schema, name = self._normalize_table_path(path) 203 | 204 | return ( 205 | f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" 206 | f" FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'" 207 | ) 208 | -------------------------------------------------------------------------------- /sqeleton/databases/postgresql.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from ..abcs.database_types import ( 3 | DbPath, 4 | Timestamp, 5 | TimestampTZ, 6 | Float, 7 | Decimal, 8 | Integer, 9 | TemporalType, 10 | Native_UUID, 11 | Text, 12 | FractionalType, 13 | Boolean, 14 | Date, 15 | ) 16 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 17 | from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema 18 | from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS, Mixin_RandomSample 19 | from ..schema import _Field 20 | 21 | SESSION_TIME_ZONE = None # Changed by the tests 22 | 23 | 24 | @import_helper("postgresql") 25 | def import_postgresql(): 26 | import psycopg2 27 | import psycopg2.extras 28 | 29 | psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) 30 | return psycopg2 31 | 32 | 33 | class Mixin_MD5(AbstractMixin_MD5): 34 | def md5_as_int(self, s: str) -> str: 35 | return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" 36 | 37 | 38 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 39 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 40 | if coltype.rounds: 41 | return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" 42 | 43 | timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" 44 | return ( 45 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 46 | ) 47 | 48 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 49 | return self.to_string(f"{value}::decimal(38, {coltype.precision})") 50 | 51 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 52 | return self.to_string(f"{value}::int") 53 | 54 | 55 | class PostgresqlDialect(BaseDialect, Mixin_Schema): 56 | name = "PostgreSQL" 57 | ROUNDS_ON_PREC_LOSS = True 58 | SUPPORTS_PRIMARY_KEY = True 59 | SUPPORTS_INDEXES = True 60 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 61 | 62 | TYPE_CLASSES = { 63 | # Timestamps 64 | "timestamp with time zone": TimestampTZ, 65 | "timestamp without time zone": Timestamp, 66 | "timestamp": Timestamp, 67 | "date": Date, 68 | # Numbers 69 | "double precision": Float, 70 | "real": Float, 71 | "decimal": Decimal, 72 | "smallint": Integer, 73 | "integer": Integer, 74 | "numeric": Decimal, 75 | "bigint": Integer, 76 | # Text 77 | "character": Text, 78 | "character varying": Text, 79 | "varchar": Text, 80 | "text": Text, 81 | # UUID 82 | "uuid": Native_UUID, 83 | # Boolean 84 | "boolean": Boolean, 85 | } 86 | 87 | def quote(self, s: str): 88 | return f'"{s}"' 89 | 90 | def to_string(self, s: str): 91 | return f"{s}::varchar" 92 | 93 | def concat(self, items: List[str]) -> str: 94 | # Postgres does not allow functions that have more that have more than 100 arguments. 95 | # When using the concat function, this limits comparisons to less than 50 fields. 96 | # Using || for concat fixes this. 97 | joined_exprs = " || ".join(items) 98 | return f"({joined_exprs})" 99 | 100 | def _convert_db_precision_to_digits(self, p: int) -> int: 101 | # Subtracting 2 due to wierd precision issues in PostgreSQL 102 | return super()._convert_db_precision_to_digits(p) - 2 103 | 104 | def set_timezone_to_utc(self) -> str: 105 | return "SET TIME ZONE 'UTC'" 106 | 107 | def current_timestamp(self) -> str: 108 | return "current_timestamp" 109 | 110 | def type_repr(self, t) -> str: 111 | if isinstance(t, _Field): 112 | if t.options.auto: 113 | assert t.type is int 114 | return "SERIAL" 115 | else: 116 | t = t.type 117 | 118 | if isinstance(t, TimestampTZ): 119 | return f"timestamp ({t.precision}) with time zone" 120 | elif t in (dict, list): 121 | return "json" 122 | return super().type_repr(t) 123 | 124 | 125 | class PostgreSQL(ThreadedDatabase): 126 | dialect = PostgresqlDialect() 127 | SUPPORTS_UNIQUE_CONSTAINT = True 128 | CONNECT_URI_HELP = "postgresql://:@/" 129 | CONNECT_URI_PARAMS = ["database?"] 130 | 131 | default_schema = "public" 132 | 133 | def __init__(self, *, thread_count, **kw): 134 | self._args = kw 135 | 136 | super().__init__(thread_count=thread_count) 137 | 138 | def create_connection(self): 139 | if not self._args: 140 | self._args["host"] = None # psycopg2 requires 1+ arguments 141 | 142 | pg = import_postgresql() 143 | try: 144 | c = pg.connect(**self._args) 145 | if SESSION_TIME_ZONE: 146 | c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'") 147 | return c 148 | except pg.OperationalError as e: 149 | raise ConnectError(*e.args) from e 150 | 151 | def select_table_schema(self, path: DbPath) -> str: 152 | database, schema, table = self._normalize_table_path(path) 153 | 154 | info_schema_path = ["information_schema", "columns"] 155 | if database: 156 | info_schema_path.insert(0, database) 157 | 158 | return ( 159 | f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " 160 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 161 | ) 162 | 163 | def select_table_unique_columns(self, path: DbPath) -> str: 164 | database, schema, table = self._normalize_table_path(path) 165 | 166 | info_schema_path = ["information_schema", "key_column_usage"] 167 | if database: 168 | info_schema_path.insert(0, database) 169 | 170 | return ( 171 | "SELECT column_name " 172 | f"FROM {'.'.join(info_schema_path)} " 173 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 174 | ) 175 | 176 | def _normalize_table_path(self, path: DbPath) -> DbPath: 177 | if len(path) == 1: 178 | return None, self.default_schema, path[0] 179 | elif len(path) == 2: 180 | return None, path[0], path[1] 181 | elif len(path) == 3: 182 | return path 183 | 184 | raise ValueError( 185 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 186 | ) 187 | -------------------------------------------------------------------------------- /sqeleton/databases/presto.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import re 3 | from typing import Optional 4 | 5 | from sqeleton.queries.ast_classes import ForeignKey 6 | 7 | from ..utils import match_regexps 8 | 9 | from ..abcs.database_types import ( 10 | Timestamp, 11 | TimestampTZ, 12 | Integer, 13 | Float, 14 | Text, 15 | FractionalType, 16 | DbPath, 17 | DbTime, 18 | Decimal, 19 | ColType, 20 | ColType_UUID, 21 | TemporalType, 22 | Boolean, 23 | Native_UUID, 24 | ) 25 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue 26 | from .base import BaseDialect, Database, QueryResult, import_helper, ThreadLocalInterpreter, Mixin_Schema, Mixin_RandomSample, SqlCode, logger 27 | from .base import ( 28 | MD5_HEXDIGITS, 29 | CHECKSUM_HEXDIGITS, 30 | TIMESTAMP_PRECISION_POS, 31 | ) 32 | from ..queries.compiler import CompiledCode 33 | 34 | 35 | def query_cursor(c, sql_code: CompiledCode) -> Optional[QueryResult]: 36 | logger.debug(f"[Presto] Executing SQL: {sql_code.code} || {sql_code.args}") 37 | c.execute(sql_code.code, sql_code.args) 38 | 39 | columns = c.description and [col[0] for col in c.description] 40 | return QueryResult(c.fetchall(), columns) 41 | 42 | 43 | @import_helper("presto") 44 | def import_presto(): 45 | import prestodb 46 | 47 | return prestodb 48 | 49 | 50 | class Mixin_MD5(AbstractMixin_MD5): 51 | def md5_as_int(self, s: str) -> str: 52 | return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" 53 | 54 | 55 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 56 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 57 | # Trim doesn't work on CHAR type 58 | return f"TRIM(CAST({value} AS VARCHAR))" 59 | 60 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 61 | # TODO rounds 62 | if coltype.rounds: 63 | s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" 64 | else: 65 | s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" 66 | 67 | return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" 68 | 69 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 70 | return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") 71 | 72 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 73 | return self.to_string(f"cast ({value} as int)") 74 | 75 | 76 | class Dialect(BaseDialect, Mixin_Schema): 77 | name = "Presto" 78 | ROUNDS_ON_PREC_LOSS = True 79 | ARG_SYMBOL = None # Not implemented by Presto 80 | TYPE_CLASSES = { 81 | # Timestamps 82 | "timestamp with time zone": TimestampTZ, 83 | "timestamp without time zone": Timestamp, 84 | "timestamp": Timestamp, 85 | # Numbers 86 | "integer": Integer, 87 | "bigint": Integer, 88 | "real": Float, 89 | "double": Float, 90 | # Text 91 | "varchar": Text, 92 | # Boolean 93 | "boolean": Boolean, 94 | # UUID 95 | "uuid": Native_UUID, 96 | } 97 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 98 | 99 | def explain_as_text(self, query: str) -> str: 100 | return f"EXPLAIN (FORMAT TEXT) {query}" 101 | 102 | def timestamp_value(self, t: DbTime) -> str: 103 | return f"timestamp '{t.isoformat(' ')}'" 104 | 105 | def quote(self, s: str): 106 | return f'"{s}"' 107 | 108 | def to_string(self, s: str): 109 | return f"cast({s} as varchar)" 110 | 111 | def parse_type( 112 | self, 113 | table_path: DbPath, 114 | col_name: str, 115 | type_repr: str, 116 | datetime_precision: int = None, 117 | numeric_precision: int = None, 118 | _numeric_scale: int = None, 119 | ) -> ColType: 120 | timestamp_regexps = { 121 | r"timestamp\((\d)\)": Timestamp, 122 | r"timestamp\((\d)\) with time zone": TimestampTZ, 123 | } 124 | for m, t_cls in match_regexps(timestamp_regexps, type_repr): 125 | precision = int(m.group(1)) 126 | return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) 127 | 128 | number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} 129 | for m, n_cls in match_regexps(number_regexps, type_repr): 130 | _prec, scale = map(int, m.groups()) 131 | return n_cls(scale) 132 | 133 | string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text} 134 | for m, n_cls in match_regexps(string_regexps, type_repr): 135 | return n_cls() 136 | 137 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) 138 | 139 | def set_timezone_to_utc(self) -> str: 140 | # return "SET TIME ZONE '+00:00'" 141 | raise NotImplementedError("TODO") 142 | 143 | def current_timestamp(self) -> str: 144 | return "current_timestamp" 145 | 146 | def type_repr(self, t) -> str: 147 | if isinstance(t, TimestampTZ): 148 | return f"timestamp with time zone" 149 | elif isinstance(t, ForeignKey): 150 | return self.type_repr(t.type) 151 | try: 152 | return {float: "REAL"}[t] 153 | except KeyError: 154 | pass 155 | 156 | return super().type_repr(t) 157 | 158 | 159 | class Presto(Database): 160 | dialect = Dialect() 161 | CONNECT_URI_HELP = "presto://@//" 162 | CONNECT_URI_PARAMS = ["catalog", "schema"] 163 | 164 | default_schema = "public" 165 | 166 | def __init__(self, **kw): 167 | prestodb = import_presto() 168 | 169 | if kw.get("schema"): 170 | self.default_schema = kw.get("schema") 171 | 172 | if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto 173 | kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) 174 | 175 | if "cert" in kw: # if a certificate was specified in URI, verify session with cert 176 | cert = kw.pop("cert") 177 | self._conn = prestodb.dbapi.connect(**kw) 178 | self._conn._http_session.verify = cert 179 | else: 180 | self._conn = prestodb.dbapi.connect(**kw) 181 | 182 | def _query(self, sql_code: SqlCode) -> Optional[QueryResult]: 183 | "Uses the standard SQL cursor interface" 184 | c = self._conn.cursor() 185 | 186 | if isinstance(sql_code, ThreadLocalInterpreter): 187 | return sql_code.apply_queries(partial(query_cursor, c)) 188 | elif isinstance(sql_code, str): 189 | sql_code = CompiledCode(sql_code, [], None) # Unknown type. #TODO: Should we guess? 190 | 191 | return query_cursor(c, sql_code) 192 | 193 | def close(self): 194 | super().close() 195 | self._conn.close() 196 | 197 | def select_table_schema(self, path: DbPath) -> str: 198 | schema, table = self._normalize_table_path(path) 199 | 200 | return ( 201 | "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale " 202 | "FROM INFORMATION_SCHEMA.COLUMNS " 203 | f"WHERE table_name = '{table}' AND table_schema = '{schema}'" 204 | ) 205 | 206 | @property 207 | def is_autocommit(self) -> bool: 208 | return False 209 | -------------------------------------------------------------------------------- /sqeleton/databases/redshift.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath, TimestampTZ 3 | from ..abcs.mixins import AbstractMixin_MD5 4 | from .postgresql import ( 5 | PostgreSQL, 6 | MD5_HEXDIGITS, 7 | CHECKSUM_HEXDIGITS, 8 | TIMESTAMP_PRECISION_POS, 9 | PostgresqlDialect, 10 | Mixin_NormalizeValue, 11 | ) 12 | 13 | 14 | class Mixin_MD5(AbstractMixin_MD5): 15 | def md5_as_int(self, s: str) -> str: 16 | return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" 17 | 18 | 19 | class Mixin_NormalizeValue(Mixin_NormalizeValue): 20 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 21 | if coltype.rounds: 22 | timestamp = f"{value}::timestamp(6)" 23 | # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. 24 | secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" 25 | # Get the milliseconds from timestamp. 26 | ms = f"extract(ms from {timestamp})" 27 | # Get the microseconds from timestamp, without the milliseconds! 28 | us = f"extract(us from {timestamp})" 29 | # epoch = Total time since epoch in microseconds. 30 | epoch = f"{secs}*1000000 + {ms}*1000 + {us}" 31 | timestamp6 = ( 32 | f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" 33 | ) 34 | else: 35 | timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" 36 | return ( 37 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 38 | ) 39 | 40 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 41 | return self.to_string(f"{value}::decimal(38,{coltype.precision})") 42 | 43 | 44 | class Dialect(PostgresqlDialect): 45 | name = "Redshift" 46 | TYPE_CLASSES = { 47 | **PostgresqlDialect.TYPE_CLASSES, 48 | "double": Float, 49 | "real": Float, 50 | } 51 | SUPPORTS_INDEXES = False 52 | 53 | def concat(self, items: List[str]) -> str: 54 | joined_exprs = " || ".join(items) 55 | return f"({joined_exprs})" 56 | 57 | def is_distinct_from(self, a: str, b: str) -> str: 58 | return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})" 59 | 60 | def type_repr(self, t) -> str: 61 | if isinstance(t, TimestampTZ): 62 | return f"timestamptz" 63 | return super().type_repr(t) 64 | 65 | 66 | class Redshift(PostgreSQL): 67 | dialect = Dialect() 68 | CONNECT_URI_HELP = "redshift://:@/" 69 | CONNECT_URI_PARAMS = ["database?"] 70 | 71 | def select_table_schema(self, path: DbPath) -> str: 72 | database, schema, table = self._normalize_table_path(path) 73 | 74 | info_schema_path = ["information_schema", "columns"] 75 | if database: 76 | info_schema_path.insert(0, database) 77 | 78 | return ( 79 | f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} " 80 | f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" 81 | ) 82 | 83 | def select_external_table_schema(self, path: DbPath) -> str: 84 | database, schema, table = self._normalize_table_path(path) 85 | 86 | db_clause = "" 87 | if database: 88 | db_clause = f" AND redshift_database_name = '{database.lower()}'" 89 | 90 | return ( 91 | f"""SELECT 92 | columnname AS column_name 93 | , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type 94 | , NULL AS datetime_precision 95 | , NULL AS numeric_precision 96 | , NULL AS numeric_scale 97 | FROM svv_external_columns 98 | WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}' 99 | """ 100 | + db_clause 101 | ) 102 | 103 | def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: 104 | rows = self.query(self.select_external_table_schema(path), list) 105 | if not rows: 106 | raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") 107 | 108 | d = {r[0]: r for r in rows} 109 | assert len(d) == len(rows) 110 | return d 111 | 112 | def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: 113 | try: 114 | return super().query_table_schema(path) 115 | except RuntimeError: 116 | return self.query_external_table_schema(path) 117 | 118 | def _normalize_table_path(self, path: DbPath) -> DbPath: 119 | if len(path) == 1: 120 | return None, self.default_schema, path[0] 121 | elif len(path) == 2: 122 | return None, path[0], path[1] 123 | elif len(path) == 3: 124 | return path 125 | 126 | raise ValueError( 127 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 128 | ) 129 | -------------------------------------------------------------------------------- /sqeleton/databases/snowflake.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import logging 3 | 4 | from ..abcs.database_types import ( 5 | Timestamp, 6 | TimestampTZ, 7 | Decimal, 8 | Float, 9 | Text, 10 | FractionalType, 11 | TemporalType, 12 | DbPath, 13 | Boolean, 14 | Date, 15 | ) 16 | from ..abcs.mixins import ( 17 | AbstractMixin_MD5, 18 | AbstractMixin_NormalizeValue, 19 | AbstractMixin_Schema, 20 | AbstractMixin_TimeTravel, 21 | ) 22 | from ..abcs import Compilable 23 | from sqeleton.queries import table, this, SKIP, code 24 | from .base import ( 25 | BaseDialect, 26 | ConnectError, 27 | Database, 28 | import_helper, 29 | CHECKSUM_MASK, 30 | ThreadLocalInterpreter, 31 | Mixin_RandomSample, 32 | ) 33 | 34 | 35 | @import_helper("snowflake") 36 | def import_snowflake(): 37 | import snowflake.connector 38 | from cryptography.hazmat.primitives import serialization 39 | from cryptography.hazmat.backends import default_backend 40 | 41 | return snowflake, serialization, default_backend 42 | 43 | 44 | class Mixin_MD5(AbstractMixin_MD5): 45 | def md5_as_int(self, s: str) -> str: 46 | return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" 47 | 48 | 49 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 50 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 51 | if coltype.rounds: 52 | timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" 53 | else: 54 | timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" 55 | 56 | return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" 57 | 58 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 59 | return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") 60 | 61 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 62 | return self.to_string(f"{value}::int") 63 | 64 | 65 | class Mixin_Schema(AbstractMixin_Schema): 66 | def table_information(self) -> Compilable: 67 | return table("INFORMATION_SCHEMA", "TABLES") 68 | 69 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 70 | return ( 71 | self.table_information() 72 | .where( 73 | this.TABLE_SCHEMA == table_schema, 74 | this.TABLE_NAME.like(like) if like is not None else SKIP, 75 | this.TABLE_TYPE == "BASE TABLE", 76 | ) 77 | .select(table_name=this.TABLE_NAME) 78 | ) 79 | 80 | 81 | class Mixin_TimeTravel(AbstractMixin_TimeTravel): 82 | def time_travel( 83 | self, 84 | table: Compilable, 85 | before: bool = False, 86 | timestamp: Compilable = None, 87 | offset: Compilable = None, 88 | statement: Compilable = None, 89 | ) -> Compilable: 90 | at_or_before = "AT" if before else "BEFORE" 91 | if timestamp is not None: 92 | assert offset is None and statement is None 93 | key = "timestamp" 94 | value = timestamp 95 | elif offset is not None: 96 | assert statement is None 97 | key = "offset" 98 | value = offset 99 | else: 100 | assert statement is not None 101 | key = "statement" 102 | value = statement 103 | 104 | return code(f"{{table}} {at_or_before}({key} => {{value}})", table=table, value=value) 105 | 106 | 107 | class Dialect(BaseDialect, Mixin_Schema): 108 | name = "Snowflake" 109 | ROUNDS_ON_PREC_LOSS = False 110 | TYPE_CLASSES = { 111 | # Timestamps 112 | "TIMESTAMP_NTZ": Timestamp, 113 | "TIMESTAMP_LTZ": Timestamp, 114 | "TIMESTAMP_TZ": TimestampTZ, 115 | "DATE": Date, 116 | # Numbers 117 | "NUMBER": Decimal, 118 | "FLOAT": Float, 119 | # Text 120 | "TEXT": Text, 121 | # Boolean 122 | "BOOLEAN": Boolean, 123 | } 124 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample} 125 | 126 | def explain_as_text(self, query: str) -> str: 127 | return f"EXPLAIN USING TEXT {query}" 128 | 129 | def quote(self, s: str): 130 | return f'"{s}"' 131 | 132 | def to_string(self, s: str): 133 | return f"cast({s} as string)" 134 | 135 | def table_information(self) -> Compilable: 136 | return table("INFORMATION_SCHEMA", "TABLES") 137 | 138 | def set_timezone_to_utc(self) -> str: 139 | return "ALTER SESSION SET TIMEZONE = 'UTC'" 140 | 141 | def optimizer_hints(self, hints: str) -> str: 142 | raise NotImplementedError("Optimizer hints not yet implemented in snowflake") 143 | 144 | def type_repr(self, t) -> str: 145 | if isinstance(t, TimestampTZ): 146 | return f"timestamp_tz({t.precision})" 147 | return super().type_repr(t) 148 | 149 | 150 | class Snowflake(Database): 151 | dialect = Dialect() 152 | CONNECT_URI_HELP = "snowflake://:@//?warehouse=" 153 | CONNECT_URI_PARAMS = ["database", "schema"] 154 | CONNECT_URI_KWPARAMS = ["warehouse"] 155 | 156 | def __init__(self, *, schema: str, **kw): 157 | snowflake, serialization, default_backend = import_snowflake() 158 | logging.getLogger("snowflake.connector").setLevel(logging.WARNING) 159 | 160 | # Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state 161 | # It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145 162 | logging.getLogger("snowflake.connector.network").disabled = True 163 | 164 | assert '"' not in schema, "Schema name should not contain quotes!" 165 | # If a private key is used, read it from the specified path and pass it as "private_key" to the connector. 166 | if "key" in kw: 167 | with open(kw.get("key"), "rb") as key: 168 | if "password" in kw: 169 | raise ConnectError("Cannot use password and key at the same time") 170 | if kw.get("private_key_passphrase"): 171 | encoded_passphrase = kw.get("private_key_passphrase").encode() 172 | else: 173 | encoded_passphrase = None 174 | p_key = serialization.load_pem_private_key( 175 | key.read(), 176 | password=encoded_passphrase, 177 | backend=default_backend(), 178 | ) 179 | 180 | kw["private_key"] = p_key.private_bytes( 181 | encoding=serialization.Encoding.DER, 182 | format=serialization.PrivateFormat.PKCS8, 183 | encryption_algorithm=serialization.NoEncryption(), 184 | ) 185 | 186 | self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw) 187 | self._conn.telemetry_enabled = False 188 | self.default_schema = schema 189 | 190 | def close(self): 191 | super().close() 192 | self._conn.close() 193 | 194 | def _query(self, sql_code: Union[str, ThreadLocalInterpreter]): 195 | "Uses the standard SQL cursor interface" 196 | return self._query_conn(self._conn, sql_code) 197 | 198 | def select_table_schema(self, path: DbPath) -> str: 199 | """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" 200 | database, schema, name = self._normalize_table_path(path) 201 | info_schema_path = ["information_schema", "columns"] 202 | if database: 203 | info_schema_path.insert(0, database) 204 | 205 | return ( 206 | "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " 207 | f"FROM {'.'.join(info_schema_path)} " 208 | f"WHERE table_name = '{name}' AND table_schema = '{schema}'" 209 | ) 210 | 211 | def _normalize_table_path(self, path: DbPath) -> DbPath: 212 | if len(path) == 1: 213 | return None, self.default_schema, path[0] 214 | elif len(path) == 2: 215 | return None, path[0], path[1] 216 | elif len(path) == 3: 217 | return path 218 | 219 | raise ValueError( 220 | f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" 221 | ) 222 | 223 | @property 224 | def is_autocommit(self) -> bool: 225 | return True 226 | 227 | def query_table_unique_columns(self, path: DbPath) -> List[str]: 228 | return [] 229 | -------------------------------------------------------------------------------- /sqeleton/databases/trino.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from ..abcs.database_types import TemporalType, ColType_UUID, String_UUID 4 | from . import presto 5 | from .base import import_helper 6 | from .base import TIMESTAMP_PRECISION_POS 7 | 8 | 9 | @import_helper("trino") 10 | def import_trino(): 11 | import trino 12 | 13 | return trino 14 | 15 | 16 | Mixin_MD5 = presto.Mixin_MD5 17 | 18 | 19 | class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): 20 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 21 | if coltype.rounds: 22 | s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')" 23 | else: 24 | s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" 25 | 26 | return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" 27 | 28 | def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: 29 | if isinstance(coltype, String_UUID): 30 | return f"TRIM({value})" 31 | return f"CAST({value} AS VARCHAR)" 32 | 33 | 34 | class Dialect(presto.Dialect): 35 | name = "Trino" 36 | ARG_SYMBOL = "?" 37 | 38 | def set_timezone_to_utc(self) -> str: 39 | return "SET TIME ZONE '+00:00'" 40 | 41 | def uuid_value(self, u: uuid.UUID) -> str: 42 | return f"CAST('{u}' AS UUID)" 43 | 44 | 45 | class Trino(presto.Presto): 46 | dialect = Dialect() 47 | CONNECT_URI_HELP = "trino://@//" 48 | CONNECT_URI_PARAMS = ["catalog", "schema"] 49 | 50 | def __init__(self, **kw): 51 | 52 | trino = import_trino() 53 | 54 | if kw.get("schema"): 55 | self.default_schema = kw.get("schema") 56 | 57 | if kw.get("password"): 58 | kw["auth"] = trino.auth.BasicAuthentication( 59 | kw.pop("user"), kw.pop("password") 60 | ) 61 | kw["http_scheme"] = "https" 62 | 63 | cert = kw.pop("cert", None) 64 | self._conn = trino.dbapi.connect(**kw) 65 | if cert is not None: 66 | self._conn._http_session.verify = cert 67 | 68 | 69 | @property 70 | def is_autocommit(self) -> bool: 71 | return True -------------------------------------------------------------------------------- /sqeleton/databases/vertica.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils import match_regexps 4 | from .base import ( 5 | CHECKSUM_HEXDIGITS, 6 | MD5_HEXDIGITS, 7 | TIMESTAMP_PRECISION_POS, 8 | BaseDialect, 9 | ConnectError, 10 | DbPath, 11 | ColType, 12 | ThreadedDatabase, 13 | import_helper, 14 | Mixin_RandomSample, 15 | ) 16 | from ..abcs.database_types import ( 17 | Decimal, 18 | Float, 19 | FractionalType, 20 | Integer, 21 | TemporalType, 22 | Text, 23 | Timestamp, 24 | TimestampTZ, 25 | Boolean, 26 | ColType_UUID, 27 | ) 28 | from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema 29 | from ..abcs import Compilable 30 | from ..queries import table, this, SKIP 31 | 32 | 33 | @import_helper("vertica") 34 | def import_vertica(): 35 | import vertica_python 36 | 37 | return vertica_python 38 | 39 | 40 | class Mixin_MD5(AbstractMixin_MD5): 41 | def md5_as_int(self, s: str) -> str: 42 | return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" 43 | 44 | 45 | class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): 46 | def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: 47 | if coltype.rounds: 48 | return f"TO_CHAR({value}::TIMESTAMP({coltype.precision}), 'YYYY-MM-DD HH24:MI:SS.US')" 49 | 50 | timestamp6 = f"TO_CHAR({value}::TIMESTAMP(6), 'YYYY-MM-DD HH24:MI:SS.US')" 51 | return ( 52 | f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" 53 | ) 54 | 55 | def normalize_number(self, value: str, coltype: FractionalType) -> str: 56 | return self.to_string(f"CAST({value} AS NUMERIC(38, {coltype.precision}))") 57 | 58 | def normalize_uuid(self, value: str, _coltype: ColType_UUID) -> str: 59 | # Trim doesn't work on CHAR type 60 | return f"TRIM(CAST({value} AS VARCHAR))" 61 | 62 | def normalize_boolean(self, value: str, _coltype: Boolean) -> str: 63 | return self.to_string(f"cast ({value} as int)") 64 | 65 | 66 | class Mixin_Schema(AbstractMixin_Schema): 67 | def table_information(self) -> Compilable: 68 | return table("v_catalog", "tables") 69 | 70 | def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: 71 | return ( 72 | self.table_information() 73 | .where( 74 | this.table_schema == table_schema, 75 | this.table_name.like(like) if like is not None else SKIP, 76 | ) 77 | .select(this.table_name) 78 | ) 79 | 80 | 81 | class Dialect(BaseDialect, Mixin_Schema): 82 | name = "Vertica" 83 | ROUNDS_ON_PREC_LOSS = True 84 | 85 | TYPE_CLASSES = { 86 | # Timestamps 87 | "timestamp": Timestamp, 88 | "timestamptz": TimestampTZ, 89 | # Numbers 90 | "numeric": Decimal, 91 | "int": Integer, 92 | "float": Float, 93 | # Text 94 | "char": Text, 95 | "varchar": Text, 96 | # Boolean 97 | "boolean": Boolean, 98 | } 99 | MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} 100 | 101 | def quote(self, s: str): 102 | return f'"{s}"' 103 | 104 | def concat(self, items: List[str]) -> str: 105 | return " || ".join(items) 106 | 107 | def to_string(self, s: str) -> str: 108 | return f"CAST({s} AS VARCHAR)" 109 | 110 | def is_distinct_from(self, a: str, b: str) -> str: 111 | return f"not ({a} <=> {b})" 112 | 113 | def parse_type( 114 | self, 115 | table_path: DbPath, 116 | col_name: str, 117 | type_repr: str, 118 | datetime_precision: int = None, 119 | numeric_precision: int = None, 120 | numeric_scale: int = None, 121 | ) -> ColType: 122 | timestamp_regexps = { 123 | r"timestamp\(?(\d?)\)?": Timestamp, 124 | r"timestamptz\(?(\d?)\)?": TimestampTZ, 125 | } 126 | for m, t_cls in match_regexps(timestamp_regexps, type_repr): 127 | precision = int(m.group(1)) if m.group(1) else 6 128 | return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) 129 | 130 | number_regexps = { 131 | r"numeric\((\d+),(\d+)\)": Decimal, 132 | } 133 | for m, n_cls in match_regexps(number_regexps, type_repr): 134 | _prec, scale = map(int, m.groups()) 135 | return n_cls(scale) 136 | 137 | string_regexps = { 138 | r"varchar\((\d+)\)": Text, 139 | r"char\((\d+)\)": Text, 140 | } 141 | for m, n_cls in match_regexps(string_regexps, type_repr): 142 | return n_cls() 143 | 144 | return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) 145 | 146 | def set_timezone_to_utc(self) -> str: 147 | return "SET TIME ZONE TO 'UTC'" 148 | 149 | def current_timestamp(self) -> str: 150 | return "current_timestamp(6)" 151 | 152 | 153 | class Vertica(ThreadedDatabase): 154 | dialect = Dialect() 155 | CONNECT_URI_HELP = "vertica://:@/" 156 | CONNECT_URI_PARAMS = ["database?"] 157 | 158 | default_schema = "public" 159 | 160 | def __init__(self, *, thread_count, **kw): 161 | self._args = kw 162 | self._args["AUTOCOMMIT"] = False 163 | 164 | super().__init__(thread_count=thread_count) 165 | 166 | def create_connection(self): 167 | vertica = import_vertica() 168 | try: 169 | c = vertica.connect(**self._args) 170 | return c 171 | except vertica.errors.ConnectionError as e: 172 | raise ConnectError(*e.args) from e 173 | 174 | def select_table_schema(self, path: DbPath) -> str: 175 | schema, name = self._normalize_table_path(path) 176 | 177 | return ( 178 | "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " 179 | "FROM V_CATALOG.COLUMNS " 180 | f"WHERE table_name = '{name}' AND table_schema = '{schema}'" 181 | ) 182 | -------------------------------------------------------------------------------- /sqeleton/queries/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiler import CompileError, Compiler 2 | from .base import SKIP, T_SKIP 3 | from .api import ( 4 | this, 5 | join, 6 | outerjoin, 7 | table, 8 | sum_, 9 | avg, 10 | min_, 11 | max_, 12 | cte, 13 | commit, 14 | when, 15 | coalesce, 16 | and_, 17 | if_, 18 | or_, 19 | leftjoin, 20 | rightjoin, 21 | current_timestamp, 22 | code, 23 | ) 24 | from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column, ITable, ForeignKey 25 | from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString 26 | -------------------------------------------------------------------------------- /sqeleton/queries/api.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .ast_classes import * 4 | from .base import args_as_tuple 5 | from ..schema import SchemaInput, _Schema 6 | 7 | 8 | this = This() 9 | 10 | 11 | def join(*tables: ITable) -> Join: 12 | """Inner-join a sequence of table expressions" 13 | 14 | When joining, it's recommended to use explicit tables names, instead of `this`, in order to avoid potential name collisions. 15 | 16 | Example: 17 | :: 18 | 19 | person = table('person') 20 | city = table('city') 21 | 22 | name_and_city = ( 23 | join(person, city) 24 | .on(person['city_id'] == city['id']) 25 | .select(person['id'], city['name']) 26 | ) 27 | """ 28 | return Join(tables) 29 | 30 | 31 | def leftjoin(*tables: ITable) -> Join: 32 | """Left-joins a sequence of table expressions. 33 | 34 | See Also: ``join()`` 35 | """ 36 | return Join(tables, "LEFT") 37 | 38 | 39 | def rightjoin(*tables: ITable): 40 | """Right-joins a sequence of table expressions. 41 | 42 | See Also: ``join()`` 43 | """ 44 | return Join(tables, "RIGHT") 45 | 46 | 47 | def outerjoin(*tables: ITable): 48 | """Outer-joins a sequence of table expressions. 49 | 50 | See Also: ``join()`` 51 | """ 52 | return Join(tables, "FULL OUTER") 53 | 54 | 55 | def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None): 56 | """Define a CTE""" 57 | return Cte(expr, name, params) 58 | 59 | 60 | def table(*path: str, schema: SchemaInput = None) -> TablePath: 61 | """Defines a table with a path (dotted name), and optionally a schema. 62 | 63 | Parameters: 64 | path: A list of names that make up the path to the table. 65 | schema: a dictionary of {name: type} 66 | """ 67 | if len(path) == 1 and isinstance(path[0], tuple): 68 | (path,) = path 69 | if len(path) == 1 and isinstance(path[0], TablePath): 70 | path = path[0].path 71 | if not all(isinstance(i, str) for i in path): 72 | raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}") 73 | 74 | if schema: 75 | schema = _Schema.make(schema) 76 | return TablePath(path, schema) 77 | 78 | 79 | def table_from_sqlmodel(cls: type) -> TablePath: 80 | import sqlmodel 81 | 82 | assert issubclass(cls, sqlmodel.SQLModel) 83 | table_name = cls.__name__.lower() 84 | schema = cls.__annotations__ 85 | return table(table_name, schema=schema) 86 | 87 | 88 | def or_(*exprs: Expr): 89 | """Apply OR between a sequence of boolean expressions""" 90 | exprs = args_as_tuple(exprs) 91 | if len(exprs) == 1: 92 | return exprs[0] 93 | return BinBoolOp("OR", exprs) 94 | 95 | 96 | def and_(*exprs: Expr): 97 | """Apply AND between a sequence of boolean expressions""" 98 | exprs = args_as_tuple(exprs) 99 | if len(exprs) == 1: 100 | return exprs[0] 101 | return BinBoolOp("AND", exprs) 102 | 103 | 104 | def sum_(expr: Expr): 105 | """Call SUM(expr)""" 106 | return Func("sum", [expr]) 107 | 108 | 109 | def avg(expr: Expr): 110 | """Call AVG(expr)""" 111 | return Func("avg", [expr]) 112 | 113 | 114 | def min_(expr: Expr): 115 | """Call MIN(expr)""" 116 | return Func("min", [expr]) 117 | 118 | 119 | def max_(expr: Expr): 120 | """Call MAX(expr)""" 121 | return Func("max", [expr]) 122 | 123 | 124 | def exists(expr: Expr): 125 | """Call EXISTS(expr)""" 126 | return Exists(expr) 127 | 128 | 129 | def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None): 130 | """Conditional expression, shortcut to when-then-else. 131 | 132 | Example: 133 | :: 134 | 135 | # SELECT CASE WHEN b THEN c ELSE d END FROM foo 136 | table('foo').select(if_(b, c, d)) 137 | """ 138 | return when(cond).then(then).else_(else_) 139 | 140 | 141 | def when(*when_exprs: Expr): 142 | """Start a when-then expression 143 | 144 | Example: 145 | :: 146 | 147 | # SELECT CASE 148 | # WHEN (type = 'text') THEN text 149 | # WHEN (type = 'number') THEN number 150 | # ELSE 'unknown type' END 151 | # FROM foo 152 | rows = table('foo').select( 153 | when(this.type == 'text').then(this.text) 154 | .when(this.type == 'number').then(this.number) 155 | .else_('unknown type') 156 | ) 157 | """ 158 | return CaseWhen([]).when(*when_exprs) 159 | 160 | 161 | def coalesce(*exprs): 162 | "Returns a call to COALESCE" 163 | exprs = args_as_tuple(exprs) 164 | return Func("COALESCE", exprs) 165 | 166 | 167 | def insert_rows_in_batches(db, tbl: TablePath, rows, *, columns=None, batch_size=1024 * 8): 168 | assert batch_size > 0 169 | rows = list(rows) 170 | 171 | while rows: 172 | batch, rows = rows[:batch_size], rows[batch_size:] 173 | db.query(tbl.insert_rows(batch, columns=columns)) 174 | 175 | 176 | def current_timestamp(): 177 | """Returns CURRENT_TIMESTAMP() or NOW()""" 178 | return CurrentTimestamp() 179 | 180 | 181 | def code(code: str, **kw: Dict[str, Expr]) -> Code: 182 | """Inline raw SQL code. 183 | 184 | It allows users to use features and syntax that Sqeleton doesn't yet support. 185 | 186 | It's the user's responsibility to make sure the contents of the string given to `code()` are correct and safe for execution. 187 | 188 | Strings given to `code()` are actually templates, and can embed query expressions given as arguments: 189 | 190 | Parameters: 191 | code: template string of SQL code. Templated variables are signified with '{var}'. 192 | kw: optional parameters for SQL template. 193 | 194 | Examples: 195 | :: 196 | 197 | # SELECT b, FROM tmp WHERE 198 | table('tmp').select(this.b, code("")).where(code("")) 199 | 200 | :: 201 | 202 | def tablesample(tbl, size): 203 | return code("SELECT * FROM {tbl} TABLESAMPLE BERNOULLI ({size})", tbl=tbl, size=size) 204 | 205 | nonzero = table('points').where(this.x > 0, this.y > 0) 206 | 207 | # SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10) 208 | sample_expr = tablesample(nonzero) 209 | """ 210 | return Code(code, kw) 211 | 212 | 213 | commit = Commit() 214 | -------------------------------------------------------------------------------- /sqeleton/queries/base.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | from ..abcs import DbPath, DbKey 4 | from ..schema import Schema 5 | 6 | 7 | class T_SKIP: 8 | def __repr__(self): 9 | return "SKIP" 10 | 11 | def __deepcopy__(self, memo): 12 | return self 13 | 14 | 15 | SKIP = T_SKIP() 16 | 17 | 18 | class SqeletonError(Exception): 19 | pass 20 | 21 | 22 | def args_as_tuple(exprs): 23 | if len(exprs) == 1: 24 | (e,) = exprs 25 | if isinstance(e, Generator): 26 | return tuple(e) 27 | return exprs 28 | -------------------------------------------------------------------------------- /sqeleton/queries/extras.py: -------------------------------------------------------------------------------- 1 | "Useful AST classes that don't quite fall within the scope of regular SQL" 2 | 3 | from typing import Callable, Sequence 4 | from runtype import dataclass 5 | 6 | from ..abcs.database_types import ColType, Native_UUID 7 | 8 | from .compiler import Compiler, md 9 | from .ast_classes import Expr, ExprNode, Concat, Code 10 | 11 | 12 | 13 | 14 | @dataclass 15 | class NormalizeAsString(ExprNode): 16 | expr: ExprNode 17 | expr_type: ColType = None 18 | type = str 19 | 20 | 21 | @dataclass 22 | class ApplyFuncAndNormalizeAsString(ExprNode): 23 | expr: ExprNode 24 | apply_func: Callable = None 25 | 26 | 27 | @dataclass 28 | class Checksum(ExprNode): 29 | exprs: Sequence[Expr] 30 | 31 | 32 | class Compiler(Compiler): 33 | @md 34 | def compile_node(c: Compiler, n: NormalizeAsString) -> str: 35 | expr = c.compile(n.expr) 36 | return c.dialect.normalize_value_by_type(expr, n.expr_type or n.expr.type) 37 | 38 | 39 | @md 40 | def compile_node(c: Compiler, n: ApplyFuncAndNormalizeAsString) -> str: 41 | expr = n.expr 42 | expr_type = expr.type 43 | 44 | if isinstance(expr_type, Native_UUID): 45 | # Normalize first, apply template after (for uuids) 46 | # Needed because min/max(uuid) fails in postgresql 47 | expr = NormalizeAsString(expr, expr_type) 48 | if n.apply_func is not None: 49 | expr = n.apply_func(expr) # Apply template using Python's string formatting 50 | 51 | else: 52 | # Apply template before normalizing (for ints) 53 | if n.apply_func is not None: 54 | expr = n.apply_func(expr) # Apply template using Python's string formatting 55 | expr = NormalizeAsString(expr, expr_type) 56 | 57 | return c.compile(expr) 58 | 59 | 60 | @md 61 | def compile_node(c: Compiler, n: Checksum) -> str: 62 | if len(n.exprs) > 1: 63 | exprs = [Code(f"coalesce({c.compile(expr)}, '')") for expr in n.exprs] 64 | # exprs = [c.compile(e) for e in exprs] 65 | expr = Concat(exprs, "|") 66 | else: 67 | # No need to coalesce - safe to assume that key cannot be null 68 | (expr,) = n.exprs 69 | expr = c.compile(expr) 70 | md5 = c.dialect.md5_as_int(expr) 71 | return f"sum({md5})" 72 | -------------------------------------------------------------------------------- /sqeleton/query_utils.py: -------------------------------------------------------------------------------- 1 | "Module for query utilities that didn't make it into the query-builder (yet)" 2 | 3 | from contextlib import suppress 4 | 5 | from sqeleton.databases import DbPath, QueryError, Oracle 6 | from sqeleton.queries import table, commit, Expr 7 | 8 | 9 | def _drop_table_oracle(name: DbPath): 10 | t = table(name) 11 | # Experience shows double drop is necessary 12 | with suppress(QueryError): 13 | yield t.drop() 14 | yield t.drop() 15 | yield commit 16 | 17 | 18 | def _drop_table(name: DbPath): 19 | t = table(name) 20 | yield t.drop(if_exists=True) 21 | yield commit 22 | 23 | 24 | def drop_table(db, tbl): 25 | if isinstance(db, Oracle): 26 | db.query(_drop_table_oracle(tbl)) 27 | else: 28 | db.query(_drop_table(tbl)) 29 | 30 | 31 | def _append_to_table_oracle(path: DbPath, expr: Expr): 32 | """See append_to_table""" 33 | assert expr.schema, expr 34 | t = table(path, schema=expr.schema) 35 | with suppress(QueryError): 36 | yield t.create() # uses expr.schema 37 | yield commit 38 | yield t.insert_expr(expr) 39 | yield commit 40 | 41 | 42 | def _append_to_table(path: DbPath, expr: Expr): 43 | """Append to table""" 44 | assert expr.schema, expr 45 | t = table(path, schema=expr.schema) 46 | yield t.create(if_not_exists=True) # uses expr.schema 47 | yield commit 48 | yield t.insert_expr(expr) 49 | yield commit 50 | 51 | 52 | def append_to_table(db, path, expr): 53 | f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table 54 | db.query(f(path, expr)) 55 | -------------------------------------------------------------------------------- /sqeleton/repl.py: -------------------------------------------------------------------------------- 1 | import rich.table 2 | import logging 3 | 4 | from pathlib import Path 5 | from time import time 6 | 7 | ### XXX Fix for Python 3.8 bug (https://github.com/prompt-toolkit/python-prompt-toolkit/issues/1023) 8 | import asyncio 9 | import selectors 10 | 11 | selector = selectors.SelectSelector() 12 | loop = asyncio.SelectorEventLoop(selector) 13 | asyncio.set_event_loop(loop) 14 | ### XXX End of fix 15 | 16 | from pygments.lexers.sql import SqlLexer 17 | from pygments.styles import get_style_by_name 18 | from prompt_toolkit import PromptSession 19 | from prompt_toolkit.lexers import PygmentsLexer 20 | from prompt_toolkit.filters import Condition 21 | from prompt_toolkit.application.current import get_app 22 | from prompt_toolkit.history import FileHistory 23 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory 24 | from prompt_toolkit.output.color_depth import ColorDepth 25 | from prompt_toolkit.completion import WordCompleter 26 | from prompt_toolkit.styles.pygments import style_from_pygments_cls 27 | 28 | from . import __version__ 29 | 30 | STYLE = style_from_pygments_cls(get_style_by_name("dracula")) 31 | 32 | 33 | sql_keywords = [ 34 | "abort", 35 | "action", 36 | "add", 37 | "after", 38 | "all", 39 | "alter", 40 | "analyze", 41 | "and", 42 | "as", 43 | "asc", 44 | "attach", 45 | "autoincrement", 46 | "before", 47 | "begin", 48 | "between", 49 | "by", 50 | "cascade", 51 | "case", 52 | "cast", 53 | "check", 54 | "collate", 55 | "column", 56 | "commit", 57 | "conflict", 58 | "constraint", 59 | "create", 60 | "cross", 61 | "current_date", 62 | "current_time", 63 | "current_timestamp", 64 | "database", 65 | "default", 66 | "deferrable", 67 | "deferred", 68 | "delete", 69 | "desc", 70 | "detach", 71 | "distinct", 72 | "drop", 73 | "each", 74 | "else", 75 | "end", 76 | "escape", 77 | "except", 78 | "exclusive", 79 | "exists", 80 | "explain", 81 | "fail", 82 | "for", 83 | "foreign", 84 | "from", 85 | "full", 86 | "glob", 87 | "group", 88 | "having", 89 | "if", 90 | "ignore", 91 | "immediate", 92 | "in", 93 | "index", 94 | "indexed", 95 | "initially", 96 | "inner", 97 | "insert", 98 | "instead", 99 | "intersect", 100 | "into", 101 | "is", 102 | "isnull", 103 | "join", 104 | "key", 105 | "left", 106 | "like", 107 | "limit", 108 | "match", 109 | "natural", 110 | "no", 111 | "not", 112 | "notnull", 113 | "null", 114 | "of", 115 | "offset", 116 | "on", 117 | "or", 118 | "order", 119 | "outer", 120 | "plan", 121 | "pragma", 122 | "primary", 123 | "query", 124 | "raise", 125 | "recursive", 126 | "references", 127 | "regexp", 128 | "reindex", 129 | "release", 130 | "rename", 131 | "replace", 132 | "restrict", 133 | "right", 134 | "rollback", 135 | "row", 136 | "savepoint", 137 | "select", 138 | "set", 139 | "table", 140 | "temp", 141 | "temporary", 142 | "then", 143 | "to", 144 | "transaction", 145 | "trigger", 146 | "union", 147 | "unique", 148 | "update", 149 | "using", 150 | "vacuum", 151 | "values", 152 | "view", 153 | "virtual", 154 | "when", 155 | "where", 156 | "with", 157 | "without", 158 | ] 159 | sql_completer = WordCompleter(sql_keywords, ignore_case=True) 160 | 161 | 162 | def add_keywords(new_keywords): 163 | global sql_keywords 164 | new = set(new_keywords) - set(sql_keywords) 165 | sql_keywords += new 166 | 167 | 168 | def _code_is_valid(code: str): 169 | if code: 170 | if code[0].isalnum(): 171 | return code[-1] in ";\n" 172 | return True 173 | 174 | 175 | def repl(uri, prompt=" >> "): 176 | rich.print(f"[purple]Sqeleton {__version__} interactive prompt. Enter '?' for help[/purple]") 177 | 178 | db = connect(uri) 179 | db_name = db.name 180 | 181 | try: 182 | session = PromptSession( 183 | lexer=PygmentsLexer(SqlLexer), 184 | completer=sql_completer, 185 | # key_bindings=kb 186 | history=FileHistory(str(Path.home() / ".sqeleton_repl_history")), 187 | auto_suggest=AutoSuggestFromHistory(), 188 | color_depth=ColorDepth.TRUE_COLOR, 189 | style=STYLE, 190 | ) 191 | 192 | @Condition 193 | def multiline_filter(): 194 | text = get_app().layout.get_buffer_by_name("DEFAULT_BUFFER").text 195 | return not _code_is_valid(text) 196 | 197 | prompt = f"{db_name}> " 198 | while True: 199 | # Read 200 | try: 201 | q = session.prompt(prompt, multiline=multiline_filter) 202 | if not q.strip(): 203 | continue 204 | 205 | start_time = time() 206 | run_command(db, q) 207 | 208 | duration = time() - start_time 209 | if duration > 1: 210 | rich.print("(Query took %.2f seconds)" % duration) 211 | 212 | except KeyboardInterrupt: 213 | rich.print("Interrupted (Ctrl+C)") 214 | 215 | except (KeyboardInterrupt, EOFError): 216 | rich.print("Exiting Sqeleton interaction") 217 | 218 | 219 | from . import connect 220 | 221 | import sys 222 | 223 | 224 | def print_table(rows, schema, table_name=""): 225 | # Print rows in a rich table 226 | t = rich.table.Table(title=table_name, caption=f"{len(rows)} rows") 227 | for col in schema: 228 | t.add_column(col) 229 | for r in rows: 230 | t.add_row(*map(str, r)) 231 | rich.print(t) 232 | 233 | 234 | def help(): 235 | rich.print("Commands:") 236 | rich.print(" ?mytable - shows schema of table 'mytable'") 237 | rich.print(" * - shows list of all tables") 238 | rich.print(" *pattern - shows list of all tables with name like pattern") 239 | rich.print("Otherwise, runs regular SQL query") 240 | 241 | 242 | def run_command(db, q): 243 | if q.startswith("*"): 244 | pattern = q[1:] 245 | try: 246 | names = db.query(db.dialect.list_tables(db.default_schema, like=f"%{pattern}%" if pattern else None)) 247 | except Exception as e: 248 | logging.exception(e) 249 | else: 250 | print_table(names, ["name"], "List of tables") 251 | add_keywords([".".join(n) for n in names]) 252 | elif q.startswith("?"): 253 | table_name = q[1:] 254 | if not table_name: 255 | help() 256 | return 257 | try: 258 | path = db.parse_table_name(table_name) 259 | print("->", path) 260 | schema = db.query_table_schema(path) 261 | except Exception as e: 262 | logging.error(e) 263 | else: 264 | print_table([(k, v[1]) for k, v in schema.items()], ["name", "type"], f"Table '{table_name}'") 265 | add_keywords(schema.keys()) 266 | else: 267 | # Normal SQL query 268 | try: 269 | res = db.query(q) 270 | except Exception as e: 271 | logging.error(e) 272 | else: 273 | if res: 274 | print_table(res.rows, res.columns, None) 275 | add_keywords(res.columns) 276 | 277 | 278 | def main(): 279 | uri = sys.argv[1] 280 | return repl(uri) 281 | 282 | 283 | if __name__ == "__main__": 284 | main() 285 | -------------------------------------------------------------------------------- /sqeleton/schema.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Union, Type 3 | 4 | from runtype import dataclass 5 | 6 | from .utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict 7 | from .abcs import AbstractDatabase, DbPath 8 | 9 | logger = logging.getLogger("schema") 10 | 11 | Schema = CaseAwareMapping 12 | 13 | class TableType: 14 | pass 15 | # TODO: This should replace the current Schema type 16 | 17 | @classmethod 18 | def is_superclass(cls, t): 19 | return isinstance(t, type) and issubclass(t, cls) 20 | 21 | 22 | SchemaInput = Union[Type[TableType], Schema, dict] 23 | 24 | @dataclass 25 | class Options: 26 | default: Any = None 27 | primary_key: bool = False 28 | auto: bool = False 29 | # TODO: foreign_key, unique 30 | # TODO: index? 31 | 32 | @dataclass 33 | class _Field: 34 | type: type 35 | options: Options 36 | 37 | class _Schema(CaseAwareMapping[Union[type, _Field]]): 38 | pass 39 | 40 | @classmethod 41 | def make(cls, schema: SchemaInput): 42 | assert schema 43 | if TableType.is_superclass(schema): 44 | def _make_field(k: str, v: type): 45 | field = getattr(schema, k) 46 | if field: 47 | if not isinstance(field, Options): 48 | field = Options(default=v) 49 | return _Field(v, field) 50 | return v 51 | 52 | schema = CaseSensitiveDict({k:_make_field(k, v) for k,v in schema.__annotations__.items()}) 53 | 54 | elif isinstance(schema, CaseAwareMapping): 55 | pass 56 | else: 57 | assert isinstance(schema, dict), schema 58 | schema = CaseSensitiveDict(schema) 59 | 60 | return schema 61 | 62 | def options(**kw) -> Any: # Any, so that type-checking doesn't complain 63 | return Options(**kw) 64 | 65 | 66 | def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: 67 | if case_sensitive: 68 | return CaseSensitiveDict(schema) 69 | 70 | if len({k.lower() for k in schema}) < len(schema): 71 | logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') 72 | # logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).") 73 | return CaseInsensitiveDict(schema) 74 | -------------------------------------------------------------------------------- /sqeleton/utils.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Iterable, 3 | Iterator, 4 | MutableMapping, 5 | Union, 6 | Any, 7 | Sequence, 8 | Dict, 9 | Hashable, 10 | TypeVar, 11 | Generator, 12 | List, 13 | ) 14 | from abc import ABC, abstractmethod 15 | from weakref import ref 16 | import math 17 | import string 18 | import re 19 | from uuid import UUID 20 | from urllib.parse import urlparse 21 | 22 | # -- Common -- 23 | 24 | try: 25 | from typing import Self 26 | except ImportError: 27 | Self = Any 28 | 29 | 30 | class WeakCache: 31 | def __init__(self): 32 | self._cache = {} 33 | 34 | def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable: 35 | if isinstance(k, dict): 36 | return tuple(k.items()) 37 | return k 38 | 39 | def add(self, key: Union[dict, Hashable], value: Any): 40 | key = self._hashable_key(key) 41 | self._cache[key] = ref(value) 42 | 43 | def get(self, key: Union[dict, Hashable]) -> Any: 44 | key = self._hashable_key(key) 45 | 46 | value = self._cache[key]() 47 | if value is None: 48 | del self._cache[key] 49 | raise KeyError(f"Key {key} not found, or no longer a valid reference") 50 | 51 | return value 52 | 53 | 54 | def join_iter(joiner: Any, iterable: Iterable) -> Iterable: 55 | it = iter(iterable) 56 | try: 57 | yield next(it) 58 | except StopIteration: 59 | return 60 | for i in it: 61 | yield joiner 62 | yield i 63 | 64 | 65 | def safezip(*args): 66 | "zip but makes sure all sequences are the same length" 67 | lens = list(map(len, args)) 68 | if len(set(lens)) != 1: 69 | raise ValueError(f"Mismatching lengths in arguments to safezip: {lens}") 70 | return zip(*args) 71 | 72 | 73 | def is_uuid(u): 74 | try: 75 | UUID(u) 76 | except ValueError: 77 | return False 78 | return True 79 | 80 | 81 | def match_regexps(regexps: Dict[str, Any], s: str) -> Generator[tuple, None, None]: 82 | for regexp, v in regexps.items(): 83 | m = re.match(regexp + "$", s) 84 | if m: 85 | yield m, v 86 | 87 | 88 | # -- Schema -- 89 | 90 | V = TypeVar("V") 91 | 92 | 93 | class CaseAwareMapping(MutableMapping[str, V]): 94 | @abstractmethod 95 | def get_key(self, key: str) -> str: 96 | ... 97 | 98 | @abstractmethod 99 | def __init__(self, initial): 100 | ... 101 | 102 | def new(self, initial=()): 103 | return type(self)(initial) 104 | 105 | 106 | class CaseInsensitiveDict(CaseAwareMapping[V]): 107 | def __init__(self, initial): 108 | self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} 109 | 110 | def __getitem__(self, key: str) -> V: 111 | return self._dict[key.lower()][1] 112 | 113 | def __iter__(self) -> Iterator[str]: 114 | return iter(self._dict) 115 | 116 | def __len__(self) -> int: 117 | return len(self._dict) 118 | 119 | def __setitem__(self, key: str, value): 120 | k = key.lower() 121 | if k in self._dict: 122 | key = self._dict[k][0] 123 | self._dict[k] = key, value 124 | 125 | def __delitem__(self, key: str): 126 | del self._dict[key.lower()] 127 | 128 | def get_key(self, key: str) -> str: 129 | return self._dict[key.lower()][0] 130 | 131 | def __repr__(self) -> str: 132 | return repr(dict(self.items())) 133 | 134 | 135 | class CaseSensitiveDict(dict, CaseAwareMapping): 136 | def get_key(self, key) -> str: 137 | self[key] # Throw KeyError if key doesn't exist 138 | return key 139 | 140 | def as_insensitive(self): 141 | return CaseInsensitiveDict(self) 142 | 143 | 144 | # -- Alphanumerics -- 145 | 146 | alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase 147 | 148 | 149 | class ArithString(ABC): 150 | @classmethod 151 | def new(cls, *args, **kw): 152 | return cls(*args, **kw) 153 | 154 | def range(self, other: "ArithString", count: int): 155 | assert isinstance(other, ArithString) 156 | checkpoints = split_space(self.int, other.int, count) 157 | return [self.new(int=i) for i in checkpoints] 158 | 159 | # @property 160 | # @abstractmethod 161 | # def int(self): 162 | # ... 163 | 164 | 165 | class ArithUUID(UUID, ArithString): 166 | "A UUID that supports basic arithmetic (add, sub)" 167 | 168 | def __int__(self): 169 | return self.int 170 | 171 | def __add__(self, other: int): 172 | if isinstance(other, int): 173 | return self.new(int=self.int + other) 174 | return NotImplemented 175 | 176 | def __sub__(self, other: Union[UUID, int]): 177 | if isinstance(other, int): 178 | return self.new(int=self.int - other) 179 | elif isinstance(other, UUID): 180 | return self.int - other.int 181 | return NotImplemented 182 | 183 | 184 | def numberToAlphanum(num: int, base: str = alphanums) -> str: 185 | digits = [] 186 | while num > 0: 187 | num, remainder = divmod(num, len(base)) 188 | digits.append(remainder) 189 | return "".join(base[i] for i in digits[::-1]) 190 | 191 | 192 | def alphanumToNumber(alphanum: str, base: str = alphanums) -> int: 193 | num = 0 194 | for c in alphanum: 195 | num = num * len(base) + base.index(c) 196 | return num 197 | 198 | 199 | def justify_alphanums(s1: str, s2: str): 200 | max_len = max(len(s1), len(s2)) 201 | s1 = s1.ljust(max_len) 202 | s2 = s2.ljust(max_len) 203 | return s1, s2 204 | 205 | 206 | def alphanums_to_numbers(s1: str, s2: str): 207 | s1, s2 = justify_alphanums(s1, s2) 208 | n1 = alphanumToNumber(s1) 209 | n2 = alphanumToNumber(s2) 210 | return n1, n2 211 | 212 | 213 | class ArithAlphanumeric(ArithString): 214 | def __init__(self, s: str, max_len=None): 215 | if s is None: 216 | raise ValueError("Alphanum string cannot be None") 217 | if max_len and len(s) > max_len: 218 | raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") 219 | 220 | for ch in s: 221 | if ch not in alphanums: 222 | raise ValueError(f"Unexpected character {ch} in alphanum string") 223 | 224 | self._str = s 225 | self._max_len = max_len 226 | 227 | def __str__(self): 228 | s = self._str 229 | if self._max_len: 230 | s = s.rjust(self._max_len, alphanums[0]) 231 | return s 232 | 233 | def __len__(self): 234 | return len(self._str) 235 | 236 | def __repr__(self): 237 | return f'alphanum"{self._str}"' 238 | 239 | def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric": 240 | if isinstance(other, int): 241 | if other != 1: 242 | raise NotImplementedError("not implemented for arbitrary numbers") 243 | num = alphanumToNumber(self._str) 244 | return self.new(numberToAlphanum(num + 1)) 245 | 246 | return NotImplemented 247 | 248 | def range(self, other: "ArithAlphanumeric", count: int): 249 | assert isinstance(other, ArithAlphanumeric) 250 | n1, n2 = alphanums_to_numbers(self._str, other._str) 251 | split = split_space(n1, n2, count) 252 | return [self.new(numberToAlphanum(s)) for s in split] 253 | 254 | def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float: 255 | if isinstance(other, ArithAlphanumeric): 256 | n1, n2 = alphanums_to_numbers(self._str, other._str) 257 | return n1 - n2 258 | 259 | return NotImplemented 260 | 261 | def __ge__(self, other): 262 | if not isinstance(other, type(self)): 263 | return NotImplemented 264 | return self._str >= other._str 265 | 266 | def __lt__(self, other): 267 | if not isinstance(other, type(self)): 268 | return NotImplemented 269 | return self._str < other._str 270 | 271 | def __eq__(self, other): 272 | if not isinstance(other, type(self)): 273 | return NotImplemented 274 | return self._str == other._str 275 | 276 | def new(self, *args, **kw): 277 | return type(self)(*args, **kw, max_len=self._max_len) 278 | 279 | 280 | def number_to_human(n): 281 | millnames = ["", "k", "m", "b"] 282 | n = float(n) 283 | millidx = max( 284 | 0, 285 | min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))), 286 | ) 287 | 288 | return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) 289 | 290 | 291 | def split_space(start, end, count) -> List[int]: 292 | size = end - start 293 | assert count <= size, (count, size) 294 | return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1] 295 | 296 | 297 | def remove_passwords_in_dict(d: dict, replace_with: str = "***"): 298 | for k, v in d.items(): 299 | if k == "password": 300 | d[k] = replace_with 301 | elif isinstance(v, dict): 302 | remove_passwords_in_dict(v, replace_with) 303 | elif k.startswith("database"): 304 | d[k] = remove_password_from_url(v, replace_with) 305 | 306 | 307 | def _join_if_any(sym, args): 308 | args = list(args) 309 | if not args: 310 | return "" 311 | return sym.join(str(a) for a in args if a) 312 | 313 | 314 | def remove_password_from_url(url: str, replace_with: str = "***") -> str: 315 | parsed = urlparse(url) 316 | account = parsed.username or "" 317 | if parsed.password: 318 | account += ":" + replace_with 319 | host = _join_if_any(":", filter(None, [parsed.hostname, parsed.port])) 320 | netloc = _join_if_any("@", filter(None, [account, host])) 321 | replaced = parsed._replace(netloc=netloc) 322 | return replaced.geturl() 323 | 324 | 325 | def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]: 326 | reo = re.compile(pattern.replace("%", ".*").replace("?", ".") + "$") 327 | for s in strs: 328 | if reo.match(s): 329 | yield s 330 | 331 | 332 | class UnknownMeta(type): 333 | def __instancecheck__(self, instance): 334 | return instance is Unknown 335 | 336 | def __repr__(self): 337 | return "Unknown" 338 | 339 | 340 | class Unknown(metaclass=UnknownMeta): 341 | def __nonzero__(self): 342 | raise TypeError() 343 | 344 | def __new__(cls, *args, **kwargs): 345 | raise RuntimeError("Unknown is a singleton") 346 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erezsh/sqeleton/164f62ddb047153b0a2ac5e0220098f6801792c2/tests/__init__.py -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import string 4 | import random 5 | from typing import Callable 6 | import unittest 7 | import logging 8 | import subprocess 9 | 10 | import sqeleton 11 | from parameterized import parameterized_class 12 | 13 | from sqeleton import databases as db 14 | from sqeleton import connect 15 | from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue 16 | from sqeleton.queries import table 17 | from sqeleton.databases import Database 18 | from sqeleton.query_utils import drop_table 19 | 20 | 21 | # We write 'or None' because Github sometimes creates empty env vars for secrets 22 | TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" 23 | TEST_POSTGRESQL_CONN_STRING: str = "postgresql://postgres:Password1@localhost/postgres" 24 | TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("SNOWFLAKE_URI") or None 25 | TEST_PRESTO_CONN_STRING: str = os.environ.get("PRESTO_URI") or None 26 | TEST_BIGQUERY_CONN_STRING: str = os.environ.get("BIGQUERY_URI") or None 27 | TEST_REDSHIFT_CONN_STRING: str = os.environ.get("REDSHIFT_URI") or None 28 | TEST_ORACLE_CONN_STRING: str = None 29 | TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATABRICKS_URI") 30 | TEST_TRINO_CONN_STRING: str = os.environ.get("TRINO_URI") or None 31 | # clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" 32 | TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("CLICKHOUSE_URI") 33 | # vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica" 34 | TEST_VERTICA_CONN_STRING: str = os.environ.get("VERTICA_URI") 35 | TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" 36 | 37 | 38 | DEFAULT_N_SAMPLES = 50 39 | N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) 40 | BENCHMARK = os.environ.get("BENCHMARK", False) 41 | N_THREADS = int(os.environ.get("N_THREADS", 1)) 42 | TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite? 43 | 44 | 45 | def get_git_revision_short_hash() -> str: 46 | return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() 47 | 48 | 49 | GIT_REVISION = get_git_revision_short_hash() 50 | 51 | level = logging.ERROR 52 | if os.environ.get("LOG_LEVEL", False): 53 | level = getattr(logging, os.environ["LOG_LEVEL"].upper()) 54 | 55 | logging.basicConfig(level=level) 56 | logging.getLogger("database").setLevel(level) 57 | 58 | try: 59 | from .local_settings import * 60 | except ImportError: 61 | pass # No local settings 62 | 63 | 64 | CONN_STRINGS = { 65 | db.BigQuery: TEST_BIGQUERY_CONN_STRING, 66 | db.MySQL: TEST_MYSQL_CONN_STRING, 67 | db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING, 68 | db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, 69 | db.Redshift: TEST_REDSHIFT_CONN_STRING, 70 | db.Oracle: TEST_ORACLE_CONN_STRING, 71 | db.Presto: TEST_PRESTO_CONN_STRING, 72 | db.Databricks: TEST_DATABRICKS_CONN_STRING, 73 | db.Trino: TEST_TRINO_CONN_STRING, 74 | db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, 75 | db.Vertica: TEST_VERTICA_CONN_STRING, 76 | db.DuckDB: TEST_DUCKDB_CONN_STRING, 77 | } 78 | 79 | _database_instances = {} 80 | 81 | 82 | def get_conn(cls: type, shared: bool = True) -> Database: 83 | if shared: 84 | if cls not in _database_instances: 85 | _database_instances[cls] = get_conn(cls, shared=False) 86 | return _database_instances[cls] 87 | 88 | con = sqeleton.connect.load_mixins(AbstractMixin_NormalizeValue) 89 | return con(CONN_STRINGS[cls], N_THREADS) 90 | 91 | 92 | def _print_used_dbs(): 93 | used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None} 94 | unused = {k.__name__ for k, v in CONN_STRINGS.items() if v is None} 95 | 96 | print(f"Testing databases: {', '.join(used)}") 97 | if unused: 98 | logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") 99 | if TEST_ACROSS_ALL_DBS: 100 | logging.info( 101 | f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" 102 | ) 103 | 104 | 105 | _print_used_dbs() 106 | CONN_STRINGS = {k: v for k, v in CONN_STRINGS.items() if v is not None} 107 | 108 | 109 | def random_table_suffix() -> str: 110 | char_set = string.ascii_lowercase + string.digits 111 | suffix = "_" 112 | suffix += "".join(random.choice(char_set) for _ in range(5)) 113 | return suffix 114 | 115 | 116 | def str_to_checksum(str: str): 117 | # hello world 118 | # => 5eb63bbbe01eeed093cb22bb8f5acdc3 119 | # => cb22bb8f5acdc3 120 | # => 273350391345368515 121 | m = hashlib.md5() 122 | m.update(str.encode("utf-8")) # encode to binary 123 | md5 = m.hexdigest() 124 | # 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs 125 | half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS 126 | return int(md5[half_pos:], 16) 127 | 128 | 129 | class DbTestCase(unittest.TestCase): 130 | "Sets up a table for testing" 131 | db_cls = None 132 | table1_schema = None 133 | shared_connection = True 134 | 135 | def setUp(self): 136 | assert self.db_cls, self.db_cls 137 | 138 | self.connection = get_conn(self.db_cls, self.shared_connection) 139 | 140 | table_suffix = random_table_suffix() 141 | self.table1_name = f"src{table_suffix}" 142 | 143 | self.table1_path = self.connection.parse_table_name(self.table1_name) 144 | 145 | drop_table(self.connection, self.table1_path) 146 | 147 | self.src_table = table(self.table1_path, schema=self.table1_schema) 148 | if self.table1_schema: 149 | self.connection.query(self.src_table.create()) 150 | 151 | return super().setUp() 152 | 153 | def tearDown(self): 154 | drop_table(self.connection, self.table1_path) 155 | 156 | 157 | def _parameterized_class_per_conn(test_databases): 158 | test_databases = set(test_databases) 159 | names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] 160 | return parameterized_class(("name", "db_cls"), names) 161 | 162 | 163 | def make_test_each_database_in_list(databases) -> Callable: 164 | def _test_per_database(cls): 165 | return _parameterized_class_per_conn(databases)(cls) 166 | 167 | return _test_per_database 168 | -------------------------------------------------------------------------------- /tests/test_database.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | from typing import Callable, List, Tuple 4 | 5 | import pytz 6 | 7 | from sqeleton import connect 8 | from sqeleton import databases as dbs 9 | from sqeleton.queries import table, current_timestamp, NormalizeAsString, ForeignKey, Compiler 10 | from .common import TEST_MYSQL_CONN_STRING 11 | from .common import str_to_checksum, make_test_each_database_in_list, get_conn, random_table_suffix 12 | from sqeleton.abcs.database_types import TimestampTZ 13 | 14 | TEST_DATABASES = { 15 | dbs.MySQL, 16 | dbs.PostgreSQL, 17 | dbs.Oracle, 18 | dbs.Redshift, 19 | dbs.Snowflake, 20 | dbs.DuckDB, 21 | dbs.BigQuery, 22 | dbs.Presto, 23 | dbs.Trino, 24 | dbs.Vertica, 25 | } 26 | 27 | test_each_database: Callable = make_test_each_database_in_list(TEST_DATABASES) 28 | 29 | 30 | class TestDatabase(unittest.TestCase): 31 | def setUp(self): 32 | self.mysql = connect(TEST_MYSQL_CONN_STRING) 33 | 34 | def test_connect_to_db(self): 35 | self.assertEqual(1, self.mysql.query("SELECT 1", int)) 36 | 37 | 38 | class TestMD5(unittest.TestCase): 39 | def test_md5_as_int(self): 40 | class MD5Dialect(dbs.mysql.Dialect, dbs.mysql.Mixin_MD5): 41 | pass 42 | 43 | self.mysql = connect(TEST_MYSQL_CONN_STRING) 44 | self.mysql.dialect = MD5Dialect() 45 | 46 | str = "hello world" 47 | query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str)) 48 | query = f"SELECT {query_fragment}" 49 | 50 | self.assertEqual(str_to_checksum(str), self.mysql.query(query, int)) 51 | 52 | 53 | class TestConnect(unittest.TestCase): 54 | def test_bad_uris(self): 55 | self.assertRaises(ValueError, connect, "p") 56 | self.assertRaises(ValueError, connect, "postgresql:///bla/foo") 57 | self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1") 58 | self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1?warehouse=ha&schema=dup") 59 | 60 | 61 | @test_each_database 62 | class TestSchema(unittest.TestCase): 63 | def test_table_list(self): 64 | name = "tbl_" + random_table_suffix() 65 | db = get_conn(self.db_cls) 66 | tbl = table(db.parse_table_name(name), schema={"id": int}) 67 | q = db.dialect.list_tables(db.default_schema, name) 68 | assert not db.query(q) 69 | 70 | db.query(tbl.create()) 71 | self.assertEqual(db.query(q, List[str]), [name]) 72 | 73 | db.query(tbl.drop()) 74 | assert not db.query(q) 75 | 76 | def test_type_mapping(self): 77 | name = "tbl_" + random_table_suffix() 78 | db = get_conn(self.db_cls) 79 | tbl = table( 80 | db.parse_table_name(name), 81 | schema={ 82 | "int": int, 83 | "float": float, 84 | "datetime": datetime, 85 | "str": str, 86 | "bool": bool, 87 | }, 88 | ) 89 | q = db.dialect.list_tables(db.default_schema, name) 90 | assert not db.query(q) 91 | 92 | db.query(tbl.create()) 93 | self.assertEqual(db.query(q, List[str]), [name]) 94 | 95 | db.query(tbl.drop()) 96 | assert not db.query(q) 97 | 98 | 99 | @test_each_database 100 | class TestQueries(unittest.TestCase): 101 | def test_current_timestamp(self): 102 | db = get_conn(self.db_cls) 103 | res = db.query(current_timestamp(), datetime) 104 | assert isinstance(res, datetime), (res, type(res)) 105 | 106 | def test_correct_timezone(self): 107 | name = "tbl_" + random_table_suffix() 108 | db = get_conn(self.db_cls) 109 | tbl = table(name, schema={"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)}) 110 | 111 | db.query(tbl.create()) 112 | 113 | tz = pytz.timezone("Europe/Berlin") 114 | 115 | now = datetime.now(tz) 116 | if isinstance(db, dbs.Presto): 117 | ms = now.microsecond // 1000 * 1000 # Presto max precision is 3 118 | now = now.replace(microsecond=ms) 119 | 120 | db.query(tbl.insert_row(1, now, now)) 121 | db.query(db.dialect.set_timezone_to_utc()) 122 | 123 | t = db.table(tbl).query_schema() 124 | t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) 125 | 126 | tbl = table(name, schema=t.schema) 127 | 128 | results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple]) 129 | 130 | created_at = results[0][1] 131 | updated_at = results[0][1] 132 | 133 | utc = now.astimezone(pytz.UTC) 134 | expected = utc.__format__("%Y-%m-%d %H:%M:%S.%f") 135 | 136 | self.assertEqual(created_at, expected) 137 | self.assertEqual(updated_at, expected) 138 | 139 | db.query(tbl.drop()) 140 | 141 | def test_foreign_key(self): 142 | db = get_conn(self.db_cls) 143 | 144 | a = table("tbl1_" + random_table_suffix(), schema={"id": int}) 145 | b = table("tbl2_" + random_table_suffix(), schema={"a": ForeignKey(a, "id")}) 146 | c = Compiler(db) 147 | 148 | s = c.compile(b.create()) 149 | try: 150 | db.query([a.create(), b.create()]) 151 | 152 | print("TODO foreign key") 153 | # breakpoint() 154 | finally: 155 | db.query( 156 | [ 157 | a.drop(True), 158 | b.drop(True), 159 | ] 160 | ) 161 | 162 | 163 | @test_each_database 164 | class TestThreePartIds(unittest.TestCase): 165 | def test_three_part_support(self): 166 | if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB]: 167 | self.skipTest("Limited support for 3 part ids") 168 | 169 | table_name = "tbl_" + random_table_suffix() 170 | db = get_conn(self.db_cls) 171 | db_res = db.query("SELECT CURRENT_DATABASE()") 172 | schema_res = db.query("SELECT CURRENT_SCHEMA()") 173 | db_name = db_res.rows[0][0] 174 | schema_name = schema_res.rows[0][0] 175 | 176 | table_one_part = table((table_name,), schema={"id": int}) 177 | table_two_part = table((schema_name, table_name), schema={"id": int}) 178 | table_three_part = table((db_name, schema_name, table_name), schema={"id": int}) 179 | 180 | for part in (table_one_part, table_two_part, table_three_part): 181 | db.query(part.create()) 182 | d = db.query_table_schema(part.path) 183 | assert len(d) == 1 184 | db.query(part.drop()) 185 | -------------------------------------------------------------------------------- /tests/test_mixins.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sqeleton import connect 4 | 5 | from sqeleton.abcs import AbstractDialect, AbstractDatabase 6 | from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue, AbstractMixin_RandomSample, AbstractMixin_TimeTravel 7 | 8 | 9 | class TestMixins(unittest.TestCase): 10 | def test_normalize(self): 11 | # - Test sanity 12 | ddb1 = connect("duckdb://:memory:") 13 | assert not hasattr(ddb1.dialect, "normalize_boolean") 14 | 15 | # - Test abstract mixins 16 | class NewAbstractDialect(AbstractDialect, AbstractMixin_NormalizeValue, AbstractMixin_RandomSample): 17 | pass 18 | 19 | new_connect = connect.load_mixins(AbstractMixin_NormalizeValue, AbstractMixin_RandomSample) 20 | ddb2: AbstractDatabase[NewAbstractDialect] = new_connect("duckdb://:memory:") 21 | # Implementation may change; Just update the test 22 | assert ddb2.dialect.normalize_boolean("bool", None) == "bool::INTEGER::VARCHAR" 23 | assert ddb2.dialect.random_sample_n("x", 10) 24 | 25 | # - Test immutability 26 | ddb3 = connect("duckdb://:memory:") 27 | assert not hasattr(ddb3.dialect, "normalize_boolean") 28 | 29 | self.assertRaises(TypeError, connect.load_mixins, AbstractMixin_TimeTravel) 30 | 31 | new_connect = connect.for_databases("bigquery", "snowflake").load_mixins(AbstractMixin_TimeTravel) 32 | self.assertRaises(NotImplementedError, new_connect, "duckdb://:memory:") 33 | -------------------------------------------------------------------------------- /tests/test_sql.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from .common import TEST_MYSQL_CONN_STRING 4 | 5 | from sqeleton import connect 6 | from sqeleton.queries import Compiler, Count, Explain, Select, table, In, BinOp, Code 7 | 8 | 9 | class TestSQL(unittest.TestCase): 10 | def setUp(self): 11 | self.mysql = connect(TEST_MYSQL_CONN_STRING) 12 | self.compiler = Compiler(self.mysql) 13 | 14 | def test_compile_string(self): 15 | self.assertEqual("SELECT 1", self.compiler.compile(Code("SELECT 1"))) 16 | 17 | def test_compile_int(self): 18 | self.assertEqual("1", self.compiler.compile(1)) 19 | 20 | def test_compile_table_name(self): 21 | self.assertEqual( 22 | "`marine_mammals`.`walrus`", 23 | self.compiler.replace(_is_root=False).compile(table("marine_mammals", "walrus")), 24 | ) 25 | 26 | def test_compile_select(self): 27 | expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" 28 | self.assertEqual( 29 | expected_sql, 30 | self.compiler.compile( 31 | Select( 32 | table("marine_mammals", "walrus"), 33 | [Code("name")], 34 | ) 35 | ), 36 | ) 37 | 38 | # def test_enum(self): 39 | # expected_sql = "(SELECT *, (row_number() over (ORDER BY id)) as idx FROM `walrus` ORDER BY id) tmp" 40 | # self.assertEqual( 41 | # expected_sql, 42 | # self.compiler.compile( 43 | # Enum( 44 | # ("walrus",), 45 | # "id", 46 | # ) 47 | # ), 48 | # ) 49 | 50 | # def test_checksum(self): 51 | # expected_sql = "SELECT name, sum(cast(conv(substring(md5(concat(cast(id as char), cast(timestamp as char))), 18), 16, 10) as unsigned)) FROM `marine_mammals`.`walrus`" 52 | # self.assertEqual( 53 | # expected_sql, 54 | # self.compiler.compile( 55 | # Select( 56 | # ["name", Checksum(["id", "timestamp"])], 57 | # TableName(("marine_mammals", "walrus")), 58 | # ) 59 | # ), 60 | # ) 61 | 62 | def test_compare(self): 63 | expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id <= 1000) AND (id > 1)" 64 | self.assertEqual( 65 | expected_sql, 66 | self.compiler.compile( 67 | Select( 68 | table("marine_mammals", "walrus"), 69 | [Code("name")], 70 | [BinOp("<=", [Code("id"), Code("1000")]), BinOp(">", [Code("id"), Code("1")])], 71 | ) 72 | ), 73 | ) 74 | 75 | def test_in(self): 76 | expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 77 | self.assertEqual( 78 | expected_sql, 79 | self.compiler.compile( 80 | Select(table("marine_mammals", "walrus"), [Code("name")], [In(Code("id"), [1, 2, 3])]) 81 | ), 82 | ) 83 | 84 | def test_count(self): 85 | expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 86 | self.assertEqual( 87 | expected_sql, 88 | self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In(Code("id"), [1, 2, 3])])), 89 | ) 90 | 91 | def test_count_with_column(self): 92 | expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 93 | self.assertEqual( 94 | expected_sql, 95 | self.compiler.compile( 96 | Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])]) 97 | ), 98 | ) 99 | 100 | def test_explain(self): 101 | expected_sql = "EXPLAIN FORMAT=TREE SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" 102 | self.assertEqual( 103 | expected_sql, 104 | self.compiler.compile( 105 | Explain(Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])])) 106 | ), 107 | ) 108 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human, WeakCache 4 | 5 | 6 | class TestUtils(unittest.TestCase): 7 | def test_remove_passwords_in_dict(self): 8 | # Test replacing password value 9 | d = {"password": "mypassword"} 10 | remove_passwords_in_dict(d) 11 | assert d["password"] == "***" 12 | 13 | # Test replacing password in database URL 14 | d = {"database_url": "mysql://user:mypassword@localhost/db"} 15 | remove_passwords_in_dict(d, "$$$$") 16 | assert d["database_url"] == "mysql://user:$$$$@localhost/db" 17 | 18 | # Test replacing password in nested dictionary 19 | d = {"info": {"password": "mypassword"}} 20 | remove_passwords_in_dict(d, "%%") 21 | assert d["info"]["password"] == "%%" 22 | 23 | def test_match_regexps(self): 24 | def only_results(x): 25 | return [v for k, v in x] 26 | 27 | # Test with no matches 28 | regexps = {"a*": 1, "b*": 2} 29 | s = "c" 30 | assert only_results(match_regexps(regexps, s)) == [] 31 | 32 | # Test with one match 33 | regexps = {"a*": 1, "b*": 2} 34 | s = "b" 35 | assert only_results(match_regexps(regexps, s)) == [2] 36 | 37 | # Test with multiple matches 38 | regexps = {"abc": 1, "ab*c": 2, "c*": 3} 39 | s = "abc" 40 | assert only_results(match_regexps(regexps, s)) == [1, 2] 41 | 42 | # Test with regexp that doesn't match the end of the string 43 | regexps = {"a*b": 1} 44 | s = "acb" 45 | assert only_results(match_regexps(regexps, s)) == [] 46 | 47 | def test_match_like(self): 48 | strs = ["abc", "abcd", "ab", "bcd", "def"] 49 | 50 | # Test exact match 51 | pattern = "abc" 52 | result = list(match_like(pattern, strs)) 53 | assert result == ["abc"] 54 | 55 | # Test % match 56 | pattern = "a%" 57 | result = list(match_like(pattern, strs)) 58 | self.assertEqual(result, ["abc", "abcd", "ab"]) 59 | 60 | # Test ? match 61 | pattern = "a?c" 62 | result = list(match_like(pattern, strs)) 63 | self.assertEqual(result, ["abc"]) 64 | 65 | def test_number_to_human(self): 66 | # Test basic conversion 67 | assert number_to_human(1000) == "1k" 68 | assert number_to_human(1000000) == "1m" 69 | assert number_to_human(1000000000) == "1b" 70 | 71 | # Test decimal values 72 | assert number_to_human(1234) == "1k" 73 | assert number_to_human(12345) == "12k" 74 | assert number_to_human(123456) == "123k" 75 | assert number_to_human(1234567) == "1m" 76 | assert number_to_human(12345678) == "12m" 77 | assert number_to_human(123456789) == "123m" 78 | assert number_to_human(1234567890) == "1b" 79 | 80 | # Test negative values 81 | assert number_to_human(-1000) == "-1k" 82 | assert number_to_human(-1000000) == "-1m" 83 | assert number_to_human(-1000000000) == "-1b" 84 | 85 | def test_weak_cache(self): 86 | # Create cache 87 | cache = WeakCache() 88 | 89 | # Test adding and retrieving basic value 90 | o = {1, 2} 91 | cache.add("key", o) 92 | assert cache.get("key") is o 93 | 94 | # Test adding and retrieving dict value 95 | cache.add({"key": "value"}, o) 96 | assert cache.get({"key": "value"}) is o 97 | 98 | # Test deleting value when reference is lost 99 | del o 100 | try: 101 | cache.get({"key": "value"}) 102 | assert False, "KeyError should have been raised" 103 | except KeyError: 104 | pass 105 | -------------------------------------------------------------------------------- /tests/waiting_for_stack_up.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -n "$VERTICA_URI" ] 4 | then 5 | echo "Check Vertica DB running..." 6 | while true 7 | do 8 | if docker logs dd-vertica | tail -n 100 | grep -q -i "vertica is now running" 9 | then 10 | echo "Vertica DB is ready"; 11 | break; 12 | else 13 | echo "Waiting for Vertica DB starting..."; 14 | sleep 10; 15 | fi 16 | done 17 | fi 18 | --------------------------------------------------------------------------------