├── .changie.yaml ├── .devcontainer ├── README.md └── devcontainer.json ├── .dockerignore ├── .env.example ├── .github ├── dependabot.yml ├── release-drafter.yml └── workflows │ ├── constraints.txt │ ├── deploy_aws.yml │ ├── dwh-benchmark.yml │ ├── release.yml │ ├── static.yml │ └── test.yml ├── .gitignore ├── DEVELOPMENT.md ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yml ├── examples └── playground │ ├── app │ └── sqlglot │ │ ├── page.py │ │ └── requirements.txt │ └── universql_project.yml ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── resources ├── cli_demo.gif ├── cli_demo.png ├── clickbench.png ├── demo.gif ├── dwh-benchmark │ ├── clickbench.sql │ ├── index.html │ └── results.json ├── how_to_use.png ├── snowflake │ ├── favicon.ico │ ├── index.html │ └── logo.png └── snowflake_redshift_usage │ ├── 1.Download.sh │ ├── 2.IngestCost.ipynb │ └── 2.IngestCost.py ├── snowflake.aws.lambda.Dockerfile ├── ssl └── .gitignore ├── tests ├── __init__.py ├── integration │ ├── __init__.py │ ├── extract.py │ ├── load.py │ ├── object_identifiers.py │ ├── transform.py │ └── utils.py ├── plugins │ └── __init__.py ├── scratch │ ├── __init__.py │ ├── boto_tests.py │ ├── cdk_tests.py │ ├── chdb_tests.py │ ├── dbt_tests.py │ ├── duck_tests.py │ ├── jinja_test.py │ ├── pyiceberg_tests.py │ ├── sqlglot_tests.py │ ├── system_tray.py │ └── textual_test.py ├── sql_optimizer.py └── testing.ipynb ├── universql.code-workspace └── universql ├── __init__.py ├── agent ├── __init__.py └── cloudflared.py ├── catalog ├── __init__.py └── iceberg.py ├── lake ├── cloud.py └── fsspec_util.py ├── main.py ├── plugin.py ├── plugins ├── __init__.py ├── snow.py └── ui.py ├── protocol ├── __init__.py ├── lambda.py ├── session.py ├── snowflake.py └── utils.py ├── streamlit ├── .streamlit │ └── config.toml ├── app.py └── requirements.txt ├── util.py └── warehouse ├── __init__.py ├── bigquery.py ├── duckdb.py ├── redshift.py └── snowflake.py /.changie.yaml: -------------------------------------------------------------------------------- 1 | changesDir: .changes 2 | unreleasedDir: unreleased 3 | headerPath: header.tpl.md 4 | changelogPath: CHANGELOG.md 5 | versionExt: md 6 | versionFormat: '## {{.Version}} - {{.Time.Format "2006-01-02"}}' 7 | kindFormat: '### {{.Kind}}' 8 | changeFormat: '* {{.Body}}' 9 | kinds: 10 | - label: Added 11 | - label: Changed 12 | - label: Deprecated 13 | - label: Removed 14 | - label: Fixed 15 | - label: Security 16 | newlines: 17 | afterChangelogHeader: 1 18 | beforeChangelogVersion: 1 19 | endOfVersion: 1 20 | -------------------------------------------------------------------------------- /.devcontainer/README.md: -------------------------------------------------------------------------------- 1 | # Start Universql 2 | 3 | Congrats! You've successfully created a new Universql project. To start the project, run the following command: 4 | 5 | ```bash 6 | poetry run universql snowflake 7 | ``` -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "universql", 3 | "build": { 4 | "dockerfile": "../Dockerfile" 5 | }, 6 | "forwardPorts": [ 7 | 8084 8 | ], 9 | "postCreateCommand": "poetry install && gh codespace ports visibility 8084:public -c $CODESPACE_NAME", 10 | "portsAttributes": { 11 | "8084": { 12 | "label": "universql-app", 13 | "protocol": "https" 14 | } 15 | }, 16 | "remoteEnv": { 17 | "SERVER_HOST": "0.0.0.0" 18 | }, 19 | "customizations": { 20 | "codespaces": { 21 | "openFiles": [ 22 | ".devcontainer/README.md" 23 | ] 24 | } 25 | } 26 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | tests/* 2 | resources/* 3 | .git/* -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | SNOWFLAKE_ACCOUNT=XXXXX-XXXXX 2 | SERVER_PORT=8084 3 | HOST=0.0.0.0 4 | SSL_KEYFILE_NAME=privkey_x.pem 5 | SSL_CERTFILE_NAME=fullchain_x.pem -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 2 | 3 | version: 2 4 | updates: 5 | - package-ecosystem: "pip" 6 | directory: "/" 7 | schedule: 8 | interval: "weekly" 9 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | categories: 2 | - title: ":boom: Breaking Changes" 3 | label: "breaking" 4 | - title: ":rocket: Features" 5 | label: "enhancement" 6 | - title: ":fire: Removals and Deprecations" 7 | label: "removal" 8 | - title: ":beetle: Fixes" 9 | label: "bug" 10 | - title: ":racehorse: Performance" 11 | label: "performance" 12 | - title: ":rotating_light: Testing" 13 | label: "testing" 14 | - title: ":construction_worker: Continuous Integration" 15 | label: "ci" 16 | - title: ":books: Documentation" 17 | label: "documentation" 18 | - title: ":hammer: Refactoring" 19 | label: "refactoring" 20 | - title: ":lipstick: Style" 21 | label: "style" 22 | - title: ":package: Dependencies" 23 | labels: 24 | - "dependencies" 25 | - "build" 26 | template: | 27 | ## Changes 28 | See the [CHANGELOG](https://github.com/z3z1ma/dbt-osmosis/blob/main/CHANGELOG.md) for a full list of changes. 29 | Automatically detected changes: 30 | $CHANGES -------------------------------------------------------------------------------- /.github/workflows/constraints.txt: -------------------------------------------------------------------------------- 1 | pip==23.0.1 2 | virtualenv==20.21.0 -------------------------------------------------------------------------------- /.github/workflows/deploy_aws.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to AWS Lambda 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | jobs: 8 | 9 | build: 10 | 11 | name: Build Image and Deploy 12 | runs-on: ubuntu-latest 13 | environment: deploy 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v3 17 | 18 | - name: Configure AWS credentials 19 | uses: aws-actions/configure-aws-credentials@v4 20 | with: 21 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 22 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 23 | aws-region: us-east-1 24 | 25 | - name: Login to Amazon ECR 26 | id: login-ecr 27 | uses: aws-actions/amazon-ecr-login@v1 28 | with: 29 | registry-type: public 30 | - name: Build, tag, and push image to Amazon ECR 31 | env: 32 | ECR_REGISTRY: public.ecr.aws/m5p8n9s4/universql 33 | IMAGE_TAG: latest 34 | run: | 35 | docker build -t $ECR_REGISTRY:$IMAGE_TAG -f snowflake.aws.lambda.Dockerfile --push . 36 | - uses: int128/deploy-lambda-action@v1 37 | with: 38 | function-name: universql-server 39 | image-uri: public.ecr.aws/m5p8n9s4/universql:latest 40 | #image-uri: 730335382627.dkr.ecr.us-east-1.amazonaws.com/universql:latest -------------------------------------------------------------------------------- /.github/workflows/dwh-benchmark.yml: -------------------------------------------------------------------------------- 1 | name: deploy 2 | on: [push] 3 | jobs: 4 | version: 5 | name: "Check Snowflake CLI version" 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Checkout repo 9 | uses: actions/checkout@v4 10 | - uses: Snowflake-Labs/snowflake-cli-action@v1.5 11 | with: 12 | cli-version: "latest" 13 | - name: Test project 14 | env: 15 | SNOWFLAKE_CONNECTIONS_BASE64: ${{ secrets.SNOWFLAKE_CONNECTIONS_BASE64 }} 16 | run: | 17 | export SNOWFLAKE_CONNECTIONS=$(printf '%s' "$SNOWFLAKE_CONNECTIONS_BASE64" | base64 -d | tr '\r' '\n') 18 | snow --version 19 | snow connection test --environment ci-aws-us-east-1 20 | python run_benchmark.py /resources/dwh-benchmark/clickbench.sql > /resources/dwh-benchmark/results.json 21 | - uses: actions/upload-artifact@v4 22 | with: 23 | name: icebergbench 24 | path: /resources/dwh-benchmark -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | pip: 10 | name: Release PyPI 11 | runs-on: ubuntu-latest 12 | environment: deploy 13 | steps: 14 | - name: Check out the repository 15 | uses: actions/checkout@v3 16 | with: 17 | fetch-depth: 2 18 | - name: Set up Python 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.11" 22 | - name: Upgrade pip 23 | run: | 24 | pip install --constraint=.github/workflows/constraints.txt pip 25 | pip --version 26 | - name: Install Poetry 27 | run: | 28 | pip install --constraint=.github/workflows/constraints.txt poetry 29 | poetry --version 30 | poetry install 31 | - name: Check if there is a parent commit 32 | id: check-parent-commit 33 | run: | 34 | echo "::set-output name=sha::$(git rev-parse --verify --quiet HEAD^)" 35 | - name: Detect and tag new version 36 | id: check-version 37 | if: steps.check-parent-commit.outputs.sha 38 | uses: salsify/action-detect-and-tag-new-version@v2.0.3 39 | with: 40 | version-command: | 41 | bash -o pipefail -c "poetry version | awk '{ print \$2 }'" 42 | - name: Bump version for developmental release 43 | if: "! steps.check-version.outputs.tag" 44 | run: | 45 | poetry version patch && 46 | version=$(poetry version | awk '{ print $2 }') && 47 | poetry version $version.dev.$(date +%s) 48 | - name: Build package 49 | run: | 50 | poetry build --ansi 51 | - name: Publish package on PyPI 52 | # if: steps.check-version.outputs.tag 53 | uses: pypa/gh-action-pypi-publish@v1.8.3 54 | with: 55 | user: __token__ 56 | password: ${{ secrets.PYPI_TOKEN }} 57 | 58 | - name: Publish package on TestPyPI 59 | if: "! steps.check-version.outputs.tag" 60 | uses: pypa/gh-action-pypi-publish@v1.8.3 61 | with: 62 | user: __token__ 63 | password: ${{ secrets.TEST_PYPI_TOKEN }} 64 | repository_url: https://test.pypi.org/legacy/ 65 | 66 | - name: Publish the release notes 67 | uses: release-drafter/release-drafter@v5.23.0 68 | with: 69 | publish: ${{ steps.check-version.outputs.tag != '' }} 70 | tag: ${{ steps.check-version.outputs.tag }} 71 | env: 72 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 73 | docker: 74 | name: Publish Docker 75 | environment: deploy 76 | runs-on: ubuntu-latest 77 | steps: 78 | - name: Check out the repository 79 | uses: actions/checkout@v3 80 | with: 81 | fetch-depth: 2 82 | - name: Set up Docker Buildx 83 | uses: docker/setup-buildx-action@v2 84 | - name: Publish to Registry 85 | uses: elgohr/Publish-Docker-Github-Action@v5 86 | with: 87 | name: buremba/universql 88 | username: ${{ secrets.DOCKER_USERNAME }} 89 | password: ${{ secrets.DOCKER_PASSWORD }} 90 | tags: "latest" 91 | dockerfile: "Dockerfile" 92 | platforms: linux/amd64,linux/arm64 93 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | # Simple workflow for deploying static content to GitHub Pages 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | # Runs on pushes targeting the default branch 6 | push: 7 | branches: ["main"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | # Single deploy job since we're just deploying 26 | deploy: 27 | environment: 28 | name: github-pages 29 | url: ${{ steps.deployment.outputs.page_url }} 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | - name: Setup Pages 35 | uses: actions/configure-pages@v5 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | with: 39 | # Upload entire repository 40 | path: 'resources/' 41 | - name: Deploy to GitHub Pages 42 | id: deployment 43 | uses: actions/deploy-pages@v4 44 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | 6 | jobs: 7 | pip: 8 | name: Build and Test 9 | runs-on: ubuntu-latest 10 | environment: deploy 11 | steps: 12 | - name: Check out the repository 13 | uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 2 16 | - name: Set up Python 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: "3.11" 20 | - name: Upgrade pip 21 | run: | 22 | pip install --constraint=.github/workflows/constraints.txt pip 23 | pip --version 24 | - name: Install Poetry 25 | run: | 26 | pip install --constraint=.github/workflows/constraints.txt poetry 27 | poetry --version 28 | poetry install 29 | - name: Test project 30 | env: 31 | SNOWFLAKE_CONNECTIONS_BASE64: ${{ secrets.SNOWFLAKE_CONNECTIONS_BASE64 }} 32 | run: | 33 | export SNOWFLAKE_CONNECTIONS=$(printf '%s' "$SNOWFLAKE_CONNECTIONS_BASE64" | base64 -d | tr '\r' '\n') 34 | poetry run pytest tests/integration/* --log-cli-level=DEBUG -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | universql-misc/* 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # Custom 103 | .idea/* 104 | .metabase/* 105 | .clickhouse/* 106 | .DS_STORE 107 | .certs/* 108 | .rill/* 109 | venv/* 110 | .venv/* 111 | .db/* 112 | .env 113 | universql.metadata.sqlite 114 | credentials/ 115 | 116 | one_off_tests 117 | copy_files_for_claude.py -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Running Github Actions locally 2 | 3 | ```shell 4 | act push --container-architecture linux/amd64 --secret-file .env.act --workflows ./.github/workflows/test.yml --insecure-secrets 5 | ``` 6 | 7 | # Update certs 8 | 9 | 1. Start [the instance](https://console.cloud.google.com/compute/instancesDetail/zones/us-central1-a/instances/instance-20240709-162937?inv=1&invt=AbpNfQ&project=jinjat-demo) 10 | 11 | 2. Connect and update the certs 12 | ```bash 13 | gcloud compute ssh --zone "us-central1-a" "instance-20240709-162937" --project "jinjat-demo" 14 | ``` 15 | 16 | Update IP on DNS: https://dash.cloudflare.com/09f0f68ebdbed47d203725997d4cfbdb/localhostcomputing.com/dns/records 17 | 18 | ```bash 19 | sudo certbot certonly --standalone --domains localhostcomputing.com 20 | ``` 21 | 22 | 3. Update the [util.py](/universql/util.py) 23 | 24 | # Plugins 25 | 26 | If you would like to enable custom logic inside Universql, to support transition in between different dialects or just 27 | ```bash 28 | > poetry add "universql[rewrite_create_table_as_iceberg]" 29 | > universql snowflake myaccount.aws.us-east-1 --module-paths rewrite_create_table_as_iceberg 30 | ``` 31 | 32 | universql snowflake --module-paths rewrite_create_table_as_iceberg 33 | 34 | Then, implement your custom logic later on if you would like. 35 | 36 | You can inject modules into Universql at start time, by passing `--module-paths ./my_transformers` or via: 37 | 38 | 39 | 40 | Here is the abstract class for transformer: 41 | 42 | ```python 43 | class Transformer: 44 | def __init__(self, 45 | # allows us to call this transformer when base catalog is snowflake 46 | source_engine: SnowflakeCatalog, 47 | # if this is generic as Catalog, can automatically invoke 48 | # .transform() no matter where it's running on 49 | target_engine: Catalog 50 | ): 51 | self.source_engine = source_engine 52 | self.target_engine = target_engine 53 | 54 | def transform_sql(self, expression: Expression) -> Expression: 55 | return expression 56 | 57 | def transform_result(self, response: Response): 58 | return response 59 | 60 | def transform_request(self, request: Request): 61 | return request 62 | ``` 63 | 64 | Here is an example: 65 | 66 | ```python 67 | # Rewrites Snowflake timestamp time to DuckDB 68 | class FixTimestampTypes(Transformer): 69 | def __init__(self, source_engine: SnowflakeCatalog, target_engine: DuckDBCatalog): 70 | super().__init__(source_engine, target_engine) 71 | 72 | def transform_sql(self, expression): 73 | if isinstance(expression, sqlglot.exp.DataType): 74 | if expression.this.value in ["TIMESTAMPLTZ", "TIMESTAMPTZ"]: 75 | return sqlglot.exp.DataType.build("TIMESTAMPTZ") 76 | if expression.this.value in ["VARIANT"]: 77 | return sqlglot.exp.DataType.build("JSON") 78 | 79 | return expression 80 | ``` 81 | 82 | The engine will look into the `__init__` and call the transformer only when the catalogs type match. 83 | 84 | Things we can move to transformers: 85 | 86 | ```python 87 | class RewriteCreateAsIceberg(Transformer): 88 | 89 | def transform_sql(self, expression: Expression) -> Expression: 90 | # re-write CREATE TABLE as CREATE ICEBERG TABLE 91 | return expression 92 | ``` 93 | 94 | For stage integration, something like: 95 | 96 | ```python 97 | class SnowflakeStageTransformer(Transformer): 98 | def __init__(self, source_engine: SnowflakeCatalog, target_engine: DuckDBCatalog): 99 | super().__init__(source_engine, target_engine) 100 | 101 | def transform_sql(self, ast: Expression) -> Expression: 102 | if isinstance(ast, sqlglot.exp.Var) and ast.name.startswith('@'): 103 | # transform into full path and create secret on duckdb 104 | self.target_engine.duckdb.sql("select from stage information_schema.stages where ..") 105 | if not_exists: 106 | self.source_engine.executor().execute_raw(, "get stage info from fs" 107 | self.target_engine.duckdb.sql("INSERT INTO information_schema.stages ...") 108 | return new_ast_with_full_path 109 | return ast 110 | ``` -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | RUN apt-get update && apt-get -y upgrade \ 4 | && apt-get install gcc -y \ 5 | && pip install --upgrade pip 6 | 7 | # Set separate working directory for easier debugging. 8 | WORKDIR /app 9 | 10 | RUN pip install 'poetry==1.8.3' 11 | COPY pyproject.toml poetry.lock ./ 12 | 13 | RUN poetry config virtualenvs.create false --local 14 | 15 | # Copy everything. (Note: If needed, we can use .dockerignore to limit what's copied.) 16 | COPY . . 17 | 18 | RUN poetry install 19 | 20 | EXPOSE 8084 21 | ENV SERVER_HOST=0.0.0.0 22 | ENV USE_LOCALCOMPUTING_COM=1 23 | 24 | ENTRYPOINT ["poetry", "run", "universql"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | universql-server: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | container_name: universql-server 9 | ports: 10 | - "${SERVER_PORT}:${SERVER_PORT}" 11 | volumes: 12 | - ./ssl:/app/ssl 13 | environment: 14 | - SERVER_HOST=${HOST} 15 | command: snowflake --account ${SNOWFLAKE_ACCOUNT} --host ${HOST} --ssl_keyfile /app/ssl/${SSL_KEYFILE_NAME} --ssl_certfile /app/ssl/${SSL_CERTFILE_NAME} -------------------------------------------------------------------------------- /examples/playground/app/sqlglot/page.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import sqlglot 3 | from sqlglot import expressions 4 | import networkx as nx 5 | from st_link_analysis import st_link_analysis, NodeStyle 6 | 7 | example_sql = """\ 8 | CREATE VIEW sales.customer_order_total_percentage AS 9 | SELECT 10 | order.id, 11 | email, 12 | order.total_amount, 13 | total_spent, 14 | 100 * order.total_amount / total_spent AS order_total_percentage 15 | FROM sales.customer_order_summary 16 | JOIN sales.order 17 | ON customer_order_summary.order_id = order.id;""" 18 | 19 | # Set the title and favicon that appear in the Browser's tab bar. 20 | st.set_page_config( 21 | page_title="sqlglot Abstract Syntax Tree (AST) Viewer", 22 | page_icon=":deciduous_tree:", 23 | layout="wide", 24 | ) 25 | 26 | # Set the title that appears at the top of the page. 27 | """ 28 | # :deciduous_tree: sqlglot Abstract Syntax Tree (AST) Viewer 29 | 30 | View the abstract syntax tree generated by sqlglot for SQL code. 31 | 32 | If there are multiple statements, only the first statement will be visualized. 33 | """ 34 | 35 | # Add some spacing 36 | "" 37 | "" 38 | 39 | sql = st.text_area("Input SQL code here:", value=example_sql, height=300) 40 | 41 | "" 42 | "" 43 | "" 44 | 45 | parsed_sql = sqlglot.parse(sql) 46 | 47 | 48 | def is_flattenable(node, classes): 49 | return isinstance(node, classes) 50 | 51 | 52 | def ast_to_digraph(ast, flattenable_classes): 53 | """ 54 | Convert a SQLGlot AST into a NetworkX DiGraph. 55 | Args: 56 | ast (sqlglot.Expression): The root of the SQLGlot AST. 57 | Returns: 58 | nx.DiGraph: A directed graph representing the AST. 59 | """ 60 | graph = nx.DiGraph() 61 | 62 | def add_node_and_edges(node, parent=None): 63 | """ 64 | Recursively add nodes and edges to the graph based on the AST. 65 | Args: 66 | node (sqlglot.Expression): The current AST node. 67 | parent (str): The label of the parent node, if any. 68 | """ 69 | # Create a unique identifier for each node using its type and id 70 | 71 | node_id = id(node) 72 | node_label = f"{node.__class__.__name__}" 73 | node_kind = ( 74 | f"({node.args.get('kind')})" if node.args.get("kind") else "" 75 | ) 76 | node_name = f"{node_label} {node_kind}".strip() 77 | 78 | all_children_identifiers = all( 79 | [ 80 | isinstance(x, expressions.Identifier) 81 | for x in filter(None, node.args.values()) 82 | ] 83 | ) 84 | flattenable = ( 85 | is_flattenable(node, flattenable_classes) 86 | or all_children_identifiers 87 | ) 88 | 89 | if flattenable or node.is_leaf(): 90 | node_content = node.sql() 91 | else: 92 | node_content = ( 93 | node.this if isinstance(node, expressions.Identifier) else "" 94 | ) 95 | 96 | # Add the node to the graph with its label 97 | graph.add_node( 98 | node_id, 99 | label=node_label, 100 | name=node_name, 101 | kind=node_kind, 102 | content=node_content, 103 | ) 104 | # If there is a parent node, add an edge from parent to current node 105 | if parent is not None: 106 | graph.add_edge(parent, node_id) 107 | 108 | # Recursively process child nodes 109 | for child in node.args.values(): 110 | if isinstance(child, list) and not flattenable: 111 | for sub_child in child: 112 | if isinstance(sub_child, sqlglot.Expression): 113 | add_node_and_edges(sub_child, node_id) 114 | elif isinstance(child, sqlglot.Expression) and not flattenable: 115 | add_node_and_edges(child, node_id) 116 | 117 | # Start traversal from the root of the AST 118 | add_node_and_edges(ast) 119 | 120 | return graph 121 | 122 | 123 | def graph_to_elements(graph): 124 | """ 125 | Convert a NetworkX DiGraph into the specified elements data structure. 126 | 127 | Args: 128 | graph (nx.DiGraph): The directed graph to convert. 129 | 130 | Returns: 131 | dict: A dictionary with nodes and edges formatted according to the 132 | specified structure. 133 | """ 134 | elements = {"nodes": [], "edges": []} 135 | 136 | # Convert nodes to the desired format 137 | for node_id, node_data in graph.nodes(data=True): 138 | # Extract the node label and other attributes 139 | node_label = node_data.get("label", "NODE") 140 | node_name = node_data.get("name", "") 141 | node_content = node_data.get("content", "") 142 | 143 | # Customize the node data format based on available attributes 144 | node_entry = {"data": {"id": node_id, "label": node_label}} 145 | if node_name: 146 | node_entry["data"]["name"] = node_name 147 | if node_content: 148 | node_entry["data"]["content"] = node_content 149 | 150 | elements["nodes"].append(node_entry) 151 | 152 | # Convert edges to the desired format 153 | for edge_id, (source, target, edge_data) in enumerate( 154 | graph.edges(data=True), start=1 155 | ): 156 | edge_label = edge_data.get("label", "EDGE") 157 | 158 | # Create edge data with source, target, and attributes 159 | edge_entry = { 160 | "data": { 161 | "id": edge_id, 162 | "label": edge_label, 163 | "source": source, 164 | "target": target, 165 | } 166 | } 167 | elements["edges"].append(edge_entry) 168 | 169 | return elements 170 | 171 | 172 | def ast_to_link_analysis(ast, classes): 173 | graph = ast_to_digraph(ast, classes) 174 | elements = graph_to_elements(graph) 175 | st_link_analysis( 176 | elements, 177 | "breadthfirst", 178 | [ 179 | NodeStyle("Select", "#FF5733", "content"), 180 | NodeStyle("Create", "#33FF57", "content"), 181 | NodeStyle("With", "#3357FF", "content"), 182 | NodeStyle("Table", "#FF33A1", "content", "table"), 183 | NodeStyle("Column", "#33FFA1", "content", "view_column"), 184 | NodeStyle("Alias", "#4340A1", "content"), 185 | NodeStyle("ColumnDef", "#33FFA1", "content"), 186 | NodeStyle("ColumnConstraint", "#A133FF", "content"), 187 | NodeStyle("Delete", "#FF8C33", "content"), 188 | NodeStyle("DropCopy", "#8CFF33", "content"), 189 | NodeStyle("ForeignKey", "#33FF8C", "content"), 190 | NodeStyle("PrimaryKey", "#8C33FF", "content"), 191 | NodeStyle("Into", "#FF3333", "content"), 192 | NodeStyle("From", "#33FF33", "content"), 193 | NodeStyle("Having", "#3333FF", "content"), 194 | NodeStyle("Index", "#FF333F", "content"), 195 | NodeStyle("Insert", "#33FF3F", "content"), 196 | NodeStyle("Limit", "#3F33FF", "content"), 197 | NodeStyle("Group", "#FF33FF", "content"), 198 | NodeStyle("Join", "#33FFFF", "content"), 199 | NodeStyle("Properties", "#FFCC33", "content"), 200 | NodeStyle("Where", "#33CCFF", "content"), 201 | NodeStyle("Order", "#CC33FF", "content"), 202 | ], 203 | [], 204 | ) 205 | 206 | 207 | sql_expressions = { 208 | x.__name__: x 209 | for x in expressions.__dict__.values() 210 | if isinstance(x, expressions._Expression) 211 | } 212 | 213 | default_flattenable_classes = ["Column", "ColumnDef", "Alias"] 214 | 215 | if parsed_sql[0]: 216 | st.code(parsed_sql[0].__repr__()) 217 | 218 | flattenable_classes = st.multiselect( 219 | "Select what types of expressions you would like to flatten:", 220 | sql_expressions.keys(), 221 | default=default_flattenable_classes, 222 | ) 223 | 224 | ast_to_link_analysis( 225 | parsed_sql[0], 226 | tuple( 227 | sql_expressions.get(class_name) 228 | for class_name in flattenable_classes 229 | ), 230 | ) -------------------------------------------------------------------------------- /examples/playground/app/sqlglot/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | sqlglot 3 | duckdb 4 | networkx 5 | st_link_analysis -------------------------------------------------------------------------------- /examples/playground/universql_project.yml: -------------------------------------------------------------------------------- 1 | 2 | name: 'utilities' 3 | version: '2.0' 4 | 5 | profile: 'default' 6 | config-version: 2 7 | app-path: 'app' 8 | target-path: 'target' 9 | 10 | vars: 11 | example_variable: 'value' 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "universql" 3 | version = "0.1" 4 | description = "" 5 | authors = ["Burak Kabakcı "] 6 | readme = "README.md" 7 | license = "Apache 2.0" 8 | keywords = ["dbt", "server", "streamlit", "git", "refine", "data-app", "snowflake"] 9 | documentation = "https://github.com/buremba/universql" 10 | repository = "https://github.com/buremba/universql" 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.11,<3.13" 14 | duckdb = "^1.1.3" 15 | fastapi = "^0.111.0" 16 | uvicorn = "^0.30.1" 17 | snowflake-connector-python = {extras = ["pandas", "secure-local-storage"], version = "^3.12.0"} 18 | # same with fakesnow 19 | sqlglot = "~25.34.0" 20 | fsspec = "^2024.6.1" 21 | click = ">7" 22 | pyiceberg = {extras = ["glue", "sql-sqlite"], version = "^0.7.1"} 23 | #pyiceberg = { git = "https://github.com/buremba/iceberg-python", branch = "main", develop = true, extras = ["glue"] } 24 | fakesnow = "^0.9.27" 25 | humanize = "^4.10.0" 26 | mangum = "^0.19.0" 27 | pyarrow = "^17.0.0" 28 | psutil = "^6.0.0" 29 | # once we get rid of from pyiceberg.catalog.sql import SqlCatalog, remove it 30 | sqlalchemy = "^2.0.35" 31 | duckdb-engine = "^0.13.2" 32 | 33 | # GCS 34 | gcsfs = "^2024.6.1" 35 | google-cloud-bigquery = "^3.25.0" 36 | # AWS 37 | s3fs = "^2024.6.1" 38 | aws-cdk-lib = "^2.162.0" 39 | sentry-sdk = {extras = ["fastapi"], version = "^2.17.0"} 40 | aws-cdk-aws-codestar-alpha = "^2.162.1a0" 41 | marimo = "^0.10.9" 42 | 43 | 44 | [tool.poetry.dev-dependencies] 45 | pylint = ">=2.11.1" 46 | pystray = "^0.19.5" 47 | networkx = "^3.3" 48 | streamlit = "^1.38.0" 49 | st-link-analysis = "^0.3.0" 50 | shandy-sqlfmt = "^0.17.0" 51 | streamlit-ace = "^0.1.1" 52 | pytest-randomly = "^3.15.0" 53 | pytest = "^8.2.0" 54 | 55 | [build-system] 56 | requires = ["poetry-core"] 57 | build-backend = "poetry.core.masonry.api" 58 | 59 | [tool.poetry.scripts] 60 | universql = 'universql.main:cli' -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore:unclosed file .*:ResourceWarning 4 | testpaths = 5 | tests/integration -------------------------------------------------------------------------------- /resources/cli_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/cli_demo.gif -------------------------------------------------------------------------------- /resources/cli_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/cli_demo.png -------------------------------------------------------------------------------- /resources/clickbench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/clickbench.png -------------------------------------------------------------------------------- /resources/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/demo.gif -------------------------------------------------------------------------------- /resources/dwh-benchmark/clickbench.sql: -------------------------------------------------------------------------------- 1 | CREATE TEMP TABLE hits2 2 | ( 3 | WatchID BIGINT NOT NULL, 4 | JavaEnable SMALLINT NOT NULL, 5 | Title TEXT NOT NULL, 6 | GoodEvent SMALLINT NOT NULL, 7 | EventTime TIMESTAMP NOT NULL, 8 | EventDate Date NOT NULL, 9 | CounterID INTEGER NOT NULL, 10 | ClientIP INTEGER NOT NULL, 11 | RegionID INTEGER NOT NULL, 12 | UserID BIGINT NOT NULL, 13 | CounterClass SMALLINT NOT NULL, 14 | OS SMALLINT NOT NULL, 15 | UserAgent SMALLINT NOT NULL, 16 | URL TEXT NOT NULL, 17 | Referer TEXT NOT NULL, 18 | IsRefresh SMALLINT NOT NULL, 19 | RefererCategoryID SMALLINT NOT NULL, 20 | RefererRegionID INTEGER NOT NULL, 21 | URLCategoryID SMALLINT NOT NULL, 22 | URLRegionID INTEGER NOT NULL, 23 | ResolutionWidth SMALLINT NOT NULL, 24 | ResolutionHeight SMALLINT NOT NULL, 25 | ResolutionDepth SMALLINT NOT NULL, 26 | FlashMajor SMALLINT NOT NULL, 27 | FlashMinor SMALLINT NOT NULL, 28 | FlashMinor2 TEXT NOT NULL, 29 | NetMajor SMALLINT NOT NULL, 30 | NetMinor SMALLINT NOT NULL, 31 | UserAgentMajor SMALLINT NOT NULL, 32 | UserAgentMinor VARCHAR(255) NOT NULL, 33 | CookieEnable SMALLINT NOT NULL, 34 | JavascriptEnable SMALLINT NOT NULL, 35 | IsMobile SMALLINT NOT NULL, 36 | MobilePhone SMALLINT NOT NULL, 37 | MobilePhoneModel TEXT NOT NULL, 38 | Params TEXT NOT NULL, 39 | IPNetworkID INTEGER NOT NULL, 40 | TraficSourceID SMALLINT NOT NULL, 41 | SearchEngineID SMALLINT NOT NULL, 42 | SearchPhrase TEXT NOT NULL, 43 | AdvEngineID SMALLINT NOT NULL, 44 | IsArtifical SMALLINT NOT NULL, 45 | WindowClientWidth SMALLINT NOT NULL, 46 | WindowClientHeight SMALLINT NOT NULL, 47 | ClientTimeZone SMALLINT NOT NULL, 48 | ClientEventTime TIMESTAMP NOT NULL, 49 | SilverlightVersion1 SMALLINT NOT NULL, 50 | SilverlightVersion2 SMALLINT NOT NULL, 51 | SilverlightVersion3 INTEGER NOT NULL, 52 | SilverlightVersion4 SMALLINT NOT NULL, 53 | PageCharset TEXT NOT NULL, 54 | CodeVersion INTEGER NOT NULL, 55 | IsLink SMALLINT NOT NULL, 56 | IsDownload SMALLINT NOT NULL, 57 | IsNotBounce SMALLINT NOT NULL, 58 | FUniqID BIGINT NOT NULL, 59 | OriginalURL TEXT NOT NULL, 60 | HID INTEGER NOT NULL, 61 | IsOldCounter SMALLINT NOT NULL, 62 | IsEvent SMALLINT NOT NULL, 63 | IsParameter SMALLINT NOT NULL, 64 | DontCountHits SMALLINT NOT NULL, 65 | WithHash SMALLINT NOT NULL, 66 | HitColor CHAR NOT NULL, 67 | LocalEventTime TIMESTAMP NOT NULL, 68 | Age SMALLINT NOT NULL, 69 | Sex SMALLINT NOT NULL, 70 | Income SMALLINT NOT NULL, 71 | Interests SMALLINT NOT NULL, 72 | Robotness SMALLINT NOT NULL, 73 | RemoteIP INTEGER NOT NULL, 74 | WindowName INTEGER NOT NULL, 75 | OpenerName INTEGER NOT NULL, 76 | HistoryLength SMALLINT NOT NULL, 77 | BrowserLanguage TEXT NOT NULL, 78 | BrowserCountry TEXT NOT NULL, 79 | SocialNetwork TEXT NOT NULL, 80 | SocialAction TEXT NOT NULL, 81 | HTTPError SMALLINT NOT NULL, 82 | SendTiming INTEGER NOT NULL, 83 | DNSTiming INTEGER NOT NULL, 84 | ConnectTiming INTEGER NOT NULL, 85 | ResponseStartTiming INTEGER NOT NULL, 86 | ResponseEndTiming INTEGER NOT NULL, 87 | FetchTiming INTEGER NOT NULL, 88 | SocialSourceNetworkID SMALLINT NOT NULL, 89 | SocialSourcePage TEXT NOT NULL, 90 | ParamPrice BIGINT NOT NULL, 91 | ParamOrderID TEXT NOT NULL, 92 | ParamCurrency TEXT NOT NULL, 93 | ParamCurrencyID SMALLINT NOT NULL, 94 | OpenstatServiceName TEXT NOT NULL, 95 | OpenstatCampaignID TEXT NOT NULL, 96 | OpenstatAdID TEXT NOT NULL, 97 | OpenstatSourceID TEXT NOT NULL, 98 | UTMSource TEXT NOT NULL, 99 | UTMMedium TEXT NOT NULL, 100 | UTMCampaign TEXT NOT NULL, 101 | UTMContent TEXT NOT NULL, 102 | UTMTerm TEXT NOT NULL, 103 | FromTag TEXT NOT NULL, 104 | HasGCLID SMALLINT NOT NULL, 105 | RefererHash BIGINT NOT NULL, 106 | URLHash BIGINT NOT NULL, 107 | CLID INTEGER NOT NULL, 108 | PRIMARY KEY (CounterID, EventDate, UserID, EventTime, WatchID) 109 | ); 110 | /* 111 | { 112 | "clickbench_hide": true 113 | } 114 | */ 115 | COPY INTO hits2 FROM 's3://clickhouse-public-datasets/hits_compatible/hits.csv.gz' FILE_FORMAT = (TYPE = CSV, COMPRESSION = GZIP, FIELD_OPTIONALLY_ENCLOSED_BY = '"'); 116 | ALTER SESSION SET USE_CACHED_RESULT = true; 117 | 118 | SELECT COUNT(*) FROM hits2; 119 | SELECT COUNT(*) FROM hits2 WHERE AdvEngineID <> 0; 120 | SELECT SUM(AdvEngineID), COUNT(*), AVG(ResolutionWidth) FROM hits2; 121 | SELECT AVG(UserID) FROM hits2; 122 | SELECT COUNT(DISTINCT UserID) FROM hits2; 123 | SELECT COUNT(DISTINCT SearchPhrase) FROM hits2; 124 | SELECT MIN(EventDate), MAX(EventDate) FROM hits2; 125 | SELECT AdvEngineID, COUNT(*) FROM hits2 WHERE AdvEngineID <> 0 GROUP BY AdvEngineID ORDER BY COUNT(*) DESC; 126 | SELECT RegionID, COUNT(DISTINCT UserID) AS u FROM hits2 GROUP BY RegionID ORDER BY u DESC LIMIT 10; 127 | SELECT RegionID, SUM(AdvEngineID), COUNT(*) AS c, AVG(ResolutionWidth), COUNT(DISTINCT UserID) FROM hits2 GROUP BY RegionID ORDER BY c DESC LIMIT 10; 128 | SELECT MobilePhoneModel, COUNT(DISTINCT UserID) AS u FROM hits2 WHERE MobilePhoneModel <> '' GROUP BY MobilePhoneModel ORDER BY u DESC LIMIT 10; 129 | SELECT MobilePhone, MobilePhoneModel, COUNT(DISTINCT UserID) AS u FROM hits2 WHERE MobilePhoneModel <> '' GROUP BY MobilePhone, MobilePhoneModel ORDER BY u DESC LIMIT 10; 130 | SELECT SearchPhrase, COUNT(*) AS c FROM hits2 WHERE SearchPhrase <> '' GROUP BY SearchPhrase ORDER BY c DESC LIMIT 10; 131 | SELECT SearchPhrase, COUNT(DISTINCT UserID) AS u FROM hits2 WHERE SearchPhrase <> '' GROUP BY SearchPhrase ORDER BY u DESC LIMIT 10; 132 | SELECT SearchEngineID, SearchPhrase, COUNT(*) AS c FROM hits2 WHERE SearchPhrase <> '' GROUP BY SearchEngineID, SearchPhrase ORDER BY c DESC LIMIT 10; 133 | SELECT UserID, COUNT(*) FROM hits2 GROUP BY UserID ORDER BY COUNT(*) DESC LIMIT 10; 134 | SELECT UserID, SearchPhrase, COUNT(*) FROM hits2 GROUP BY UserID, SearchPhrase ORDER BY COUNT(*) DESC LIMIT 10; 135 | SELECT UserID, SearchPhrase, COUNT(*) FROM hits2 GROUP BY UserID, SearchPhrase LIMIT 10; 136 | SELECT UserID, extract(minute FROM EventTime) AS m, SearchPhrase, COUNT(*) FROM hits2 GROUP BY UserID, m, SearchPhrase ORDER BY COUNT(*) DESC LIMIT 10; 137 | SELECT UserID FROM hits2 WHERE UserID = 435090932899640449; 138 | SELECT COUNT(*) FROM hits2 WHERE URL LIKE '%google%'; 139 | SELECT SearchPhrase, MIN(URL), COUNT(*) AS c FROM hits2 WHERE URL LIKE '%google%' AND SearchPhrase <> '' GROUP BY SearchPhrase ORDER BY c DESC LIMIT 10; 140 | SELECT SearchPhrase, MIN(URL), MIN(Title), COUNT(*) AS c, COUNT(DISTINCT UserID) FROM hits2 WHERE Title LIKE '%Google%' AND URL NOT LIKE '%.google.%' AND SearchPhrase <> '' GROUP BY SearchPhrase ORDER BY c DESC LIMIT 10; 141 | SELECT * FROM hits2 WHERE URL LIKE '%google%' ORDER BY EventTime LIMIT 10; 142 | SELECT SearchPhrase FROM hits2 WHERE SearchPhrase <> '' ORDER BY EventTime LIMIT 10; 143 | SELECT SearchPhrase FROM hits2 WHERE SearchPhrase <> '' ORDER BY SearchPhrase LIMIT 10; 144 | SELECT SearchPhrase FROM hits2 WHERE SearchPhrase <> '' ORDER BY EventTime, SearchPhrase LIMIT 10; 145 | SELECT CounterID, AVG(length(URL)) AS l, COUNT(*) AS c FROM hits2 WHERE URL <> '' GROUP BY CounterID HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; 146 | SELECT REGEXP_REPLACE(Referer, '^https?://(www\.)?([^/]+)/.*$', '\2') AS k, AVG(length(Referer)) AS l, COUNT(*) AS c, MIN(Referer) FROM hits2 WHERE Referer <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; 147 | SELECT SUM(ResolutionWidth), SUM(ResolutionWidth + 1), SUM(ResolutionWidth + 2), SUM(ResolutionWidth + 3), SUM(ResolutionWidth + 4), SUM(ResolutionWidth + 5), SUM(ResolutionWidth + 6), SUM(ResolutionWidth + 7), SUM(ResolutionWidth + 8), SUM(ResolutionWidth + 9), SUM(ResolutionWidth + 10), SUM(ResolutionWidth + 11), SUM(ResolutionWidth + 12), SUM(ResolutionWidth + 13), SUM(ResolutionWidth + 14), SUM(ResolutionWidth + 15), SUM(ResolutionWidth + 16), SUM(ResolutionWidth + 17), SUM(ResolutionWidth + 18), SUM(ResolutionWidth + 19), SUM(ResolutionWidth + 20), SUM(ResolutionWidth + 21), SUM(ResolutionWidth + 22), SUM(ResolutionWidth + 23), SUM(ResolutionWidth + 24), SUM(ResolutionWidth + 25), SUM(ResolutionWidth + 26), SUM(ResolutionWidth + 27), SUM(ResolutionWidth + 28), SUM(ResolutionWidth + 29), SUM(ResolutionWidth + 30), SUM(ResolutionWidth + 31), SUM(ResolutionWidth + 32), SUM(ResolutionWidth + 33), SUM(ResolutionWidth + 34), SUM(ResolutionWidth + 35), SUM(ResolutionWidth + 36), SUM(ResolutionWidth + 37), SUM(ResolutionWidth + 38), SUM(ResolutionWidth + 39), SUM(ResolutionWidth + 40), SUM(ResolutionWidth + 41), SUM(ResolutionWidth + 42), SUM(ResolutionWidth + 43), SUM(ResolutionWidth + 44), SUM(ResolutionWidth + 45), SUM(ResolutionWidth + 46), SUM(ResolutionWidth + 47), SUM(ResolutionWidth + 48), SUM(ResolutionWidth + 49), SUM(ResolutionWidth + 50), SUM(ResolutionWidth + 51), SUM(ResolutionWidth + 52), SUM(ResolutionWidth + 53), SUM(ResolutionWidth + 54), SUM(ResolutionWidth + 55), SUM(ResolutionWidth + 56), SUM(ResolutionWidth + 57), SUM(ResolutionWidth + 58), SUM(ResolutionWidth + 59), SUM(ResolutionWidth + 60), SUM(ResolutionWidth + 61), SUM(ResolutionWidth + 62), SUM(ResolutionWidth + 63), SUM(ResolutionWidth + 64), SUM(ResolutionWidth + 65), SUM(ResolutionWidth + 66), SUM(ResolutionWidth + 67), SUM(ResolutionWidth + 68), SUM(ResolutionWidth + 69), SUM(ResolutionWidth + 70), SUM(ResolutionWidth + 71), SUM(ResolutionWidth + 72), SUM(ResolutionWidth + 73), SUM(ResolutionWidth + 74), SUM(ResolutionWidth + 75), SUM(ResolutionWidth + 76), SUM(ResolutionWidth + 77), SUM(ResolutionWidth + 78), SUM(ResolutionWidth + 79), SUM(ResolutionWidth + 80), SUM(ResolutionWidth + 81), SUM(ResolutionWidth + 82), SUM(ResolutionWidth + 83), SUM(ResolutionWidth + 84), SUM(ResolutionWidth + 85), SUM(ResolutionWidth + 86), SUM(ResolutionWidth + 87), SUM(ResolutionWidth + 88), SUM(ResolutionWidth + 89) FROM hits2; 148 | SELECT SearchEngineID, ClientIP, COUNT(*) AS c, SUM(IsRefresh), AVG(ResolutionWidth) FROM hits2 WHERE SearchPhrase <> '' GROUP BY SearchEngineID, ClientIP ORDER BY c DESC LIMIT 10; 149 | SELECT WatchID, ClientIP, COUNT(*) AS c, SUM(IsRefresh), AVG(ResolutionWidth) FROM hits2 WHERE SearchPhrase <> '' GROUP BY WatchID, ClientIP ORDER BY c DESC LIMIT 10; 150 | SELECT WatchID, ClientIP, COUNT(*) AS c, SUM(IsRefresh), AVG(ResolutionWidth) FROM hits2 GROUP BY WatchID, ClientIP ORDER BY c DESC LIMIT 10; 151 | SELECT URL, COUNT(*) AS c FROM hits2 GROUP BY URL ORDER BY c DESC LIMIT 10; 152 | SELECT 1, URL, COUNT(*) AS c FROM hits2 GROUP BY 1, URL ORDER BY c DESC LIMIT 10; 153 | SELECT ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3, COUNT(*) AS c FROM hits2 GROUP BY ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3 ORDER BY c DESC LIMIT 10; 154 | SELECT URL, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-01' AND EventDate <= '2013-07-31' AND DontCountHits = 0 AND IsRefresh = 0 AND URL <> '' GROUP BY URL ORDER BY PageViews DESC LIMIT 10; 155 | SELECT Title, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-01' AND EventDate <= '2013-07-31' AND DontCountHits = 0 AND IsRefresh = 0 AND Title <> '' GROUP BY Title ORDER BY PageViews DESC LIMIT 10; 156 | SELECT URL, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-01' AND EventDate <= '2013-07-31' AND IsRefresh = 0 AND IsLink <> 0 AND IsDownload = 0 GROUP BY URL ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; 157 | SELECT TraficSourceID, SearchEngineID, AdvEngineID, CASE WHEN (SearchEngineID = 0 AND AdvEngineID = 0) THEN Referer ELSE '' END AS Src, URL AS Dst, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-01' AND EventDate <= '2013-07-31' AND IsRefresh = 0 GROUP BY TraficSourceID, SearchEngineID, AdvEngineID, Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; 158 | SELECT URLHash, EventDate, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-01' AND EventDate <= '2013-07-31' AND IsRefresh = 0 AND TraficSourceID IN (-1, 6) AND RefererHash = 3594120000172545465 GROUP BY URLHash, EventDate ORDER BY PageViews DESC LIMIT 10 OFFSET 100; 159 | SELECT WindowClientWidth, WindowClientHeight, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-01' AND EventDate <= '2013-07-31' AND IsRefresh = 0 AND DontCountHits = 0 AND URLHash = 2868770270353813622 GROUP BY WindowClientWidth, WindowClientHeight ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; 160 | SELECT DATE_TRUNC('minute', EventTime) AS M, COUNT(*) AS PageViews FROM hits2 WHERE CounterID = 62 AND EventDate >= '2013-07-14' AND EventDate <= '2013-07-15' AND IsRefresh = 0 AND DontCountHits = 0 GROUP BY DATE_TRUNC('minute', EventTime) ORDER BY DATE_TRUNC('minute', EventTime) LIMIT 10 OFFSET 1000; -------------------------------------------------------------------------------- /resources/how_to_use.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/how_to_use.png -------------------------------------------------------------------------------- /resources/snowflake/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/snowflake/favicon.ico -------------------------------------------------------------------------------- /resources/snowflake/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/resources/snowflake/logo.png -------------------------------------------------------------------------------- /resources/snowflake_redshift_usage/1.Download.sh: -------------------------------------------------------------------------------- 1 | aws s3 cp --no-sign-request s3://redshift-downloads/redset/serverless/parts/ redset/serverless/parts/ --recursive 2 | aws s3 cp --no-sign-request s3://redshift-downloads/redset/provisioned/parts/ redset/provisioned/parts/ --recursive 3 | 4 | mkdir -p snowset 5 | wget http://www.cs.cornell.edu/~midhul/snowset/snowset-main.parquet.tar.gz 6 | tar -xzvf snowset-main.parquet.tar.gz 7 | mv snowset-main.parquet snowset -------------------------------------------------------------------------------- /resources/snowflake_redshift_usage/2.IngestCost.py: -------------------------------------------------------------------------------- 1 | import marimo 2 | 3 | __generated_with = "0.1.77" 4 | app = marimo.App() 5 | 6 | 7 | @app.cell 8 | def _(): 9 | import duckdb 10 | 11 | sf_all = duckdb.query("select * from 'snowset/part.*.parquet' using sample 1 percent (bernoulli)").df() 12 | 13 | sf_all 14 | return duckdb, sf_all 15 | 16 | 17 | @app.cell 18 | def _(duckdb): 19 | query = """ 20 | select 21 | 22 | persistentReadBytesS3 > 0 as persistentReadBytesS3, 23 | persistentReadBytesCache > 0 as persistentReadBytesCache, 24 | persistentWriteBytesCache > 0 as persistentWriteBytesCache, 25 | persistentWriteBytesS3 > 0 as persistentWriteBytesS3, 26 | intDataWriteBytesLocalSSD > 0 as intDataWriteBytesLocalSSD, 27 | intDataReadBytesLocalSSD > 0 as intDataReadBytesLocalSSD, 28 | intDataWriteBytesS3 > 0 as intDataWriteBytesS3, 29 | intDataReadBytesS3 > 0 as intDataReadBytesS3, 30 | ioRemoteExternalReadBytes > 0 as ioRemoteExternalReadBytes, 31 | intDataNetReceivedBytes > 0 as intDataNetReceivedBytes, 32 | intDataNetSentBytes > 0 as intDataNetSentBytes, 33 | 34 | producedRows == 0 as producedRows0, 35 | producedRows == 1 as producedRows1, 36 | 37 | returnedRows == 0 as returnedRows0, 38 | returnedRows == 1 as returnedRows1, 39 | 40 | remoteSeqScanFileOps > 0 as remoteSeqScanFileOps, 41 | localSeqScanFileOps > 0 as localSeqScanFileOps, 42 | localWriteFileOps > 0 as localWriteFileOps, 43 | remoteWriteFileOps > 0 as remoteWriteFileOps, 44 | filesCreated > 0 as filesCreated, 45 | profPersistentReadCache > 0 as profPersistentReadCache, 46 | profPersistentWriteCache > 0 as profPersistentWriteCache, 47 | profPersistentReadS3 > 0 as profPersistentReadS3, 48 | profPersistentWriteS3 > 0 as profPersistentWriteS3, 49 | profIntDataReadLocalSSD > 0 as profIntDataReadLocalSSD, 50 | profIntDataWriteLocalSSD > 0 as profIntDataWriteLocalSSD, 51 | profIntDataReadS3 > 0 as profIntDataReadS3, 52 | profIntDataWriteS3 > 0 as profIntDataWriteS3, 53 | profRemoteExtRead > 0 as profRemoteExtRead, 54 | profRemoteExtWrite > 0 as profRemoteExtWrite, 55 | profResWriteS3 > 0 as profResWriteS3, 56 | profFsMeta > 0 as profFsMeta, 57 | profDataExchangeNet > 0 as profDataExchangeNet, 58 | profDataExchangeMsg > 0 as profDataExchangeMsg, 59 | profControlPlaneMsg > 0 as profControlPlaneMsg, 60 | profOs > 0 as profOs, 61 | profMutex > 0 as profMutex, 62 | profSetup > 0 as profSetup, 63 | profSetupMesh > 0 as profSetupMesh, 64 | profTeardown > 0 as profTeardown, 65 | profScanRso > 0 as profScanRso, 66 | profXtScanRso > 0 as profXtScanRso, 67 | profProjRso > 0 as profProjRso, 68 | profSortRso > 0 as profSortRso, 69 | profFilterRso > 0 as profFilterRso, 70 | profResRso > 0 as profResRso, 71 | profDmlRso > 0 as profDmlRso, 72 | profHjRso > 0 as profHjRso, 73 | profBufRso > 0 as profBufRso, 74 | profFlatRso > 0 as profFlatRso, 75 | profBloomRso > 0 as profBloomRso, 76 | profAggRso > 0 as profAggRso, 77 | profBandRso > 0 as profBandRso, 78 | 79 | from 'snowset/part.*.parquet' 80 | """ 81 | sf_features = duckdb.query(query).pl() 82 | sf_corr = sf_features.corr() 83 | sf_corr 84 | return query, sf_corr, sf_features 85 | 86 | 87 | @app.cell 88 | def _(sf_corr): 89 | import numpy as np 90 | import scipy.cluster.hierarchy as sch 91 | import seaborn as sns 92 | import matplotlib.pyplot as plt 93 | 94 | # Compute the correlation matrix 95 | corr_matrix = sf_corr.to_numpy() 96 | 97 | # Clean up numerical error 98 | for i in range(len(corr_matrix)): 99 | corr_matrix[i, i] = 1 100 | corr_matrix = (corr_matrix + corr_matrix.T) / 2 101 | 102 | # Create the distance matrix: 1 - |correlation| 103 | dist_matrix = 1 - np.abs(corr_matrix) 104 | 105 | # Perform hierarchical clustering on the symmetric distance matrix 106 | linkage = sch.linkage(sch.distance.squareform(dist_matrix), method='average') 107 | 108 | # Get the order of the features after clustering 109 | ordered_indices = sch.dendrogram(linkage, no_plot=True)['leaves'] 110 | 111 | # Reorder the matrix using the ordered indices 112 | corr_matrix = corr_matrix[np.ix_(ordered_indices, ordered_indices)] 113 | 114 | # Step 4: Plot the reordered correlation matrix 115 | plt.figure(figsize=(10, 8)) 116 | sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', vmin=-1, vmax=1, yticklabels=[sf_corr.columns[i] for i in ordered_indices]) 117 | plt.title('Reordered Correlation Matrix with Hierarchical Clustering') 118 | plt.show() 119 | return ( 120 | corr_matrix, 121 | dist_matrix, 122 | i, 123 | linkage, 124 | np, 125 | ordered_indices, 126 | plt, 127 | sch, 128 | sns, 129 | ) 130 | 131 | 132 | @app.cell 133 | def _(duckdb): 134 | import altair as alt 135 | 136 | sf_costs = duckdb.query( 137 | """ 138 | select 139 | case 140 | when profRemoteExtRead > 0 then 'Ingest' 141 | when profRemoteExtWrite > 0 then 'Export' 142 | when persistentReadBytesS3 + persistentReadBytesCache == 0 and persistentWriteBytesS3 > 0 then 'Ingest' 143 | when persistentReadBytesS3 + persistentReadBytesCache > 0 and persistentWriteBytesS3 > 0 then 'Transformation' 144 | when persistentReadBytesS3 + persistentReadBytesCache > 0 and persistentWriteBytesS3 == 0 then 'Read' 145 | else 'Other' end as "Query Type", 146 | sum(durationExec * warehouseSize) as "Cost", 147 | from 'snowset/part.*.parquet' 148 | group by all 149 | """ 150 | ).df() 151 | sf_costs['Cost'] = sf_costs['Cost'] / sf_costs['Cost'].sum() 152 | 153 | alt.Chart(sf_costs).mark_bar(color='blue').encode( 154 | x='Cost', 155 | y=alt.Y('Query Type', sort=["Ingest", "Transformation", "Read", "Export", "Other"]) 156 | ) 157 | return alt, sf_costs 158 | 159 | 160 | @app.cell 161 | def _(duckdb): 162 | redshift_sample = duckdb.query("select * from 'redset/provisioned/parts/*.parquet' using sample 1 percent (bernoulli)").df() 163 | redshift_sample 164 | return redshift_sample, 165 | 166 | 167 | @app.cell 168 | def _(duckdb): 169 | rs_sum = duckdb.query( 170 | """ 171 | select 172 | query_type, 173 | sum(if(read_table_ids is not null, execution_duration_ms * cluster_size, 0)) as reads, 174 | sum(if(write_table_ids is not null, execution_duration_ms * cluster_size, 0)) as writes, 175 | sum(execution_duration_ms * cluster_size) as cost 176 | from 'redset/provisioned/parts/*.parquet' 177 | group by all 178 | order by cost desc 179 | """ 180 | ).df() 181 | rs_sum['reads'] = rs_sum['reads'] / rs_sum['cost'].sum() 182 | rs_sum['writes'] = rs_sum['writes'] / rs_sum['cost'].sum() 183 | rs_sum['cost'] = rs_sum['cost'] / rs_sum['cost'].sum() 184 | rs_sum 185 | return rs_sum, 186 | 187 | 188 | @app.cell 189 | def _(alt, duckdb): 190 | redshift_costs = duckdb.query( 191 | """ 192 | select 193 | case when query_type in ('insert', 'copy', 'delete', 'update') then 'Ingest' 194 | when query_type in ('ctas') then 'Transformation' 195 | when query_type in ('select') then 'Read' 196 | when query_type in ('unload') then 'Export' 197 | when query_type in ('analyze', 'vacuum', 'other') then 'Other' 198 | else error(query_type) end as "Query Type", 199 | sum(execution_duration_ms * cluster_size) as "Cost" 200 | from 'redset/provisioned/parts/*.parquet' 201 | where cluster_size is not null 202 | group by all 203 | """ 204 | ).df() 205 | redshift_costs['Cost'] = redshift_costs['Cost'] / redshift_costs['Cost'].sum() 206 | 207 | alt.Chart(redshift_costs).mark_bar(color='red').encode( 208 | x='Cost', 209 | y=alt.Y('Query Type', sort=["Ingest", "Transformation", "Read", "Export", "Other"]) 210 | ) 211 | return redshift_costs, 212 | 213 | 214 | @app.cell 215 | def _(alt, duckdb): 216 | combined_costs = duckdb.query( 217 | """ 218 | select 'Snowflake' as "Warehouse", * 219 | from sf_costs 220 | union all 221 | select 'Redshift' as "Warehouse", * 222 | from redshift_costs 223 | """ 224 | ).df() 225 | 226 | alt.Chart(combined_costs).mark_bar().encode( 227 | color=alt.Color('Warehouse').scale(scheme='paired'), 228 | yOffset='Warehouse', 229 | x='Cost', 230 | y=alt.Y('Query Type', sort=["Ingest", "Transformation", "Read", "Export", "Other"]) 231 | ).properties( 232 | width=200, 233 | height=200, 234 | ) 235 | return combined_costs, 236 | 237 | 238 | @app.cell 239 | def _(duckdb): 240 | 241 | sf_sizes = duckdb.query( 242 | """ 243 | select warehouseSize * perServerCores as vCPU, count(*) as count 244 | from 'snowset/part.*.parquet' 245 | group by all 246 | having vCPU >= 8 247 | order by all""" 248 | ).df() 249 | 250 | alt.Chart(sf_sizes).mark_bar().encode(x='count:Q', y='vCPU:N') 251 | return alt, sf_sizes 252 | 253 | 254 | @app.cell 255 | def _(alt, duckdb): 256 | sf_scanned = duckdb.query( 257 | """ 258 | select floor(log2((persistentReadBytesS3 + persistentReadBytesCache) / 1024 / 1024)) as log_mb_scanned, 2**log_mb_scanned as mb_scanned, count(*) as count 259 | from 'snowset/part.*.parquet' 260 | where persistentWriteBytesS3 = 0 261 | and profRemoteExtRead == 0 262 | and profRemoteExtWrite == 0 263 | and persistentReadBytesS3 + persistentReadBytesCache >= 1024*1024 264 | group by all 265 | order by all 266 | """ 267 | ).df() 268 | sf_scanned['p'] = sf_scanned['count'] / sf_scanned['count'].sum() 269 | 270 | alt.Chart(sf_scanned).mark_bar().encode(x='mb_scanned:O', y='p:Q') 271 | return sf_scanned, 272 | 273 | 274 | @app.cell 275 | def _(alt, duckdb): 276 | rs_scanned = duckdb.query( 277 | """ 278 | select floor(log2(mbytes_scanned)) as log_mb_scanned, 2**log_mb_scanned as mb_scanned, count(*) as count 279 | from 'redset/provisioned/parts/*.parquet' 280 | where query_type = 'select' 281 | and mbytes_scanned >= 1 282 | and num_permanent_tables_accessed > 0 283 | group by all 284 | order by all 285 | """ 286 | ).df() 287 | rs_scanned['p'] = rs_scanned['count'] / rs_scanned['count'].sum() 288 | 289 | alt.Chart(rs_scanned).mark_bar().encode(x='mb_scanned:O', y='p:Q') 290 | return rs_scanned, 291 | 292 | 293 | @app.cell 294 | def _(alt, duckdb): 295 | combined_scanned = duckdb.query( 296 | """ 297 | with combined as ( 298 | select 'Snowflake' as warehouse, * from sf_scanned 299 | union all 300 | select 'Redshift' as warehouse, * from rs_scanned 301 | ) 302 | select 303 | * exclude (mb_scanned), 304 | least(mb_scanned, 2**20) as mb_scanned, 305 | case when mb_scanned < 1024 then format('{:d} MB', mb_scanned::int) 306 | when mb_scanned < 1024*1024 then format('{:d} GB', (mb_scanned/1024)::int) 307 | else '1 TB+' end as label 308 | from combined 309 | """ 310 | ).df() 311 | labels = {2**i: f"{2**i} MB" for i in range(0, 10)} | {2**i: f"{2**(i-10)} GB" for i in range(10,20)} | {2**20: f"1 TB+"} 312 | 313 | alt.Chart(combined_scanned).mark_bar().encode( 314 | color=alt.Color('warehouse').scale(scheme='paired').title('System'), 315 | xOffset='warehouse', 316 | x=alt.X('mb_scanned:O').title('Data Scanned').axis(labelExpr='datum.value < 1024 ? datum.value + " MB" : datum.value < 1024*1024 ? datum.value/1024 + " GB" : "1 TB+"'), 317 | y=alt.Y('p:Q').title('Query Fraction') 318 | ).properties( 319 | width=400, 320 | height=200, 321 | ) 322 | return combined_scanned, labels 323 | 324 | 325 | @app.cell 326 | def _(duckdb): 327 | duckdb.query( 328 | """ 329 | select 'Redshift' as warehouse, quantile_cont(mbytes_scanned, [0.5, 0.999]) as q 330 | from 'redset/provisioned/parts/*.parquet' 331 | where query_type = 'select' 332 | and mbytes_scanned >= 1 333 | and num_permanent_tables_accessed > 0 334 | 335 | union all 336 | 337 | select 'Snowflake' as warehouse, quantile_cont((persistentReadBytesS3 + persistentReadBytesCache) / 1024 / 1024, [0.5, 0.999]) as q 338 | from 'snowset/part.*.parquet' 339 | where persistentWriteBytesS3 = 0 340 | and profRemoteExtRead == 0 341 | and profRemoteExtWrite == 0 342 | and persistentReadBytesS3 + persistentReadBytesCache >= 1024*1024 343 | """ 344 | ).df() 345 | return 346 | 347 | 348 | if __name__ == "__main__": 349 | app.run() 350 | -------------------------------------------------------------------------------- /snowflake.aws.lambda.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 2 | 3 | RUN pip install 'poetry==1.8.3' 4 | # Copy everything. (Note: If needed, we can use .dockerignore to limit what's copied.) 5 | COPY . ${LAMBDA_TASK_ROOT} 6 | RUN poetry config virtualenvs.create false --local && poetry install --no-interaction --no-ansi --no-root --only main --no-cache 7 | 8 | CMD [ "universql.protocol.lambda.snowflake" ] 9 | 10 | 11 | -------------------------------------------------------------------------------- /ssl/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/extract.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | from snowflake.connector import ProgrammingError 5 | 6 | from tests.integration.utils import execute_query, universql_connection, SIMPLE_QUERY, ALL_COLUMNS_QUERY 7 | 8 | class TestConnectivity: 9 | def test_invalid_auth(self): 10 | # we don't have many tests here because snowflake blocks in case you try invalid pass for 3 times 11 | with universql_connection(password="invalidPass", warehouse=None) as conn: 12 | with pytest.raises(ProgrammingError, match="Incorrect username or password was specified"): 13 | execute_query(conn, "SELECT 1") 14 | 15 | 16 | class TestSelect: 17 | def test_simple_select(self): 18 | with universql_connection(warehouse=None) as conn: 19 | universql_result = execute_query(conn, SIMPLE_QUERY) 20 | assert universql_result.num_rows == 1 21 | 22 | def test_complex_select(self): 23 | with universql_connection(warehouse=None) as conn: 24 | universql_result = execute_query(conn, ALL_COLUMNS_QUERY) 25 | assert universql_result.num_rows == 1 26 | 27 | def test_switch_schema(self): 28 | with universql_connection(warehouse=None) as conn: 29 | execute_query(conn, "USE DATABASE snowflake") 30 | universql_result = execute_query(conn, "SHOW SCHEMAS") 31 | assert universql_result.num_rows > 0, f"The query did not return any rows!" 32 | 33 | execute_query(conn, "USE SCHEMA snowflake.account_usage") 34 | universql_result = execute_query(conn, "SHOW SCHEMAS") 35 | assert universql_result.num_rows > 0, f"The query did not return any rows!" 36 | 37 | execute_query(conn, "USE snowflake") 38 | universql_result = execute_query(conn, "SHOW SCHEMAS") 39 | assert universql_result.num_rows > 0, f"The query did not return any rows!" 40 | 41 | execute_query(conn, "USE snowflake.account_usage") 42 | universql_result = execute_query(conn, "SHOW SCHEMAS") 43 | assert universql_result.num_rows > 0, f"The query did not return any rows!" 44 | 45 | def test_success_after_failure(self): 46 | with universql_connection(warehouse=None) as conn: 47 | with pytest.raises(ProgrammingError): 48 | execute_query(conn, "select * from not_exists") 49 | result = execute_query(conn, "select 1") 50 | assert result.num_rows == 1 51 | 52 | def test_union(self): 53 | with universql_connection(warehouse=None) as conn: 54 | result = execute_query(conn, "select 1 union all select 2") 55 | assert result.num_rows == 2 56 | 57 | # TODO: add the create stage for landing_stage in the pre_hook 58 | def test_copy_into_for_ryan(self): 59 | with universql_connection(warehouse=None) as conn: 60 | result = execute_query(conn, """ 61 | CREATE OR REPLACE TEMPORARY TABLE DEVICE_METADATA_REF ( 62 | device_id VARCHAR, 63 | device_name VARCHAR, 64 | device_type VARCHAR, 65 | manufacturer VARCHAR, 66 | model_number VARCHAR, 67 | firmware_version VARCHAR, 68 | installation_date DATE, 69 | location_id VARCHAR, 70 | location_name VARCHAR, 71 | facility_zone VARCHAR, 72 | is_active BOOLEAN, 73 | expected_lifetime_months INT, 74 | maintenance_interval_days INT, 75 | last_maintenance_date DATE 76 | ); 77 | 78 | COPY INTO DEVICE_METADATA_REF 79 | FROM @landing_stage/initial_objects/device_metadata.csv 80 | FILE_FORMAT = (SKIP_HEADER = 1); 81 | """) 82 | assert result.num_rows != 0 83 | 84 | # def test_clickbench(self): 85 | # with universql_connection(warehouse=None) as conn: 86 | # result = execute_query(conn, """ 87 | # CREATE TEMP TABLE hits2 AS SELECT 88 | # CAST(WatchID AS BIGINT) AS WatchID, 89 | # CAST(JavaEnable AS SMALLINT) AS JavaEnable, 90 | # CAST(Title AS TEXT) AS Title, 91 | # CAST(GoodEvent AS SMALLINT) AS GoodEvent, 92 | # epoch_ms(EventTime * 1000) AS EventTime, 93 | # DATE '1970-01-01' + INTERVAL (EventDate) DAYS AS EventDate, 94 | # CAST(CounterID AS INTEGER) AS CounterID, 95 | # CAST(ClientIP AS INTEGER) AS ClientIP, 96 | # CAST(RegionID AS INTEGER) AS RegionID, 97 | # CAST(UserID AS BIGINT) AS UserID, 98 | # CAST(CounterClass AS SMALLINT) AS CounterClass, 99 | # CAST(OS AS SMALLINT) AS OS, 100 | # CAST(UserAgent AS SMALLINT) AS UserAgent, 101 | # CAST(URL AS TEXT) AS URL, 102 | # CAST(Referer AS TEXT) AS Referer, 103 | # CAST(IsRefresh AS SMALLINT) AS IsRefresh, 104 | # CAST(RefererCategoryID AS SMALLINT) AS RefererCategoryID, 105 | # CAST(RefererRegionID AS INTEGER) AS RefererRegionID, 106 | # CAST(URLCategoryID AS SMALLINT) AS URLCategoryID, 107 | # CAST(URLRegionID AS INTEGER) AS URLRegionID, 108 | # CAST(ResolutionWidth AS SMALLINT) AS ResolutionWidth, 109 | # CAST(ResolutionHeight AS SMALLINT) AS ResolutionHeight, 110 | # CAST(ResolutionDepth AS SMALLINT) AS ResolutionDepth, 111 | # CAST(FlashMajor AS SMALLINT) AS FlashMajor, 112 | # CAST(FlashMinor AS SMALLINT) AS FlashMinor, 113 | # CAST(FlashMinor2 AS TEXT) AS FlashMinor2, 114 | # CAST(NetMajor AS SMALLINT) AS NetMajor, 115 | # CAST(NetMinor AS SMALLINT) AS NetMinor, 116 | # CAST(UserAgentMajor AS SMALLINT) AS UserAgentMajor, 117 | # CAST(UserAgentMinor AS VARCHAR(255)) AS UserAgentMinor, 118 | # CAST(CookieEnable AS SMALLINT) AS CookieEnable, 119 | # CAST(JavascriptEnable AS SMALLINT) AS JavascriptEnable, 120 | # CAST(IsMobile AS SMALLINT) AS IsMobile, 121 | # CAST(MobilePhone AS SMALLINT) AS MobilePhone, 122 | # CAST(MobilePhoneModel AS TEXT) AS MobilePhoneModel, 123 | # CAST(Params AS TEXT) AS Params, 124 | # CAST(IPNetworkID AS INTEGER) AS IPNetworkID, 125 | # CAST(TraficSourceID AS SMALLINT) AS TraficSourceID, 126 | # CAST(SearchEngineID AS SMALLINT) AS SearchEngineID, 127 | # CAST(SearchPhrase AS TEXT) AS SearchPhrase, 128 | # CAST(AdvEngineID AS SMALLINT) AS AdvEngineID, 129 | # CAST(IsArtifical AS SMALLINT) AS IsArtifical, 130 | # CAST(WindowClientWidth AS SMALLINT) AS WindowClientWidth, 131 | # CAST(WindowClientHeight AS SMALLINT) AS WindowClientHeight, 132 | # CAST(ClientTimeZone AS SMALLINT) AS ClientTimeZone, 133 | # epoch_ms(ClientEventTime * 1000) AS ClientEventTime, 134 | # CAST(SilverlightVersion1 AS SMALLINT) AS SilverlightVersion1, 135 | # CAST(SilverlightVersion2 AS SMALLINT) AS SilverlightVersion2, 136 | # CAST(SilverlightVersion3 AS INTEGER) AS SilverlightVersion3, 137 | # CAST(SilverlightVersion4 AS SMALLINT) AS SilverlightVersion4, 138 | # CAST(PageCharset AS TEXT) AS PageCharset, 139 | # CAST(CodeVersion AS INTEGER) AS CodeVersion, 140 | # CAST(IsLink AS SMALLINT) AS IsLink, 141 | # CAST(IsDownload AS SMALLINT) AS IsDownload, 142 | # CAST(IsNotBounce AS SMALLINT) AS IsNotBounce, 143 | # CAST(FUniqID AS BIGINT) AS FUniqID, 144 | # CAST(OriginalURL AS TEXT) AS OriginalURL, 145 | # CAST(HID AS INTEGER) AS HID, 146 | # CAST(IsOldCounter AS SMALLINT) AS IsOldCounter, 147 | # CAST(IsEvent AS SMALLINT) AS IsEvent, 148 | # CAST(IsParameter AS SMALLINT) AS IsParameter, 149 | # CAST(DontCountHits AS SMALLINT) AS DontCountHits, 150 | # CAST(WithHash AS SMALLINT) AS WithHash, 151 | # CAST(HitColor AS CHAR) AS HitColor, 152 | # epoch_ms(LocalEventTime * 1000) AS LocalEventTime, 153 | # CAST(Age AS SMALLINT) AS Age, 154 | # CAST(Sex AS SMALLINT) AS Sex, 155 | # CAST(Income AS SMALLINT) AS Income, 156 | # CAST(Interests AS SMALLINT) AS Interests, 157 | # CAST(Robotness AS SMALLINT) AS Robotness, 158 | # CAST(RemoteIP AS INTEGER) AS RemoteIP, 159 | # CAST(WindowName AS INTEGER) AS WindowName, 160 | # CAST(OpenerName AS INTEGER) AS OpenerName, 161 | # CAST(HistoryLength AS SMALLINT) AS HistoryLength, 162 | # CAST(BrowserLanguage AS TEXT) AS BrowserLanguage, 163 | # CAST(BrowserCountry AS TEXT) AS BrowserCountry, 164 | # CAST(SocialNetwork AS TEXT) AS SocialNetwork, 165 | # CAST(SocialAction AS TEXT) AS SocialAction, 166 | # CAST(HTTPError AS SMALLINT) AS HTTPError, 167 | # CAST(SendTiming AS INTEGER) AS SendTiming, 168 | # CAST(DNSTiming AS INTEGER) AS DNSTiming, 169 | # CAST(ConnectTiming AS INTEGER) AS ConnectTiming, 170 | # CAST(ResponseStartTiming AS INTEGER) AS ResponseStartTiming, 171 | # CAST(ResponseEndTiming AS INTEGER) AS ResponseEndTiming, 172 | # CAST(FetchTiming AS INTEGER) AS FetchTiming, 173 | # CAST(SocialSourceNetworkID AS SMALLINT) AS SocialSourceNetworkID, 174 | # CAST(SocialSourcePage AS TEXT) AS SocialSourcePage, 175 | # CAST(ParamPrice AS BIGINT) AS ParamPrice, 176 | # CAST(ParamOrderID AS TEXT) AS ParamOrderID, 177 | # CAST(ParamCurrency AS TEXT) AS ParamCurrency, 178 | # CAST(ParamCurrencyID AS SMALLINT) AS ParamCurrencyID, 179 | # CAST(OpenstatServiceName AS TEXT) AS OpenstatServiceName, 180 | # CAST(OpenstatCampaignID AS TEXT) AS OpenstatCampaignID, 181 | # CAST(OpenstatAdID AS TEXT) AS OpenstatAdID, 182 | # CAST(OpenstatSourceID AS TEXT) AS OpenstatSourceID, 183 | # CAST(UTMSource AS TEXT) AS UTMSource, 184 | # CAST(UTMMedium AS TEXT) AS UTMMedium, 185 | # CAST(UTMCampaign AS TEXT) AS UTMCampaign, 186 | # CAST(UTMContent AS TEXT) AS UTMContent, 187 | # CAST(UTMTerm AS TEXT) AS UTMTerm, 188 | # CAST(FromTag AS TEXT) AS FromTag, 189 | # CAST(HasGCLID AS SMALLINT) AS HasGCLID, 190 | # CAST(RefererHash AS BIGINT) AS RefererHash, 191 | # CAST(URLHash AS BIGINT) AS URLHash, 192 | # CAST(CLID AS INTEGER) AS CLID 193 | # FROM read_parquet('s3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.*') limit 10; 194 | # -- COPY hits2 FROM 's3://clickhouse-public-datasets/hits_compatible/athena_partitioned/*' (FORMAT PARQUET) 195 | # -- COPY INTO test.public.hits2 FROM 's3://clickhouse-public-datasets/hits_compatible/hits.csv.gz' FILE_FORMAT = (TYPE = CSV, COMPRESSION = GZIP, FIELD_OPTIONALLY_ENCLOSED_BY = '"') 196 | # """) 197 | # 198 | # result = execute_query(conn, "select count(*) from hits2") 199 | # assert result.num_rows == 1 -------------------------------------------------------------------------------- /tests/integration/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from snowflake.connector import ProgrammingError 5 | 6 | from tests.integration.utils import execute_query, universql_connection, SIMPLE_QUERY 7 | 8 | 9 | class TestCreate: 10 | def test_create_iceberg_table(self): 11 | external_volume = os.getenv("PYTEST_EXTERNAL_VOLUME") 12 | if external_volume is None: 13 | pytest.skip("No external volume provided, set PYTEST_EXTERNAL_VOLUME") 14 | 15 | with universql_connection(warehouse=None) as conn: 16 | execute_query(conn, f""" 17 | CREATE OR REPLACE ICEBERG TABLE test_iceberg_table 18 | external_volume = {external_volume} 19 | catalog = 'SNOWFLAKE' 20 | BASE_LOCATION = 'test_iceberg_table' 21 | AS {SIMPLE_QUERY} 22 | """) 23 | universql_result = execute_query(conn, f"SELECT * FROM test_iceberg_table LIMIT 1") 24 | assert universql_result.num_rows == 1 25 | 26 | def test_create_temp_table(self): 27 | with universql_connection(warehouse=None) as conn: 28 | execute_query(conn, f"CREATE TEMP TABLE test_temp_table AS {SIMPLE_QUERY}") 29 | universql_result = execute_query(conn, "SELECT * FROM test_temp_table") 30 | assert universql_result.num_rows == 1 31 | 32 | # we can potentially run {SIMPLE_QUERY} on DuckDB and then CREATE TABLE on Snowflake with PyArrow (data upload, no processing) 33 | # but we need a mechanism to analyze the query and make sure it's WORTH running query locally as we need a running warehouse anyways 34 | def test_create_native_table(self): 35 | with universql_connection(warehouse=None) as conn: 36 | with pytest.raises(ProgrammingError, match="DuckDB can't create native Snowflake tables"): 37 | execute_query(conn, f"CREATE TABLE test_native_table AS {SIMPLE_QUERY}") -------------------------------------------------------------------------------- /tests/integration/object_identifiers.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | 5 | from tests.integration.utils import execute_query, universql_connection 6 | import os 7 | 8 | 9 | def generate_name_variants(name): 10 | lowercase = name.lower() 11 | uppercase = name.upper() 12 | mixed_case = name.capitalize() 13 | in_quotes = '"' + name.upper() + '"' 14 | return [lowercase, uppercase, mixed_case, in_quotes] 15 | 16 | def generate_select_statement_combos(sets_of_identifiers, connected_db=None, connected_schema=None): 17 | select_statements = [] 18 | for set in sets_of_identifiers: 19 | set_of_select_statements = [] 20 | database = set.get("database") 21 | schema = set.get("schema") 22 | table = set.get("table") 23 | if table is not None: 24 | table_variants = generate_name_variants(table) 25 | if database == connected_db and schema == connected_schema: 26 | for table_variant in table_variants: 27 | set_of_select_statements.append(f"SELECT * FROM {table_variant}") 28 | else: 29 | raise Exception("No table name provided for a select statement combo.") 30 | 31 | if schema is not None: 32 | schema_variants = generate_name_variants(schema) 33 | if database == connected_db: 34 | object_name_combos = product(schema_variants, table_variants) 35 | for schema_name, table_name in object_name_combos: 36 | set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}") 37 | elif database is not None: 38 | raise Exception("You must provide a schema name if you provide a database name.") 39 | 40 | if database is not None: 41 | database_variants = generate_name_variants(database) 42 | object_name_combos = product(database_variants, schema_variants, table_variants) 43 | for db_name, schema_name, table_name in object_name_combos: 44 | set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}") 45 | select_statements = select_statements + set_of_select_statements 46 | 47 | return select_statements 48 | 49 | class TestObjectIdentifiers: 50 | def test_querying_in_connected_db_and_schema(self): 51 | external_volume = os.getenv("PYTEST_EXTERNAL_VOLUME") 52 | if external_volume is None: 53 | pytest.skip("No external volume provided, set PYTEST_EXTERNAL_VOLUME") 54 | 55 | connected_db = "universql1" 56 | connected_schema = "same_schema" 57 | 58 | combos = [ 59 | { 60 | "database": "universql1", 61 | "schema": "same_schema", 62 | "table": "dim_devices" 63 | }, 64 | { 65 | "database": "universql1", 66 | "schema": "different_schema", 67 | "table": "different_dim_devices" 68 | }, 69 | { 70 | "database": "universql2", 71 | "schema": "another_schema", 72 | "table": "another_dim_devices" 73 | }, 74 | ] 75 | 76 | select_statements = generate_select_statement_combos(combos, connected_db, connected_schema) 77 | successful_queries = [] 78 | failed_queries = [] 79 | with universql_connection(database=connected_db, schema=connected_schema) as conn: 80 | execute_query(conn, f""" 81 | CREATE DATABASE IF NOT EXISTS universql1; 82 | CREATE DATABASE IF NOT EXISTS universql2; 83 | CREATE SCHEMA IF NOT EXISTS universql1.same_schema; 84 | CREATE SCHEMA IF NOT EXISTS universql1.different_schema; 85 | CREATE SCHEMA IF NOT EXISTS universql2.another_schema; 86 | 87 | CREATE ICEBERG TABLE IF NOT EXISTS universql1.same_schema.dim_devices("1" int) 88 | external_volume = {external_volume} 89 | catalog = 'SNOWFLAKE' 90 | BASE_LOCATION = 'universql1.same_schema.dim_devices' 91 | AS select 1; 92 | 93 | CREATE ICEBERG TABLE IF NOT EXISTS universql1.different_schema.different_dim_devices("1" int) 94 | external_volume = {external_volume} 95 | catalog = 'SNOWFLAKE' 96 | BASE_LOCATION = 'universql1.different_schema.different_dim_devices' 97 | AS select 1; 98 | 99 | CREATE ICEBERG TABLE IF NOT EXISTS universql2.another_schema.another_dim_devices("1" int) 100 | external_volume = {external_volume} 101 | catalog = 'SNOWFLAKE' 102 | BASE_LOCATION = 'universql2.another_schema.another_dim_devices' 103 | AS select 1; 104 | """) 105 | 106 | for query in select_statements: 107 | try: 108 | execute_query(conn, query) 109 | successful_queries.append(query) 110 | continue 111 | except Exception as e: 112 | failed_queries.append(f"{query} | FAILED - {str(e)}") 113 | if len(failed_queries) > 0: 114 | error_message = f"The following {len(failed_queries)} queries failed:" 115 | for query in failed_queries: 116 | error_message = f"{error_message}\n{query}" 117 | pytest.fail(error_message) -------------------------------------------------------------------------------- /tests/integration/transform.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.integration.utils import execute_query, universql_connection, SIMPLE_QUERY 4 | 5 | 6 | class TestTransform: 7 | @pytest.mark.skip(reason="not implemented") 8 | def test_insert_table(self): 9 | pass 10 | 11 | @pytest.mark.skip(reason="not implemented") 12 | def test_delete_table(self): 13 | pass 14 | 15 | @pytest.mark.skip(reason="not implemented") 16 | def test_update_table(self): 17 | pass 18 | 19 | @pytest.mark.skip(reason="not implemented") 20 | def test_merge_table(self): 21 | pass 22 | 23 | @pytest.mark.skip(reason="not implemented") 24 | def test_drop_table(self): 25 | pass 26 | -------------------------------------------------------------------------------- /tests/integration/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socketserver 4 | import threading 5 | from contextlib import contextmanager 6 | from typing import Generator, Optional 7 | 8 | import pyarrow 9 | import pytest 10 | from click.testing import CliRunner 11 | from dotenv import load_dotenv 12 | from snowflake.connector import connect as snowflake_connect, SnowflakeConnection 13 | from snowflake.connector.config_manager import CONFIG_MANAGER 14 | from snowflake.connector.constants import CONNECTIONS_FILE 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | load_dotenv() 19 | 20 | from universql.util import LOCALHOSTCOMPUTING_COM 21 | 22 | # Configuration using separate connection strings for direct and proxy connections 23 | # export SNOWFLAKE_CONNECTION_STRING="account=xxx;user=xxx;password=xxx;warehouse=xxx;database=xxx;schema=xxx" 24 | # export UNIVERSQL_CONNECTION_STRING="warehouse=xxx" 25 | SNOWFLAKE_CONNECTION_NAME = os.getenv("SNOWFLAKE_CONNECTION_NAME") or "default" 26 | logging.getLogger("snowflake.connector").setLevel(logging.INFO) 27 | 28 | # Allow Universql to start 29 | os.environ["MAX_CON_RETRY_ATTEMPTS"] = "15" 30 | 31 | SIMPLE_QUERY = """ 32 | SELECT 1 as test 33 | """ 34 | 35 | ALL_COLUMNS_QUERY = """ 36 | SELECT 37 | -- Numeric data types 38 | 12345678901234567890123456789012345678::NUMBER AS sample_number, 39 | 123.45::DECIMAL AS sample_decimal, 40 | 6789::INT AS sample_int, 41 | 9876543210::BIGINT AS sample_bigint, 42 | 123::SMALLINT AS sample_smallint, 43 | 42::TINYINT AS sample_tinyint, 44 | 255::BYTEINT AS sample_byteint, 45 | 12345.6789::FLOAT AS sample_float, 46 | 123456789.123456789::DOUBLE AS sample_double, 47 | 48 | -- String & binary data types 49 | 'Sample text'::VARCHAR AS sample_varchar, 50 | 'C'::CHAR AS sample_char, 51 | 'Another sample text'::STRING AS sample_string, 52 | 'More text'::TEXT AS sample_text, 53 | cast('307834' as binary) AS sample_binary, 54 | cast('307834' as varbinary) AS sample_varbinary, 55 | 56 | -- Logical data types 57 | TRUE::BOOLEAN AS sample_boolean, 58 | 59 | -- Date & time data types 60 | '2023-01-01'::DATE AS sample_date, 61 | -- '12:34:56'::TIME AS sample_time, # somehow python is broken but java sdk works 62 | 63 | '2023-01-01 10:34:56'::DATETIME AS sample_datetime, 64 | '2023-01-01 11:34:56'::TIMESTAMP AS sample_timestamp, 65 | -- no support for duckdb 66 | '2023-01-01 12:34:56'::TIMESTAMP_LTZ AS sample_timestamp_ltz, 67 | '2023-01-01 13:34:56'::TIMESTAMP_NTZ AS sample_timestamp_ntz, 68 | 69 | -- no support for snowflake + duckdb 70 | '2024-08-03 22:51:25.595+01'::TIMESTAMP_TZ AS sample_timestamp_tz, 71 | 72 | -- Semi-structured data types 73 | PARSE_JSON('{"key":"value"}')::VARIANT AS sample_variant, 74 | OBJECT_CONSTRUCT('foo', 1234567, 'distinct_province', (SELECT 1)) AS sample_object, 75 | ARRAY_CONSTRUCT(1, 2, 3, 4) AS sample_array, 76 | 77 | -- no support for 78 | -- Geospatial data types 79 | -- TO_GEOGRAPHY('LINESTRING(30 10, 10 30, 40 40)') AS sample_geometry, 80 | 81 | -- no support for 82 | -- Vector data types 83 | -- [1.1,2.2,3]::VECTOR(FLOAT,3) AS sample_vector 84 | """ 85 | 86 | server_cache = {} 87 | 88 | 89 | @contextmanager 90 | def snowflake_connection(**properties) -> Generator: 91 | print(f"Reading {CONNECTIONS_FILE} with {properties}") 92 | snowflake_connection_name = _set_connection_name(properties) 93 | conn = snowflake_connect(connection_name=snowflake_connection_name, **properties) 94 | try: 95 | yield conn 96 | finally: 97 | conn.close() 98 | 99 | 100 | @contextmanager 101 | def universql_connection(**properties) -> SnowflakeConnection: 102 | # https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#connecting-using-the-connections-toml-file 103 | print(f"Reading {CONNECTIONS_FILE} with {properties}") 104 | connections = CONFIG_MANAGER["connections"] 105 | snowflake_connection_name = _set_connection_name(properties) 106 | if snowflake_connection_name not in connections: 107 | raise pytest.fail(f"Snowflake connection '{snowflake_connection_name}' not found in config") 108 | connection = connections[snowflake_connection_name] 109 | account = connection.get('account') 110 | if account in server_cache: 111 | uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": server_cache[account]} | properties 112 | print(f"Reusing existing server running on port {server_cache[account]} for account {account}") 113 | else: 114 | from universql.main import snowflake 115 | with socketserver.TCPServer(("127.0.0.1", 0), None) as s: 116 | free_port = s.server_address[1] 117 | 118 | def start_universql(): 119 | runner = CliRunner() 120 | invoke = runner.invoke(snowflake, 121 | [ 122 | '--account', account, 123 | '--port', free_port, '--catalog', 'snowflake', 124 | # AWS_DEFAULT_PROFILE env can be used to pass AWS profile 125 | ], 126 | catch_exceptions=False 127 | ) 128 | if invoke.exit_code != 0: 129 | raise Exception("Unable to start Universql") 130 | 131 | print(f"Starting running on port {free_port} for account {account}") 132 | thread = threading.Thread(target=start_universql) 133 | thread.daemon = True 134 | thread.start() 135 | server_cache[account] = free_port 136 | uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties 137 | 138 | connect = None 139 | try: 140 | print(snowflake_connection_name, uni_string) 141 | connect = snowflake_connect(connection_name=snowflake_connection_name, **uni_string) 142 | yield connect 143 | finally: 144 | if connect is not None: 145 | connect.close() 146 | 147 | 148 | def execute_query(conn, query: str) -> pyarrow.Table: 149 | cur = conn.cursor() 150 | try: 151 | cur.execute(query) 152 | return cur.fetch_arrow_all() 153 | finally: 154 | cur.close() 155 | 156 | 157 | def compare_results(snowflake_result: pyarrow.Table, universql_result: pyarrow.Table): 158 | # Compare schemas 159 | if snowflake_result.schema != universql_result.schema: 160 | schema_diff = [] 161 | for field1, field2 in zip(snowflake_result.schema, universql_result.schema): 162 | if field1.name != field2.name and field1.type != field2.type: 163 | schema_diff.append(f"Expected field {field1}, but got {field2}") 164 | if len(snowflake_result.schema) != len(universql_result.schema): 165 | schema_diff.append(f"Schema lengths differ: " 166 | f"Snowflake={len(snowflake_result.schema)} " 167 | f"Universql={len(universql_result.schema)}") 168 | if len(schema_diff) > 0: 169 | raise pytest.fail("Schema mismatch:\n" + "\n".join(schema_diff)) 170 | 171 | # Compare row counts 172 | if snowflake_result.num_rows != universql_result.num_rows: 173 | raise pytest.fail(f"Row count mismatch: Snowflake={snowflake_result.num_rows} " 174 | f"Universql={universql_result.num_rows}") 175 | 176 | # Compare data row by row and column by column 177 | data_diff = [] 178 | for row_index in range(snowflake_result.num_rows): 179 | for col_index in range(snowflake_result.num_columns): 180 | value1 = snowflake_result.column(col_index)[row_index].as_py() 181 | value2 = universql_result.column(col_index)[row_index].as_py() 182 | if value1 != value2: 183 | data_diff.append(f"Row {row_index}, Column {col_index}: " 184 | f"Snowflake={value1}, Universql={value2}") 185 | 186 | if data_diff: 187 | raise pytest.fail("Data mismatch:\n" + "\n".join(data_diff)) 188 | 189 | print("Results match perfectly!") 190 | 191 | 192 | def _set_connection_name(connection_dict : Optional[dict]): 193 | if connection_dict is None: 194 | return None 195 | snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME) 196 | logger.info(f"Using the {snowflake_connection_name} connection") 197 | return snowflake_connection_name 198 | -------------------------------------------------------------------------------- /tests/plugins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/tests/plugins/__init__.py -------------------------------------------------------------------------------- /tests/scratch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/tests/scratch/__init__.py -------------------------------------------------------------------------------- /tests/scratch/cdk_tests.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import docker 3 | import base64 4 | import json 5 | 6 | 7 | # Initialize boto3 ECR client 8 | ecr = boto3.client('ecr') 9 | repository_uri = event['ResourceProperties']['RepositoryUri'] 10 | public_image_uri = event['ResourceProperties']['PublicImageUri'] 11 | region = context.invoked_function_arn.split(':')[3] 12 | registry = repository_uri.split('/')[0] 13 | 14 | # Get ECR authorization token 15 | auth_response = ecr.get_authorization_token() 16 | token = auth_response['authorizationData'][0]['authorizationToken'] 17 | username, password = base64.b64decode(token).decode().split(':') 18 | 19 | # Initialize Docker client 20 | docker_client = docker.from_env() 21 | 22 | # Login to ECR 23 | docker_client.login( 24 | username=username, 25 | password=password, 26 | registry=registry 27 | ) 28 | 29 | # Pull the public image 30 | print(f"Pulling image: {public_image_uri}") 31 | docker_client.images.pull(public_image_uri) 32 | 33 | # Tag the image for our repository 34 | source_image = docker_client.images.get(public_image_uri) 35 | source_image.tag(repository_uri, tag='latest') 36 | 37 | # Push to private repository 38 | print(f"Pushing image to: {repository_uri}") 39 | for line in docker_client.images.push(repository_uri, tag='latest', stream=True, decode=True): 40 | print(json.dumps(line)) 41 | 42 | # Get the image digest 43 | describe_images = ecr.describe_images( 44 | repositoryName=event['ResourceProperties']['RepositoryName'], 45 | imageIds=[{'imageTag': 'latest'}] 46 | ) 47 | image_digest = describe_images['imageDetails'][0]['imageDigest'] 48 | full_uri = f"{repository_uri}@{image_digest}" -------------------------------------------------------------------------------- /tests/scratch/chdb_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import chdb 3 | from chdb.udf import chdb_udf 4 | 5 | 6 | @chdb_udf(return_type="Int32") 7 | def sumn(n): 8 | n = int(n) 9 | return n*(n+1)//2 10 | 11 | 12 | ret = chdb.query( 13 | """ 14 | CREATE TABLE iceberg_table ENGINE=Iceberg(iceberg_conf, filename = 'config') 15 | 16 | """, 17 | path="./clickhouse", 18 | udf_path="./clickhouse/udf", 19 | ) 20 | print(ret) 21 | -------------------------------------------------------------------------------- /tests/scratch/dbt_tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dbt.cli.main import dbtRunner 4 | 5 | os.chdir("/Users/bkabak/Code/jinjat/snowflake_admin") 6 | dbt = dbtRunner() 7 | 8 | cli_args = ["show", "--inline", "select 1"] 9 | dbt.invoke(cli_args) -------------------------------------------------------------------------------- /tests/scratch/duck_tests.py: -------------------------------------------------------------------------------- 1 | from string import Template 2 | 3 | import duckdb 4 | 5 | 6 | connect = duckdb.connect(":memory:") 7 | connect.execute("ATTACH 'md:'") 8 | connect.execute("select * from MD_ALL_DATABASES()") 9 | 10 | connect = duckdb.connect(":memory:") 11 | connect.execute("ATTACH 'md:'") 12 | fetchdf = connect.sql("SHOW DATABASES").fetchdf() 13 | print(fetchdf) -------------------------------------------------------------------------------- /tests/scratch/jinja_test.py: -------------------------------------------------------------------------------- 1 | from jinja2 import Environment, DictLoader 2 | 3 | 4 | # Mock transformation functions 5 | def sample_data(size): 6 | return {'name': 'sample_data', 'args': {'size': size}} 7 | 8 | 9 | def duckdb(machine): 10 | return {'name': 'duckdb', 'args': {'machine': machine}} 11 | 12 | 13 | def default_create_table_as_iceberg(external_volume, catalog, base_location): 14 | return { 15 | 'name': 'default_create_table_as_iceberg', 16 | 'args': { 17 | 'external_volume': external_volume, 18 | 'catalog': catalog, 19 | 'base_location': base_location, 20 | } 21 | } 22 | 23 | 24 | def snowflake(warehouse): 25 | return {'name': 'snowflake', 'args': {'warehouse': warehouse}} 26 | 27 | 28 | # Jinja2 environment with mock functions 29 | env = Environment() 30 | 31 | # Add mock functions to the environment 32 | env.globals['sample_data'] = sample_data 33 | env.globals['duckdb'] = duckdb 34 | env.globals['default_create_table_as_iceberg'] = default_create_table_as_iceberg 35 | env.globals['snowflake'] = snowflake 36 | 37 | 38 | # Parsing function to extract transformations 39 | def parse_jinja_template(template_str): 40 | # Remove the surrounding {{ and }} 41 | template_str = template_str.strip("{}").strip() 42 | 43 | # Create the template in Jinja2 44 | template = env.from_string(template_str) 45 | 46 | # Render the template and capture the transformations 47 | result = template.render() 48 | 49 | return result 50 | 51 | 52 | # Example usage: 53 | input_str = "{{ transpile.sample_data('1000 rows') | execute.duckdb(machine='local'), transpile.default_create_table_as_iceberg(external_volume='iceberg_jinjat', catalog='snowflake', base_location='') | snowflake(warehouse='COMPUTE_WH') }}" 54 | parsed_transforms = parse_jinja_template(input_str) 55 | print(parsed_transforms) -------------------------------------------------------------------------------- /tests/scratch/pyiceberg_tests.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Union, Optional, Set, List 3 | 4 | import duckdb 5 | import pyarrow 6 | import sqlglot 7 | from duckdb.duckdb import NotImplementedException 8 | from pyiceberg.catalog import WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary, MetastoreCatalog 9 | from pyiceberg.catalog.sql import SqlCatalog 10 | from pyiceberg.exceptions import NoSuchNamespaceError, TableAlreadyExistsError, NoSuchIcebergTableError 11 | from pyiceberg.io import PY_IO_IMPL, load_file_io 12 | from pyiceberg.partitioning import PartitionSpec, UNPARTITIONED_PARTITION_SPEC 13 | from pyiceberg.schema import Schema 14 | from pyiceberg.table import SortOrder, UNSORTED_SORT_ORDER, CommitTableRequest, CommitTableResponse, Table, \ 15 | CreateTableTransaction, StaticTable 16 | from pyiceberg.table.metadata import new_table_metadata 17 | from pyiceberg.typedef import Identifier, Properties, EMPTY_DICT 18 | from sqlglot.expressions import Column 19 | from sqlglot.optimizer import build_scope 20 | 21 | from universql.lake.cloud import CACHE_DIRECTORY_KEY, s3 22 | from universql.util import QueryError, load_computes 23 | 24 | computes = load_computes() 25 | # def get_iceberg_table_from_data_lake(metadata_file_path: str, cache_directory): 26 | # from_metadata = StaticTable.from_metadata(metadata_file_path, { 27 | # PY_IO_IMPL: "universql.lake.cloud.iceberg", 28 | # CACHE_DIRECTORY_KEY: cache_directory, 29 | # }) 30 | # return from_metadata 31 | # test = get_iceberg_table_from_data_lake( 32 | # "gs://my-iceberg-data/custom-events/customer_iceberg/metadata/v1719882827064000000.metadata.json", '') 33 | # to_arrow = test.scan().to_arrow() 34 | # test = test.append(to_arrow) 35 | 36 | # catalog = load_catalog( 37 | # name="polaris_aws" 38 | # ) 39 | # catalog.properties[PY_IO_IMPL] = "universql.lake.cloud.iceberg" 40 | # 41 | # class HeadlessCatalog(NoopCatalog): 42 | # def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse: 43 | # return CommitTableResponse( 44 | # metadata=table_request, metadata_location=updated_staged_table.metadata_location 45 | # ) 46 | 47 | 48 | connect = duckdb.connect(":memory:") 49 | connect.register_filesystem(s3({"cache_directory": "~/.universql/cache"})) 50 | 51 | # db_catalog = DuckDBIcebergCatalog(connect, **{"namespace": None}) 52 | # db_catalog.create_namespace("MY_CUSTOM_APP.PUBLIC", {}) 53 | # db_catalog.register_table("memory.main.table", 54 | # "s3://universql-us-east-1/glue_tables6/ICEBERG_TESTS/PUBLIC/ttt/metadata/00001-48648dfd-9355-4808-ab2f-c9065c6ef691.metadata.json") 55 | # catalog_load_table = db_catalog.load_table("memory.main.table") 56 | 57 | # db_catalog.create_table((None, "memory.main.newtable"), schema=pyarrow.schema([])) 58 | # 59 | # db_catalog.create_namespace("MY_CUSTOM_APP", {}) 60 | 61 | sql_catalog = SqlCatalog("ducky", **{ 62 | PY_IO_IMPL: "universql.lake.cloud.iceberg", 63 | CACHE_DIRECTORY_KEY: './', 64 | "uri": "duckdb:///:memory:", 65 | "echo": "true" 66 | }) 67 | # sql_catalog.create_namespace("MY_CUSTOM_APP.PUBLIC", {}) 68 | create_table = sql_catalog.create_table('MY_CUSTOM_APP.PUBLIC.testd', schema=pyarrow.schema([])) 69 | 70 | load_table = sql_catalog.load_table("public.taxi_dataset") 71 | 72 | arrow = table.scan().to_arrow() 73 | table.current_snapshot() 74 | 75 | # df = pq.read_table("/tmp/yellow_tripdata_2023-01.parquet") 76 | 77 | # table = catalog.create_table( 78 | # "public.taxi_dataset", 79 | # schema=df.schema, 80 | # ) 81 | 82 | # table.append(df) 83 | # len(table.scan().to_arrow()) 84 | 85 | # df = df.append_column("tip_per_mile", pc.divide(df["tip_amount"], df["trip_distance"])) 86 | # with table.update_schema() as update_schema: 87 | # update_schema.union_by_name(df.schema) 88 | 89 | # table.overwrite(df) 90 | # print(table.scan().to_arrow()) 91 | 92 | df = table.scan(row_filter="tip_per_mile > 0").to_arrow() 93 | len(df) 94 | -------------------------------------------------------------------------------- /tests/scratch/sqlglot_tests.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import abstractmethod 3 | 4 | import pyarrow as pa 5 | 6 | import duckdb 7 | import sqlglot 8 | from fastapi.openapi.models import Response 9 | from google.cloud import bigquery 10 | from google.cloud.bigquery import QueryJobConfig 11 | from sqlglot import Expression 12 | from sqlglot.expressions import TransientProperty, TemporaryProperty, Properties, IcebergProperty 13 | from starlette.requests import Request 14 | 15 | from universql.plugin import UniversqlPlugin, Executor 16 | from universql.warehouse.duckdb import DuckDBCatalog, DuckDBExecutor 17 | from universql.warehouse.snowflake import SnowflakeCatalog 18 | 19 | # SELECT ascii(t.$1), ascii(t.$2) FROM 's3://fullpath' (file_format_for_duckdb => myformat) t; 20 | 21 | one = sqlglot.parse_one("SELECT ascii(t.$1), ascii(t.$2) FROM @mystage1 (file_format => myformat) t;", read="snowflake") 22 | two = sqlglot.parse_one("""COPY INTO stg_device_metadata 23 | FROM @iceberg_db.public.landing_stage/initial_objects/ 24 | --FILES = ('device_metadata.csv', 'file2.csv') 25 | FILE_FORMAT = (TYPE = CSV SKIP_HEADER = 1);""", read="snowflake") 26 | 27 | 28 | class FixTimestampTypes(UniversqlPlugin): 29 | 30 | def transform_sql(self, ast, target_executor: Executor): 31 | def fix_timestamp_types(expression): 32 | if isinstance(target_executor, DuckDBExecutor) and isinstance(expression, sqlglot.exp.DataType): 33 | if expression.this.value in ["TIMESTAMPLTZ", "TIMESTAMPTZ"]: 34 | return sqlglot.exp.DataType.build("TIMESTAMPTZ") 35 | if expression.this.value in ["VARIANT"]: 36 | return sqlglot.exp.DataType.build("JSON") 37 | 38 | return ast.transform(fix_timestamp_types) 39 | 40 | 41 | class RewriteCreateAsIceberg(UniversqlPlugin): 42 | 43 | def transform_sql(self, expression: Expression, target_executor: Executor) -> Expression: 44 | prefix = self.catalog.iceberg_catalog.properties.get("location") 45 | 46 | if isinstance(expression, sqlglot.exp.Create): 47 | if expression.kind == 'TABLE': 48 | properties = expression.args.get('properties') or Properties() 49 | is_transient = TransientProperty() in properties.expressions 50 | is_temp = TemporaryProperty() in properties.expressions 51 | is_iceberg = IcebergProperty() in properties.expressions 52 | if is_transient or len(properties.expressions) == 0: 53 | properties__set = Properties() 54 | external_volume = Property(this=Var(this='EXTERNAL_VOLUME'), 55 | value=Literal.string( 56 | self.catalog.context.get('snowflake_iceberg_volume'))) 57 | snowflake_catalog = self.catalog.iceberg_catalog or "snowflake" 58 | catalog = Property(this=Var(this='CATALOG'), value=Literal.string(snowflake_catalog)) 59 | if snowflake_catalog == 'snowflake': 60 | base_location = Property(this=Var(this='BASE_LOCATION'), 61 | value=Literal.string(location.metadata_location[len(prefix):])) 62 | elif snowflake_catalog == 'glue': 63 | base_location = Property(this=Var(this='CATALOG_TABLE_NAME'), 64 | value=Literal.string(expression.this.sql())) 65 | create_table_props = [IcebergProperty(), external_volume, catalog, base_location] 66 | properties__set.set('expressions', create_table_props) 67 | 68 | metadata_query = expression.expression.sql(dialect="snowflake") 69 | try: 70 | self.catalog.cursor().describe(metadata_query) 71 | except Exception as e: 72 | logger.error(f"Unable fetching schema for metadata query {e.args} \n" + metadata_query) 73 | return expression 74 | columns = [(column.name, FIELD_TYPES[column.type_code]) for column in 75 | self.catalog.cursor().description] 76 | unsupported_columns = [(column[0], column[1].name) for column in columns if column[1].name not in ( 77 | 'BOOLEAN', 'TIME', 'BINARY', 'TIMESTAMP_TZ', 'TIMESTAMP_NTZ', 'TIMESTAMP_LTZ', 'TIMESTAMP', 78 | 'DATE', 'FIXED', 79 | 'TEXT', 'REAL')] 80 | if len(unsupported_columns) > 0: 81 | logger.error( 82 | f"Unsupported columns {unsupported_columns} in {expression.expression.sql(dialect='snowflake')}") 83 | return expression 84 | 85 | column_definitions = [ColumnDef( 86 | this=sqlglot.exp.parse_identifier(column[0]), 87 | kind=DataType.build(self._convert_snowflake_to_iceberg_type(column[1]), dialect="snowflake")) 88 | for 89 | column in 90 | columns] 91 | schema = Schema() 92 | schema.set('this', expression.this) 93 | schema.set('expressions', column_definitions) 94 | expression.set('this', schema) 95 | select = Select().from_(Subquery(this=expression.expression)) 96 | for column in columns: 97 | col_ast = Column(this=parse_identifier(column[0])) 98 | if column[1].name in ('ARRAY', 'OBJECT'): 99 | alias = Alias(this=Anonymous(this="to_variant", expressions=[col_ast]), 100 | alias=parse_identifier(column[0])) 101 | select = select.select(alias) 102 | else: 103 | select = select.select(col_ast) 104 | 105 | expression.set('expression', select) 106 | expression.set('properties', properties__set) 107 | return expression 108 | 109 | 110 | # one = sqlglot.parse_one("create table if not exists test as select 1", read="snowflake") 111 | one = sqlglot.parse_one("select * from @test", read="snowflake") 112 | # one = sqlglot.parse_one("select * from 's3://test'", read="snowflake") 113 | # one = sqlglot.parse_one("select to_variant(test) as test from (select 1)", read="snowflake") 114 | # one = sqlglot.parse_one("create table test as select to_variant(test) as test from (select 1)", read="snowflake") -------------------------------------------------------------------------------- /tests/scratch/system_tray.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import sys 3 | 4 | from PIL import Image 5 | from pystray import Icon, Menu, MenuItem 6 | 7 | import webview 8 | 9 | if sys.platform == 'darwin': 10 | ctx = multiprocessing.get_context('spawn') 11 | Process = ctx.Process 12 | Queue = ctx.Queue 13 | else: 14 | Process = multiprocessing.Process 15 | Queue = multiprocessing.Queue 16 | 17 | webview_process = None 18 | 19 | def run_webview(): 20 | window = webview.create_window('Webview', 'https://pywebview.flowrl.com/hello') 21 | webview.start() 22 | 23 | if __name__ == '__main__': 24 | 25 | def start_webview_process(): 26 | global webview_process 27 | webview_process = Process(target=run_webview) 28 | webview_process.start() 29 | 30 | def on_open(icon, item): 31 | global webview_process 32 | if not webview_process.is_alive(): 33 | start_webview_process() 34 | 35 | def on_exit(icon, item): 36 | icon.stop() 37 | 38 | start_webview_process() 39 | 40 | image = Image.open('/Users/bkabak/Documents/My Tableau Repository/Shapes/Ratings/0.png') 41 | menu = Menu(MenuItem('Open', on_open), MenuItem('Exit', on_exit)) 42 | icon = Icon('Pystray', image, menu=menu) 43 | icon.run() 44 | 45 | webview_process.terminate() -------------------------------------------------------------------------------- /tests/scratch/textual_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable 3 | 4 | from textual.app import App, ComposeResult 5 | from textual.widgets import DirectoryTree 6 | 7 | 8 | class FilteredDirectoryTree(DirectoryTree): 9 | def filter_paths(self, paths: Iterable[Path]) -> Iterable[Path]: 10 | return [path for path in paths if not path.name.startswith(".")] 11 | 12 | 13 | class DirectoryTreeApp(App): 14 | def compose(self) -> ComposeResult: 15 | yield FilteredDirectoryTree(Path.home() / ".universql") 16 | 17 | 18 | if __name__ == "__main__": 19 | app = DirectoryTreeApp() 20 | app.run() 21 | 22 | -------------------------------------------------------------------------------- /tests/sql_optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | SQL Optimizer App 3 | 4 | This Streamlit app allows users to optimize SQL queries using predefined optimization rules. 5 | Users can select optimization rules, choose to remove common table expressions (CTEs), and 6 | optionally format the optimized query using sqlfmt. The original and optimized queries are 7 | displayed side by side for comparison. Additionally, users can view the source code on GitHub. 8 | 9 | Author: [Your Name] 10 | 11 | """ 12 | from typing import Callable, Dict, Sequence 13 | 14 | from sqlfmt.api import Mode, format_string 15 | from sqlglot import parse_one 16 | from sqlglot.expressions import Select 17 | from sqlglot.optimizer import RULES, optimize 18 | import streamlit as st 19 | from streamlit_ace import st_ace 20 | 21 | RULE_MAPPING: Dict[str, Callable] = {rule.__name__: rule for rule in RULES} 22 | SAMPLE_QUERY: str = """WITH users AS ( 23 | SELECT * 24 | FROM users_table), 25 | orders AS ( 26 | SELECT * 27 | FROM orders_table), 28 | combined AS ( 29 | SELECT users.id, users.name, orders.order_id, orders.total 30 | FROM users 31 | JOIN orders ON users.id = orders.user_id) 32 | SELECT combined.id, combined.name, combined.order_id, combined.total 33 | FROM combined 34 | """ 35 | 36 | 37 | def _generate_ast(query: str) -> Select: 38 | """ 39 | Generate an AST from a query. 40 | """ 41 | ast = parse_one(query) 42 | return ast 43 | 44 | 45 | def apply_optimizations( 46 | query: str, rules: Sequence[Callable] = RULES, remove_ctes: bool = False 47 | ) -> Select: 48 | """ 49 | Apply optimizations to an AST. 50 | """ 51 | ast = _generate_ast(query) 52 | if remove_ctes: 53 | return optimize(ast, rules=rules) 54 | else: 55 | return optimize(ast, rules=rules, leave_tables_isolated=True) 56 | 57 | 58 | def format_sql_with_sqlfmt(query: str) -> str: 59 | """ 60 | Format a query using sqlfmt. 61 | """ 62 | mode = Mode() 63 | return format_string(query, mode) 64 | 65 | # Set custom Streamlit page configuration 66 | st.set_page_config( 67 | page_title="SQL Optimizer", 68 | page_icon=":bar_chart:", 69 | layout="wide", 70 | initial_sidebar_state="expanded", 71 | ) 72 | 73 | # Hide Streamlit default menu and footer 74 | hide_st_style = """ 75 | 80 | """ 81 | st.markdown(hide_st_style, unsafe_allow_html=True) 82 | 83 | # Custom CSS styling 84 | st.markdown( 85 | """ 86 | 153 | """, 154 | unsafe_allow_html=True, 155 | ) 156 | 157 | # Title container 158 | st.markdown( 159 | """ 160 |
161 |

SQL Optimizer

162 |
163 | """, 164 | unsafe_allow_html=True, 165 | ) 166 | 167 | # Rule selector 168 | selected_rules = st.multiselect( 169 | "Optimization rules:", 170 | list(RULE_MAPPING.keys()), 171 | default=list(RULE_MAPPING.keys()), 172 | key="rules_multiselect", 173 | ) 174 | 175 | # Checkboxes 176 | cols = st.columns(2) 177 | remove_ctes = cols[0].checkbox( 178 | "Remove CTEs", on_change=None, key="remove_ctes_checkbox" 179 | ) 180 | format_with_sqlfmt = cols[1].checkbox( 181 | "Lint with sqlfmt", on_change=None, key="format_with_sqlfmt_checkbox" 182 | ) 183 | 184 | # Initialize session state 185 | if "new_query" not in st.session_state: 186 | st.session_state.new_query = "" 187 | if "state" not in st.session_state: 188 | st.session_state.state = 0 189 | 190 | 191 | # Input editor 192 | def _generate_editor_widget(value: str, **kwargs) -> str: 193 | return st_ace( 194 | value=value, 195 | height=300, 196 | theme="twilight", 197 | language="sql", 198 | font_size=16, 199 | wrap=True, 200 | auto_update=True, 201 | **kwargs, 202 | ) 203 | 204 | 205 | left, right = st.columns(2) 206 | 207 | with left: 208 | sql_input = _generate_editor_widget(SAMPLE_QUERY, key="input_editor") 209 | 210 | # Optimize and lint query 211 | if st.button("Optimize SQL", key="optimize_button"): 212 | try: 213 | rules = [RULE_MAPPING[rule] for rule in selected_rules] 214 | new_query = apply_optimizations(sql_input, rules, remove_ctes).sql(pretty=True) 215 | if format_with_sqlfmt: 216 | new_query = format_sql_with_sqlfmt(new_query) 217 | st.session_state.new_query = new_query 218 | st.session_state.state += 1 219 | st.success("SQL query optimized successfully!") 220 | 221 | except Exception as e: 222 | st.error(f"Error: {e}") 223 | 224 | # CSS for the button 225 | css = """ 226 | 240 | """ 241 | 242 | # Add the HTML button for optimization with CSS 243 | st.markdown(css, unsafe_allow_html=True) 244 | 245 | # Output editor 246 | with right: 247 | _generate_editor_widget( 248 | st.session_state.new_query, readonly=True, key=f"ace-{st.session_state.state}" 249 | ) 250 | 251 | # Include Font Awesome CSS 252 | st.markdown( 253 | """ 254 | 255 | """, 256 | unsafe_allow_html=True, 257 | ) 258 | 259 | # GitHub link 260 | st.markdown( 261 | """ 262 | 267 | """, 268 | unsafe_allow_html=True, 269 | ) -------------------------------------------------------------------------------- /universql.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "path": "." 5 | } 6 | ], 7 | "settings": {} 8 | } -------------------------------------------------------------------------------- /universql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/universql/__init__.py -------------------------------------------------------------------------------- /universql/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/universql/agent/__init__.py -------------------------------------------------------------------------------- /universql/agent/cloudflared.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import logging 3 | import sys 4 | 5 | import click 6 | import requests 7 | import subprocess 8 | import tarfile 9 | import tempfile 10 | import shutil 11 | import os 12 | import platform 13 | import time 14 | import re 15 | from pathlib import Path 16 | 17 | CLOUDFLARED_CONFIG = { 18 | ('Windows', 'AMD64'): { 19 | 'command': 'cloudflared-windows-amd64.exe', 20 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe' 21 | }, 22 | ('Windows', 'x86'): { 23 | 'command': 'cloudflared-windows-386.exe', 24 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-386.exe' 25 | }, 26 | ('Linux', 'x86_64'): { 27 | 'command': 'cloudflared-linux-amd64', 28 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64' 29 | }, 30 | ('Linux', 'i386'): { 31 | 'command': 'cloudflared-linux-386', 32 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-386' 33 | }, 34 | ('Linux', 'arm'): { 35 | 'command': 'cloudflared-linux-arm', 36 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-arm' 37 | }, 38 | ('Linux', 'arm64'): { 39 | 'command': 'cloudflared-linux-arm64', 40 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-arm64' 41 | }, 42 | ('Linux', 'aarch64'): { 43 | 'command': 'cloudflared-linux-arm64', 44 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-arm64' 45 | }, 46 | ('Darwin', 'x86_64'): { 47 | 'command': 'cloudflared', 48 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-darwin-amd64.tgz' 49 | }, 50 | ('Darwin', 'arm64'): { 51 | 'command': 'cloudflared', 52 | 'url': 'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-darwin-amd64.tgz' 53 | } 54 | } 55 | 56 | 57 | def _get_command(system, machine): 58 | try: 59 | return CLOUDFLARED_CONFIG[(system, machine)]['command'] 60 | except KeyError: 61 | raise Exception(f"{machine} is not supported on {system}") 62 | 63 | 64 | def _get_url(system, machine): 65 | try: 66 | return CLOUDFLARED_CONFIG[(system, machine)]['url'] 67 | except KeyError: 68 | raise Exception(f"{machine} is not supported on {system}") 69 | 70 | 71 | # Needed for the darwin package 72 | def _extract_tarball(tar_path, filename): 73 | tar = tarfile.open(tar_path + '/' + filename, 'r') 74 | for item in tar: 75 | tar.extract(item, tar_path) 76 | if item.name.find(".tgz") != -1 or item.name.find(".tar") != -1: 77 | tar.extract(item.name, "./" + item.name[:item.name.rfind('/')]) 78 | 79 | 80 | def _download_cloudflared(cloudflared_path, command): 81 | system, machine = platform.system(), platform.machine() 82 | if Path(cloudflared_path, command).exists(): 83 | executable = (cloudflared_path + '/' + 'cloudflared') if ( 84 | system == "Darwin" and machine in ["x86_64", "arm64"]) else (cloudflared_path + '/' + command) 85 | update_cloudflared = subprocess.Popen([executable, 'update'], stdout=subprocess.DEVNULL, 86 | stderr=subprocess.STDOUT) 87 | return 88 | print(f" * Downloading cloudflared for {system} {machine}...") 89 | url = _get_url(system, machine) 90 | _download_file(url) 91 | 92 | 93 | def _download_file(url): 94 | local_filename = url.split('/')[-1] 95 | r = requests.get(url, stream=True) 96 | download_path = str(Path(tempfile.gettempdir(), local_filename)) 97 | with open(download_path, 'wb') as f: 98 | shutil.copyfileobj(r.raw, f) 99 | return download_path 100 | 101 | 102 | def start_cloudflared(port, metrics_port, tunnel_id=None, config_path=None): 103 | system, machine = platform.system(), platform.machine() 104 | command = _get_command(system, machine) 105 | tempdir_for_app = tempfile.gettempdir() 106 | cloudflared_path = str(Path(tempdir_for_app)) 107 | if system == "Darwin": 108 | _download_cloudflared(cloudflared_path, "cloudflared-darwin-amd64.tgz") 109 | _extract_tarball(cloudflared_path, "cloudflared-darwin-amd64.tgz") 110 | else: 111 | _download_cloudflared(cloudflared_path, command) 112 | 113 | executable = str(Path(cloudflared_path, command)) 114 | os.chmod(executable, 0o777) 115 | 116 | cloudflared_command = [executable, 'tunnel', '--metrics', f'127.0.0.1:{metrics_port}', '--logfile', 117 | str(Path(tempdir_for_app, 'cloudflared.log'))] 118 | if config_path: 119 | cloudflared_command += ['--config', config_path, 'run'] 120 | elif tunnel_id: 121 | cloudflared_command += ['--url', f'http://127.0.0.1:{port}', 'run', tunnel_id] 122 | else: 123 | cloudflared_command += ['--url', f'http://127.0.0.1:{port}'] 124 | 125 | if system == "Darwin" and machine == "arm64": 126 | x___cloudflared_command = ['arch', '-x86_64'] + cloudflared_command 127 | logging.debug('Running command: '+' '.join(x___cloudflared_command)) 128 | cloudflared = subprocess.Popen(x___cloudflared_command, stdout=subprocess.PIPE, 129 | stderr=subprocess.STDOUT) 130 | else: 131 | cloudflared = subprocess.Popen(cloudflared_command, stdout=subprocess.PIPE, 132 | stderr=subprocess.STDOUT) 133 | atexit.register(cloudflared.terminate) 134 | 135 | 136 | def get_cloudflare_url(metrics_port, tunnel_id=None, config_path=None): 137 | localhost_url = f"http://127.0.0.1:{metrics_port}/metrics" 138 | 139 | for i in range(10): 140 | try: 141 | metrics = requests.get(localhost_url).text 142 | if tunnel_id or config_path: 143 | # If tunnel_id or config_path is provided, we check for cloudflared_tunnel_ha_connections, as no tunnel URL is available in the metrics 144 | if re.search("cloudflared_tunnel_ha_connections\s\d", metrics): 145 | # No tunnel URL is available in the metrics, so we return a generic text 146 | tunnel_url = "preconfigured tunnel URL" 147 | break 148 | else: 149 | # If neither tunnel_id nor config_path is provided, we check for the tunnel URL in the metrics 150 | tunnel_url = (re.search("https?:\/\/(?P[^\s]+.trycloudflare.com)", metrics).group("url")) 151 | break 152 | except: 153 | click.secho(f"Waiting for cloudflared to generate the tunnel URL... {i * 3}s", fg="yellow") 154 | time.sleep(3) 155 | else: 156 | click.secho( 157 | f"Can't connect to cloudflared tunnel, check logs at {str(Path(tempfile.gettempdir(), 'cloudflared.log'))} and restart the server") 158 | sys.exit(0) 159 | 160 | return tunnel_url 161 | -------------------------------------------------------------------------------- /universql/catalog/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/universql/catalog/__init__.py -------------------------------------------------------------------------------- /universql/catalog/iceberg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import sqlglot 5 | from pyarrow import Table 6 | from pyiceberg.catalog import load_catalog 7 | from pyiceberg.exceptions import NoSuchTableError, OAuthError 8 | from pyiceberg.io import PY_IO_IMPL 9 | 10 | from universql.plugin import Locations, ICatalog 11 | from universql.protocol.session import UniverSQLSession 12 | from universql.lake.cloud import CACHE_DIRECTORY_KEY 13 | from universql.util import SnowflakeError 14 | 15 | logger = logging.getLogger("🧊") 16 | 17 | 18 | class PolarisCatalog(ICatalog): 19 | 20 | def register_locations(self, tables: Locations): 21 | raise SnowflakeError(self.session_id, "Polaris doesn't support direct execution") 22 | 23 | def __init__(self, session : UniverSQLSession, compute: dict): 24 | super().__init__(session) 25 | current_database = session.credentials.get('database') 26 | if current_database is None: 27 | raise SnowflakeError(session.session_id, "No database/catalog provided, unable to connect to Polaris catalog") 28 | iceberg_rest_credentials = { 29 | "uri": f"https://{session.context.get('account')}.snowflakecomputing.com/polaris/api/catalog", 30 | "credential": f"{session.credentials.get('user')}:{session.credentials.get('password')}", 31 | "warehouse": current_database, 32 | "scope": "PRINCIPAL_ROLE:ALL" 33 | } 34 | try: 35 | self.rest_catalog = load_catalog(None, **iceberg_rest_credentials) 36 | except OAuthError as e: 37 | raise SnowflakeError(self.session_id, e.args[0]) 38 | self.rest_catalog.properties[CACHE_DIRECTORY_KEY] = session.context.get('cache_directory') 39 | self.rest_catalog.properties[PY_IO_IMPL] = "universql.lake.cloud.iceberg" 40 | 41 | def get_table_paths(self, tables: List[sqlglot.exp.Table]): 42 | return {table: self._get_table(table) for table in tables} 43 | 44 | def _get_table(self, table: sqlglot.exp.Table) -> Table: 45 | table_ref = table.sql(dialect="snowflake") 46 | try: 47 | iceberg_table = self.rest_catalog.load_table(table_ref) 48 | except NoSuchTableError: 49 | error = f"Table {table_ref} doesn't exist in Polaris catalog `{self.credentials.get('database')}` or your role doesn't have access to the table." 50 | logger.error(error) 51 | raise SnowflakeError(self.session_id, error) 52 | return iceberg_table.scan().to_arrow() -------------------------------------------------------------------------------- /universql/lake/cloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import aiobotocore 4 | import gcsfs 5 | import s3fs 6 | from fsspec.core import logger 7 | from fsspec.utils import setup_logging 8 | 9 | from universql.lake.fsspec_util import MonitoredSimpleCacheFileSystem 10 | 11 | in_lambda = os.environ.get('AWS_EXECUTION_ENV') is not None 12 | 13 | 14 | def s3(context: dict): 15 | cache_storage = context.get('cache_directory') 16 | session = aiobotocore.session.AioSession(profile=context.get("aws_profile")) 17 | s3_file_system = s3fs.S3FileSystem(session=session) 18 | if context.get("max_cache_size", "0") != "0": 19 | s3_file_system = MonitoredSimpleCacheFileSystem( 20 | fs=s3_file_system, 21 | cache_storage=cache_storage, 22 | ) 23 | 24 | return s3_file_system 25 | 26 | 27 | def gcs(context: dict): 28 | cache_storage = context.get('cache_directory') 29 | setup_logging(logger=logger, level="ERROR") 30 | gcs_file_system = gcsfs.GCSFileSystem(project=context.get('gcp_project')) 31 | if context.get("max_cache_size", "0") != "0": 32 | gcs_file_system = MonitoredSimpleCacheFileSystem( 33 | fs=gcs_file_system, 34 | cache_storage=cache_storage, 35 | ) 36 | return gcs_file_system 37 | 38 | 39 | CACHE_DIRECTORY_KEY = "universql.cache_directory" 40 | MAX_CACHE_SIZE = "universql.max_cache_size" 41 | 42 | 43 | def iceberg(context: dict): 44 | from pyiceberg.io.fsspec import FsspecFileIO 45 | io = FsspecFileIO(context) 46 | directory = context.get(CACHE_DIRECTORY_KEY) 47 | max_cache_size = context.get(MAX_CACHE_SIZE) 48 | get_fs = io.get_fs 49 | if max_cache_size is not None and max_cache_size != '0': 50 | io.get_fs = lambda name: MonitoredSimpleCacheFileSystem( 51 | fs=get_fs(name), 52 | cache_storage=directory, 53 | ) 54 | return io 55 | -------------------------------------------------------------------------------- /universql/lake/fsspec_util.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import os 4 | import re 5 | import shutil 6 | from datetime import timedelta, datetime 7 | from functools import wraps 8 | 9 | import psutil 10 | from fsspec.implementations.cache_mapper import AbstractCacheMapper 11 | from fsspec.implementations.cached import SimpleCacheFileSystem 12 | 13 | from universql.util import get_total_directory_size 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger("data_lake") 17 | 18 | 19 | class FileNameCacheMapper(AbstractCacheMapper): 20 | def __init__(self, directory): 21 | self.directory = directory 22 | 23 | def __call__(self, path: str) -> str: 24 | os.makedirs(os.path.dirname(os.path.join(self.directory, path)), exist_ok=True) 25 | return path 26 | 27 | 28 | class throttle(object): 29 | """ 30 | Decorator that prevents a function from being called more than once every 31 | time period. 32 | To create a function that cannot be called more than once a minute: 33 | @throttle(minutes=1) 34 | def my_fun(): 35 | pass 36 | """ 37 | 38 | def __init__(self, seconds=0, minutes=0, hours=0): 39 | self.throttle_period = timedelta( 40 | seconds=seconds, minutes=minutes, hours=hours 41 | ) 42 | self.time_of_last_call = datetime.min 43 | 44 | def __call__(self, fn): 45 | @wraps(fn) 46 | def wrapper(*args, **kwargs): 47 | now = datetime.now() 48 | time_since_last_call = now - self.time_of_last_call 49 | 50 | if time_since_last_call > self.throttle_period: 51 | self.time_of_last_call = now 52 | return fn(*args, **kwargs) 53 | 54 | return wrapper 55 | 56 | 57 | def sizeof_fmt(num, suffix="B"): 58 | for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): 59 | if abs(num) < 1024.0: 60 | return f"{num:3.1f}{unit}{suffix}" 61 | num /= 1024.0 62 | return f"{num:.1f}Yi{suffix}" 63 | 64 | 65 | last_free = None 66 | first_free = None 67 | 68 | 69 | def get_friendly_disk_usage(storage: str, debug=False) -> str: 70 | global last_free 71 | global first_free 72 | if not os.path.exists(storage): 73 | return '' 74 | usage = psutil.disk_usage(storage) 75 | if first_free is None: 76 | first_free = usage.free 77 | current_usage = get_total_directory_size(storage) 78 | message = f"{sizeof_fmt(current_usage)} used {sizeof_fmt(usage.free)} available" 79 | if last_free is not None: 80 | downloaded_recently = last_free - usage.free 81 | if downloaded_recently > 1_000_000 or debug: 82 | downloaded_since_start = first_free - usage.free 83 | message += f" downloaded since start: {sizeof_fmt(downloaded_since_start)}" 84 | 85 | last_free = usage.free 86 | return message 87 | 88 | # iceberg files don't ever change, we can cache them locally to speed up the queries. 89 | ICEBERG_FILE_REGEX = re.compile('(?i)^(?:data/|metadata/)?(?:(?=.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}).*\.parquet|.*\.avro|.*metadata\.json)$') 90 | 91 | class MonitoredSimpleCacheFileSystem(SimpleCacheFileSystem): 92 | 93 | def __init__(self, **kwargs): 94 | kwargs["cache_storage"] = os.path.join(kwargs.get("cache_storage"), kwargs.get('fs').protocol[0]) 95 | super().__init__(**kwargs) 96 | self._mapper = FileNameCacheMapper(kwargs.get('cache_storage')) 97 | 98 | def _check_file(self, path): 99 | # self._check_cache() 100 | cache_path = self._mapper(path) 101 | for storage in self.storage: 102 | fn = os.path.join(storage, cache_path) 103 | if os.path.exists(fn): 104 | return fn 105 | logger.info(f"Downloading {self.protocol[0]}://{path}") 106 | 107 | # def glob(self, path): 108 | # return [self._strip_protocol(path)] 109 | 110 | # def get_file(self, path, lpath, **kwargs): 111 | # """ 112 | # Overridden method to manage the local caching process manually. 113 | # Downloads the remote file to `lpath + '.tmp'` and then renames it to `lpath`. 114 | # """ 115 | # 116 | # # If the final file already exists and we are not forcing re-download, skip 117 | # if os.path.exists(lpath): 118 | # return 119 | # 120 | # tmp_path = lpath + ".tmp" 121 | # 122 | # # In case a previous failed download left a stale tmp file 123 | # if os.path.exists(tmp_path): 124 | # os.remove(tmp_path) 125 | # 126 | # # Ensure the target directory for lpath exists 127 | # os.makedirs(os.path.dirname(lpath), exist_ok=True) 128 | # 129 | # # Open the remote file and download to the temporary local file 130 | # with self.fs.open(path, 'rb') as source, open(tmp_path, 'wb') as target: 131 | # shutil.copyfileobj(source, target) 132 | # 133 | # # Atomically move the temporary file to the final location 134 | # os.rename(tmp_path, lpath) 135 | 136 | def size(self, path): 137 | cached_file = self._check_file(self._strip_protocol(path)) 138 | if cached_file is None: 139 | return self.fs.size(path) 140 | else: 141 | return os.path.getsize(cached_file) 142 | 143 | def open(self, path, mode="rb", **kwargs): 144 | """ 145 | Open a file. If the file's path does not match the cache regex, bypass the 146 | caching and read directly from the underlying filesystem. 147 | """ 148 | # if not ICEBERG_FILE_REGEX.search(path): 149 | # # bypass caching. 150 | # return self.fs.open(path, mode=mode, **kwargs) 151 | 152 | return super().open(path, mode=mode, **kwargs) 153 | 154 | def __getattribute__(self, item): 155 | if item in { 156 | # new items 157 | "size", 158 | "glob", 159 | # previous 160 | "load_cache", 161 | "_open", 162 | "save_cache", 163 | "close_and_update", 164 | "__init__", 165 | "__getattribute__", 166 | "__reduce__", 167 | "_make_local_details", 168 | "open", 169 | "cat", 170 | "cat_file", 171 | "cat_ranges", 172 | "get", 173 | "read_block", 174 | "tail", 175 | "head", 176 | "info", 177 | "ls", 178 | "exists", 179 | "isfile", 180 | "isdir", 181 | "_check_file", 182 | "_check_cache", 183 | "_mkcache", 184 | "clear_cache", 185 | "clear_expired_cache", 186 | "pop_from_cache", 187 | "local_file", 188 | "_paths_from_path", 189 | "get_mapper", 190 | "open_many", 191 | "commit_many", 192 | "hash_name", 193 | "__hash__", 194 | "__eq__", 195 | "to_json", 196 | "to_dict", 197 | "cache_size", 198 | "pipe_file", 199 | "pipe", 200 | "start_transaction", 201 | "end_transaction", 202 | }: 203 | # all the methods defined in this class. Note `open` here, since 204 | # it calls `_open`, but is actually in superclass 205 | return lambda *args, **kw: getattr(type(self), item).__get__(self)( 206 | *args, **kw 207 | ) 208 | if item in ["__reduce_ex__"]: 209 | raise AttributeError 210 | if item in ["transaction"]: 211 | # property 212 | return type(self).transaction.__get__(self) 213 | if item in ["_cache", "transaction_type"]: 214 | # class attributes 215 | return getattr(type(self), item) 216 | if item == "__class__": 217 | return type(self) 218 | d = object.__getattribute__(self, "__dict__") 219 | fs = d.get("fs", None) # fs is not immediately defined 220 | if item in d: 221 | return d[item] 222 | elif fs is not None: 223 | if item in fs.__dict__: 224 | # attribute of instance 225 | return fs.__dict__[item] 226 | # attributed belonging to the target filesystem 227 | cls = type(fs) 228 | m = getattr(cls, item) 229 | if (inspect.isfunction(m) or inspect.isdatadescriptor(m)) and ( 230 | not hasattr(m, "__self__") or m.__self__ is None 231 | ): 232 | # instance method 233 | return m.__get__(fs, cls) 234 | return m # class method or attribute 235 | else: 236 | # attributes of the superclass, while target is being set up 237 | return super().__getattribute__(item) -------------------------------------------------------------------------------- /universql/main.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | import os 4 | import shutil 5 | import socket 6 | import sys 7 | import tempfile 8 | import typing 9 | from pathlib import Path 10 | 11 | import click 12 | import requests 13 | import uvicorn 14 | from requests import RequestException 15 | 16 | from universql.agent.cloudflared import start_cloudflared 17 | from universql.util import LOCALHOST_UNIVERSQL_COM_BYTES, Catalog, LOCALHOSTCOMPUTING_COM, \ 18 | DEFAULTS, initialize_context 19 | 20 | logger = logging.getLogger("🏠") 21 | 22 | 23 | @click.group(context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}) 24 | @click.version_option(version="0.1") 25 | def cli(): 26 | pass 27 | 28 | 29 | @cli.command( 30 | epilog='[BETA] Check out docs at https://github.com/buremba/universql and let me know if you have any cool use-case on Github!') 31 | @click.option('--account', 32 | help='The account to use. Supports both Snowflake and Polaris (ex: rt21601.europe-west2.gcp)', 33 | prompt='Account ID (example: rt21601.europe-west2.gcp)', envvar='SNOWFLAKE_ACCOUNT') 34 | @click.option('--port', help='Port for Snowflake proxy server (default: 8084)', default=8084, envvar='SERVER_PORT', 35 | type=int) 36 | @click.option('--host', help='Host for Snowflake proxy server (default: localhostcomputing.com)', 37 | default=LOCALHOSTCOMPUTING_COM, 38 | envvar='SERVER_HOST', 39 | type=str) 40 | @click.option('--metrics-port', help='Grafana metrics port, only available for cloudflared host (default: 5675)', 41 | default=5675, 42 | envvar='METRICS_PORT', 43 | type=str) 44 | @click.option('--catalog', type=click.Choice([e.value for e in Catalog]), 45 | help='Type of the Snowflake account. Automatically detected from the account if not provided.') 46 | @click.option('--universql-catalog', type=str, 47 | help='The external catalog that will be used for Iceberg tables. (default: duckdb:///:memory:)', 48 | envvar='UNIVERSQL_CATALOG') 49 | # @click.option('--snowflake-catalog-integration', type=str, 50 | # help='Snowflake catalog integration for CREATE TABLE queries', 51 | # envvar='SNOWFLAKE_CATALOG_INTEGRATION') 52 | # @click.option('--snowflake-external-volume', type=str, 53 | # help='Snowflake external volume for CREATE TABLE queries', 54 | # envvar='SNOWFLAKE_EXTERNAL_VOLUME') 55 | @click.option('--aws-profile', help='AWS profile to access S3 (default: `default`)', type=str) 56 | @click.option('--gcp-project', 57 | help='GCP project to access GCS and apply quota. (to see how to setup auth for GCP and use different accounts, visit https://cloud.google.com/docs/authentication/application-default-credentials)', 58 | type=str) 59 | # @click.option('--azure-tenant', help='Azure account to access Blob Storage (ex: default az credentials)', type=str) 60 | @click.option('--ssl_keyfile', 61 | help='SSL keyfile for the proxy server, optional. Use it if you don\'t want to use localhostcomputing.com', 62 | type=str) 63 | @click.option('--ssl_certfile', help='SSL certfile for the proxy server, optional. ', type=str) 64 | @click.option('--max-memory', type=str, default=DEFAULTS["max_memory"], 65 | help='DuckDB Max memory to use for the server (default: 80% of total memory)', 66 | envvar='MAX_MEMORY', ) 67 | @click.option('--cache-directory', 68 | help=f'Data lake cache directory (default: {Path.home() / ".universql" / "cache"})', 69 | default=Path.home() / ".universql" / "cache", 70 | envvar='CACHE_DIRECTORY', 71 | type=str) 72 | @click.option('--home-directory', 73 | help=f'Home directory for local operations (default: {Path.home()})', 74 | default=Path.home(), 75 | envvar='HOME', 76 | type=str) 77 | @click.option('--max-cache-size', type=str, default=DEFAULTS["max_cache_size"], 78 | help='DuckDB maximum cache used in local disk (default: 80% of total available disk)', 79 | envvar='CACHE_PERCENTAGE', ) 80 | @click.option('--database-path', type=click.Path(exists=False, writable=True), 81 | help='Optional DuckDB Path. (default: :memory:) For persistent storage, provide a path similar to `~/.universql/$session_id.duckdb`', 82 | envvar='DATABASE_PATH') 83 | @click.option('--tunnel', type=click.Choice(["cloudflared", "ngrok"]), 84 | help='Use tunnel for accessing server from public internet', envvar='TUNNEL') 85 | @click.option('--motherduck-token', type=str, 86 | help='Motherduck token to enable', envvar='MOTHERDUCK_TOKEN') 87 | def snowflake(host, port, ssl_keyfile, ssl_certfile, account, catalog, metrics_port, tunnel, **kwargs): 88 | context__params = click.get_current_context().params 89 | auto_catalog_mode = catalog is None 90 | if auto_catalog_mode: 91 | try: 92 | polaris_server_check = requests.get( 93 | f"https://{account}.snowflakecomputing.com/polaris/api/catalog/v1/oauth/tokens") 94 | is_polaris = polaris_server_check.status_code == 405 95 | except RequestException as e: 96 | error_message = ( 97 | f"Unable to find Snowflake account (https://{account}.snowflakecomputing.com), make sure if you have access to the Snowflake account. (maybe need VPN access?) \n" 98 | f"You can set `--catalog` property to avoid this error. \n {str(e.args)}") 99 | logger.error(error_message) 100 | sys.exit(1) 101 | 102 | context__params["catalog"] = Catalog.POLARIS.value if is_polaris else Catalog.SNOWFLAKE.value 103 | 104 | adjective = "apparently" if auto_catalog_mode else "" 105 | logger.info(f"UniverSQL is starting reverse proxy for {account}.snowflakecomputing.com, " 106 | f"it's {adjective} a {context__params['catalog']} server.") 107 | 108 | if host == LOCALHOSTCOMPUTING_COM: 109 | try: 110 | data = socket.gethostbyname_ex(LOCALHOSTCOMPUTING_COM) 111 | logger.info(f"Using the SSL keyfile and certfile for {LOCALHOSTCOMPUTING_COM}. DNS resolves to {data}") 112 | if "127.0.0.1" not in data[2]: 113 | logger.error( 114 | f"The DNS setting for {LOCALHOSTCOMPUTING_COM} doesn't point to localhost, refusing to start. Please update UniverSQL.") 115 | sys.exit(1) 116 | except socket.gaierror as e: 117 | logger.warning(f"Unable to resolve DNS for {LOCALHOSTCOMPUTING_COM}, you're not connected to the internet") 118 | 119 | if tunnel == 'cloudflared': 120 | start_cloudflared(port, metrics_port) 121 | if os.getenv('USE_LOCALCOMPUTING_COM') == '1': 122 | host = "0.0.0.0" 123 | else: 124 | host = '127.0.0.1' 125 | elif tunnel == 'ngrok': 126 | logger.error("Ngrok is not supported yet. Please use cloudflared.") 127 | sys.exit(1) 128 | 129 | initialize_context() 130 | if host == LOCALHOSTCOMPUTING_COM or ssl_certfile is not None or os.getenv('USE_LOCALCOMPUTING_COM') == '1': 131 | with tempfile.NamedTemporaryFile(suffix='cert.pem', delete=True) as cert_file: 132 | cert_file.write(base64.b64decode(LOCALHOST_UNIVERSQL_COM_BYTES['cert'])) 133 | cert_file.flush() 134 | with tempfile.NamedTemporaryFile(suffix='key.pem', delete=True) as key_file: 135 | key_file.write(base64.b64decode(LOCALHOST_UNIVERSQL_COM_BYTES['key'])) 136 | key_file.flush() 137 | 138 | try: 139 | uvicorn.run("universql.protocol.snowflake:app", 140 | host=host, port=port, 141 | ssl_keyfile=ssl_keyfile or key_file.name, 142 | ssl_certfile=ssl_certfile or cert_file.name, 143 | reload=False, 144 | use_colors=True) 145 | except Exception as e: 146 | logger.critical("Unable to start the server", exc_info=e) 147 | raise e 148 | else: 149 | uvicorn.run("universql.protocol.snowflake:app", 150 | host=host, port=port, 151 | reload=False, 152 | use_colors=True) 153 | 154 | 155 | class EndpointFilter(logging.Filter): 156 | def __init__( 157 | self, 158 | path: str, 159 | *args: typing.Any, 160 | **kwargs: typing.Any, 161 | ): 162 | super().__init__(*args, **kwargs) 163 | self._path = path 164 | 165 | def filter(self, record: logging.LogRecord) -> bool: 166 | return record.getMessage().find(self._path) == -1 167 | 168 | 169 | 170 | uvicorn_logger = logging.getLogger("uvicorn.access") 171 | uvicorn_logger.addFilter(EndpointFilter(path="/telemetry/send")) 172 | uvicorn_logger.addFilter(EndpointFilter(path="/queries/v1/query-request")) 173 | uvicorn_logger.addFilter(EndpointFilter(path="/session")) 174 | 175 | if __name__ == '__main__': 176 | cli(prog_name="universql") 177 | -------------------------------------------------------------------------------- /universql/plugin.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | 4 | import pyarrow 5 | from fastapi import FastAPI 6 | from sqlglot import Expression 7 | 8 | import typing 9 | from abc import ABC, abstractmethod 10 | from typing import List 11 | 12 | import pyiceberg.table 13 | import sqlglot 14 | 15 | Locations = typing.Dict[sqlglot.exp.Table, sqlglot.exp.Expression | None] 16 | Tables = typing.Dict[sqlglot.exp.Table, pyiceberg.table.Table | None] 17 | 18 | 19 | class ICatalog(ABC): 20 | def __init__(self, session: "universql.protocol.session.UniverSQLSession"): 21 | self.context = session.context 22 | self.session_id = session.session_id 23 | self.credentials = session.credentials 24 | self.iceberg_catalog = session.iceberg_catalog 25 | 26 | @abstractmethod 27 | def get_table_paths(self, tables: List[sqlglot.exp.Table]) -> Tables: 28 | pass 29 | 30 | @abstractmethod 31 | def register_locations(self, tables: Locations): 32 | pass 33 | 34 | @abstractmethod 35 | def executor(self) -> "Executor": 36 | pass 37 | 38 | 39 | T = typing.TypeVar('T', bound=ICatalog) 40 | 41 | 42 | def _track_call(method): 43 | @functools.wraps(method) 44 | def wrapper(self, *args, **kwargs): 45 | # Mark that the method was called on this instance. 46 | self._warm = True 47 | return method(self, *args, **kwargs) 48 | 49 | return wrapper 50 | 51 | 52 | class Executor(typing.Protocol[T]): 53 | 54 | def __init__(self, catalog: T): 55 | self.catalog = catalog 56 | 57 | def __init_subclass__(cls, **kwargs): 58 | super().__init_subclass__(**kwargs) 59 | cls.execute = _track_call(cls.execute) 60 | cls.execute_raw = _track_call(cls.execute_raw) 61 | 62 | def is_warm(self): 63 | return getattr(self, '_warm', False) 64 | 65 | def insert(self, table: sqlglot.exp.Table, data: pyarrow.Table): 66 | raise NotImplementedError() 67 | 68 | @abstractmethod 69 | def execute(self, ast: sqlglot.exp.Expression, catalog_executor: "Executor", locations: Tables) -> \ 70 | typing.Optional[Locations]: 71 | pass 72 | 73 | def test(self): 74 | self.execute_raw("select 1", None) 75 | 76 | @abstractmethod 77 | def execute_raw(self, raw_query: str, catalog_executor: typing.Optional["Executor"]) -> None: 78 | pass 79 | 80 | @abstractmethod 81 | def get_as_table(self) -> pyarrow.Table: 82 | pass 83 | 84 | @abstractmethod 85 | def get_query_log(self, total_duration) -> str: 86 | pass 87 | 88 | @abstractmethod 89 | def close(self): 90 | pass 91 | 92 | 93 | class UQuery: 94 | def __init__(self, session: "universql.protocol.session.UniverSQLSession", ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query: str): 95 | self.session = session 96 | self.ast = ast 97 | self.raw_query = raw_query 98 | 99 | def transform_ast(self, ast: sqlglot.exp.Expression, target_executor: Executor) -> Expression: 100 | return ast 101 | 102 | def post_execute(self, locations: typing.Optional[Locations], target_executor: Executor): 103 | pass 104 | 105 | def end(self, table : pyarrow.Table): 106 | pass 107 | 108 | 109 | class UniversqlPlugin(ABC): 110 | def __init__(self, 111 | session: "universql.protocol.session.UniverSQLSession" 112 | ): 113 | self.session = session 114 | 115 | def start_query(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query : str) -> UQuery: 116 | return UQuery(self.session, ast, raw_query) 117 | 118 | 119 | # {"duckdb": DuckdbCatalog ..} 120 | COMPUTES = {} 121 | # [method] 122 | PLUGINS = [] 123 | # apps to be installed 124 | APPS = [] 125 | 126 | 127 | def register(name: typing.Optional[str] = None): 128 | """ 129 | Decorator to register a Compute subclass with an optional name. 130 | :param name: Unique of the catalog 131 | :param executor: The optional executor class for the catalog 132 | """ 133 | 134 | def decorator(cls): 135 | if inspect.isclass(cls): 136 | if issubclass(cls, ICatalog) and cls is not ICatalog: 137 | if name is None: 138 | raise SystemError("name is required for catalogs") 139 | COMPUTES[name] = cls 140 | elif issubclass(cls, UniversqlPlugin) and cls is not UniversqlPlugin: 141 | PLUGINS.append(cls) 142 | elif inspect.isfunction(cls): 143 | signature = inspect.signature(cls) 144 | if len(signature.parameters) == 1 and signature.parameters.values().__iter__().__next__().annotation is FastAPI: 145 | APPS.append(cls) 146 | else: 147 | raise SystemError(f"Unknown type {cls}") 148 | return cls 149 | 150 | return decorator 151 | -------------------------------------------------------------------------------- /universql/plugins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/universql/plugins/__init__.py -------------------------------------------------------------------------------- /universql/plugins/ui.py: -------------------------------------------------------------------------------- 1 | from contextlib import suppress 2 | from copy import deepcopy 3 | from traceback import print_exc 4 | 5 | from fastapi import FastAPI 6 | 7 | from universql.plugin import register 8 | 9 | 10 | # @register() 11 | def run_marimo(app: FastAPI): 12 | with suppress(ImportError): 13 | import marimo 14 | app.mount("/", marimo.create_asgi_app(include_code=False) 15 | .with_dynamic_directory(path="/", directory="/Users/bkabak/.universql/marimos") 16 | .build()) 17 | 18 | def configure(components, config): 19 | # TODO: generalize to arbitrary nested dictionaries, not just one level 20 | _components = deepcopy(components) 21 | for k1, v1 in config.items(): 22 | for k2, v2 in v1.items(): 23 | _components[k1][k2] = v2 24 | return _components 25 | 26 | 27 | # @register() 28 | async def run_jupyter(app: FastAPI): 29 | try: 30 | from asphalt.core import Context 31 | from jupyverse_api.main import JupyverseComponent 32 | from jupyverse_api.app import App 33 | 34 | components = configure({"app": {"type": "app"}}, {"app": {"mount_path": '/notebook'}}) 35 | async with Context() as ctx: 36 | component = JupyverseComponent(components=components, app=app, debug=True) 37 | component.start(ctx) 38 | jupyter_app = await ctx.request_resource(App) 39 | app.mount("/notebook", jupyter_app) 40 | except Exception as e: 41 | print_exc(10) 42 | print(e) 43 | 44 | -------------------------------------------------------------------------------- /universql/protocol/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/universql/protocol/__init__.py -------------------------------------------------------------------------------- /universql/protocol/lambda.py: -------------------------------------------------------------------------------- 1 | from mangum import Mangum 2 | 3 | from universql.protocol.snowflake import app as snowflake_app 4 | 5 | snowflake = Mangum(snowflake_app, lifespan="off") 6 | -------------------------------------------------------------------------------- /universql/protocol/session.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from universql.lake.fsspec_util import get_friendly_disk_usage 3 | import tempfile 4 | import time 5 | from string import Template 6 | from traceback import print_exc 7 | from typing import List 8 | from urllib.parse import urlparse, parse_qs 9 | import os 10 | import signal 11 | import threading 12 | import pyarrow 13 | import pyiceberg.table 14 | import sentry_sdk 15 | import sqlglot 16 | from pyiceberg.catalog import PY_CATALOG_IMPL, load_catalog, TYPE 17 | from pyiceberg.exceptions import TableAlreadyExistsError, NoSuchNamespaceError 18 | from pyiceberg.io import PY_IO_IMPL 19 | from sqlglot import ParseError 20 | from sqlglot.expressions import Create, Identifier, DDL, Query, Use, Semicolon, Copy 21 | 22 | from universql.lake.cloud import CACHE_DIRECTORY_KEY, MAX_CACHE_SIZE 23 | from universql.util import get_friendly_time_since, \ 24 | prepend_to_lines, QueryError, full_qualifier, current_context 25 | from universql.plugin import Executor, Tables, ICatalog, COMPUTES, PLUGINS, UniversqlPlugin, UQuery 26 | 27 | logger = logging.getLogger("💡") 28 | sessions = {} 29 | 30 | 31 | class UniverSQLSession: 32 | def __init__(self, context, session_id, credentials: dict, session_parameters: dict): 33 | self.context = context 34 | self.credentials = credentials 35 | self.session_parameters = [{"name": item[0], "value": item[1]} for item in session_parameters.items()] 36 | self.session_id = session_id 37 | self.iceberg_catalog = self._get_iceberg_catalog() 38 | self.catalog = COMPUTES["snowflake"](self) 39 | self.target_compute = COMPUTES["duckdb"](self) 40 | self.catalog_executor = self.catalog.executor() 41 | self.processing = False 42 | self.metadata_db = None 43 | self.plugins: List[UniversqlPlugin] = [plugin(self) for plugin in PLUGINS] 44 | self.query_history = [] 45 | 46 | def _get_iceberg_catalog(self): 47 | iceberg_catalog = self.context.get('universql_catalog') 48 | 49 | catalog_props = { 50 | PY_IO_IMPL: "universql.lake.cloud.iceberg", 51 | # WAREHOUSE: "gs://my-iceberg-data/custom-events/customer_iceberg_pyiceberg", 52 | CACHE_DIRECTORY_KEY: self.context.get('cache_directory'), 53 | MAX_CACHE_SIZE: self.context.get('max_cache_size'), 54 | } 55 | 56 | if iceberg_catalog is not None: 57 | parsed = urlparse(iceberg_catalog) 58 | catalog_name = parsed.scheme 59 | 60 | query_params = parse_qs(parsed.query) 61 | catalog_props |= {k: v[0] if len(v) == 1 else v for k, v in query_params.items()} 62 | catalog_props['namespace'] = parsed.hostname 63 | catalog_props[TYPE] = "glue" 64 | catalog = load_catalog(catalog_name, **catalog_props) 65 | else: 66 | database_path = Template( 67 | self.context.get('database_path') or f"_{self.session_id}_universql_session").substitute( 68 | {"session_id": self.session_id}) + ".sqlite" 69 | catalog_name = "duckdb" 70 | self.metadata_db = tempfile.NamedTemporaryFile(delete=False, suffix=database_path) 71 | 72 | catalog_props |= { 73 | # pass duck conn 74 | PY_CATALOG_IMPL: "pyiceberg.catalog.sql.SqlCatalog", 75 | "uri": f"sqlite:///{self.metadata_db.name}", 76 | "namespace": "main", 77 | } 78 | catalog = load_catalog(catalog_name, **catalog_props) 79 | catalog.create_namespace_if_not_exists("main") 80 | return catalog 81 | 82 | def _must_run_on_catalog(self, tables, ast): 83 | queries_that_doesnt_need_warehouse = ["show"] 84 | return ast.name in queries_that_doesnt_need_warehouse \ 85 | or ast.key in queries_that_doesnt_need_warehouse 86 | 87 | def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table: 88 | with sentry_sdk.start_span(op="sqlglot", name="Parsing query"): 89 | try: 90 | queries = sqlglot.parse(raw_query, read="snowflake") 91 | except ParseError as e: 92 | queries = None 93 | raise QueryError(f"Unable to parse query with SQLGlot: {e.args}") 94 | 95 | last_executor = None 96 | 97 | plugin_hooks = [] 98 | for plugin in self.plugins: 99 | try: 100 | plugin_hooks.append(plugin.start_query(queries, raw_query)) 101 | except Exception as e: 102 | print_exc(10) 103 | message = f"Unable to call start_query on plugin {plugin.__class__}" 104 | logger.error(message, exc_info=e) 105 | raise QueryError(f"{message}: {str(e)}") 106 | 107 | if queries is None: 108 | last_executor = self.perform_query(self.catalog_executor, raw_query, plugin_hooks) 109 | else: 110 | last_error = None 111 | 112 | for ast in queries: 113 | if isinstance(ast, Semicolon) and ast.this is None: 114 | continue 115 | last_executor = self.target_compute.executor() 116 | last_executor = self.perform_query(last_executor, raw_query, plugin_hooks, ast=ast) 117 | 118 | performance_counter = time.perf_counter() 119 | query_duration = performance_counter - start_time 120 | 121 | if last_error is not None: 122 | raise last_error 123 | 124 | logger.info( 125 | f"[{self.session_id}] {last_executor.get_query_log(query_duration)} 🚀 " 126 | f"({get_friendly_time_since(start_time, performance_counter)})") 127 | 128 | table = last_executor.get_as_table() 129 | for hook in plugin_hooks: 130 | try: 131 | hook.end(table) 132 | except Exception as e: 133 | print_exc(10) 134 | message = f"Unable to end query execution on plugin {hook.__class__}" 135 | logger.error(message, exc_info=e) 136 | raise QueryError(f"{message}: {str(e)}") 137 | return table 138 | 139 | def _find_tables(self, ast: sqlglot.exp.Expression, cte_aliases=None): 140 | if cte_aliases is None: 141 | cte_aliases = set() 142 | for expression in ast.walk(bfs=True): 143 | if isinstance(expression, Query) or isinstance(expression, DDL): 144 | if expression.ctes is not None and len(expression.ctes) > 0: 145 | for cte in expression.ctes: 146 | cte_aliases.add(cte.alias) 147 | if isinstance(expression, sqlglot.exp.Table) and isinstance(expression.this, Identifier): 148 | if expression.catalog or expression.db or str(expression.this.this) not in cte_aliases: 149 | yield full_qualifier(expression, self.credentials), cte_aliases 150 | 151 | def perform_query(self, alternative_executor: Executor, raw_query, plugin_hooks: List[UQuery], 152 | ast=None) -> Executor: 153 | if ast is not None and alternative_executor != self.catalog_executor: 154 | must_run_on_catalog = False 155 | if isinstance(ast, Create): 156 | if ast.kind in ('TABLE', 'VIEW'): 157 | tables = self._find_tables(ast.expression) if ast.expression is not None else [] 158 | else: 159 | tables = [] 160 | must_run_on_catalog = True 161 | elif isinstance(ast, Use): 162 | tables = [] 163 | else: 164 | tables = self._find_tables(ast) 165 | tables_list = [table[0] for table in tables] 166 | must_run_on_catalog = must_run_on_catalog or self._must_run_on_catalog(tables_list, ast) 167 | if not must_run_on_catalog: 168 | op_name = alternative_executor.__class__.__name__ 169 | with sentry_sdk.start_span(op=op_name, name="Get table paths"): 170 | locations = self.get_table_paths_from_catalog(alternative_executor.catalog, tables_list) 171 | with sentry_sdk.start_span(op=op_name, name="Execute query"): 172 | current_ast = ast 173 | for plugin_hook in plugin_hooks: 174 | try: 175 | current_ast = plugin_hook.transform_ast(ast, alternative_executor) 176 | except Exception as e: 177 | print_exc(10) 178 | message = f"Unable to tranform_ast on plugin {plugin_hook.__class__}" 179 | logger.error(message, exc_info=e) 180 | raise QueryError(f"{message}: {str(e)}") 181 | new_locations = alternative_executor.execute(current_ast, self.catalog_executor, locations) 182 | for plugin_hook in plugin_hooks: 183 | try: 184 | plugin_hook.post_execute(new_locations, alternative_executor) 185 | except Exception as e: 186 | print_exc(10) 187 | message = f"Unable to post_execute on plugin {plugin_hook.__class__}" 188 | logger.error(message, exc_info=e) 189 | raise QueryError(f"{message}: {str(e)}") 190 | if new_locations is not None: 191 | with sentry_sdk.start_span(op=op_name, name="Register new locations"): 192 | self.catalog.register_locations(new_locations) 193 | return alternative_executor 194 | 195 | with sentry_sdk.start_span(name="Execute query on Snowflake"): 196 | last_executor = self.catalog_executor 197 | if ast is None: 198 | last_executor.execute_raw(raw_query, self.catalog_executor) 199 | else: 200 | last_executor.execute(ast, self.catalog_executor, {}) 201 | return last_executor 202 | 203 | def do_query(self, raw_query: str) -> pyarrow.Table: 204 | start_time = time.perf_counter() 205 | logger.info(f"[{self.session_id}] Transpiling query \n{prepend_to_lines(raw_query)}") 206 | self.processing = True 207 | try: 208 | return self._do_query(start_time, raw_query) 209 | finally: 210 | self.processing = False 211 | 212 | def close(self): 213 | self.catalog_executor.close() 214 | if self.metadata_db is not None: 215 | self.metadata_db.close() 216 | if os.path.exists(self.metadata_db.name): 217 | os.remove(self.metadata_db.name) 218 | 219 | def get_table_paths_from_catalog(self, alternative_catalog: ICatalog, tables: list[sqlglot.exp.Table]) -> Tables: 220 | not_existed = [] 221 | cached_tables = {} 222 | 223 | for table in tables: 224 | full_qualifier_ = full_qualifier(table, self.credentials) 225 | table_path = alternative_catalog.get_table_paths([full_qualifier_]).get(full_qualifier_, False) 226 | if table_path is None or isinstance(table_path, pyarrow.Table): 227 | continue 228 | # try: 229 | # namespace = self.iceberg_catalog.properties.get('namespace') 230 | # table_ref = table.sql(dialect="snowflake") 231 | # logger.info(f"Looking up table {table_ref} in namespace {namespace}") 232 | # iceberg_table = self.iceberg_catalog.load_table((namespace, table_ref)) 233 | # cached_tables[table] = iceberg_table 234 | # except NoSuchTableError: 235 | not_existed.append(table) 236 | 237 | locations = self.catalog.get_table_paths(not_existed) 238 | for table_ast, table in locations.items(): 239 | if isinstance(table, pyiceberg.table.Table): 240 | namespace = self.iceberg_catalog.properties.get('namespace') 241 | table_name = table_ast.sql(dialect="snowflake") 242 | try: 243 | iceberg_table = self.iceberg_catalog.register_table((namespace, table_name), 244 | table.metadata_location) 245 | except TableAlreadyExistsError: 246 | iceberg_table = self.iceberg_catalog.load_table((namespace, table_name)) 247 | except NoSuchNamespaceError: 248 | self.iceberg_catalog.create_namespace((namespace,)) 249 | iceberg_table = self.iceberg_catalog.register_table((namespace, table_name), table.metadata) 250 | locations[table_ast] = iceberg_table 251 | else: 252 | raise Exception(f"Unknown table type {table}") 253 | return locations | cached_tables 254 | 255 | 256 | ENABLE_DEBUG_WATCH_TOWER = False 257 | WATCH_TOWER_SCHEDULE_SECONDS = 2 258 | kill_event = threading.Event() 259 | 260 | 261 | def watch_tower(cache_directory, **kwargs): 262 | while True: 263 | kill_event.wait(timeout=WATCH_TOWER_SCHEDULE_SECONDS) 264 | if kill_event.is_set(): 265 | break 266 | processing_sessions = sum(session.processing for token, session in sessions.items()) 267 | if ENABLE_DEBUG_WATCH_TOWER or processing_sessions > 0: 268 | try: 269 | import psutil 270 | process = psutil.Process() 271 | cpu_percent = f"[CPU: {'%.1f' % psutil.cpu_percent()}%]" 272 | memory_percent = f"[Memory: {'%.1f' % process.memory_percent()}%]" 273 | except: 274 | memory_percent = "" 275 | cpu_percent = "" 276 | 277 | disk_info = get_friendly_disk_usage(cache_directory, debug=ENABLE_DEBUG_WATCH_TOWER) 278 | logger.info(f"Currently {len(sessions)} sessions running {processing_sessions} queries " 279 | f"| System: {cpu_percent} {memory_percent} [Disk: {disk_info}] ") 280 | 281 | 282 | thread = threading.Thread(target=watch_tower, kwargs=current_context) 283 | thread.daemon = True 284 | thread.start() 285 | 286 | 287 | # If the user intends to kill the server, not wait for DuckDB to gracefully shutdown. 288 | # It's nice to treat the Duck better giving it time but user's time is more valuable than the duck's. 289 | def harakiri(sig, frame): 290 | print("Killing the server, bye!") 291 | kill_event.set() 292 | os.kill(os.getpid(), signal.SIGKILL) 293 | 294 | # last_intent_to_kill = time.time() 295 | # def graceful_shutdown(_, frame): 296 | # global last_intent_to_kill 297 | # processing_sessions = sum(session.processing for token, session in sessions.items()) 298 | # if processing_sessions == 0 or (time.time() - last_intent_to_kill) < 5000: 299 | # harakiri(signal, frame) 300 | # else: 301 | # print(f'Repeat sigint to confirm killing {processing_sessions} running queries.') 302 | # last_intent_to_kill = time.time() 303 | -------------------------------------------------------------------------------- /universql/protocol/snowflake.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import glob 3 | import json 4 | import logging 5 | import base64 6 | import os 7 | import signal 8 | import threading 9 | from traceback import print_exc 10 | 11 | from typing import Any 12 | from uuid import uuid4 13 | 14 | import click 15 | import pyarrow as pa 16 | import sentry_sdk 17 | import yaml 18 | 19 | from fastapi import FastAPI 20 | from pyarrow import Schema 21 | from snowflake.connector import DatabaseError 22 | from starlette.exceptions import HTTPException 23 | from starlette.requests import Request 24 | from starlette.responses import JSONResponse, Response, HTMLResponse 25 | 26 | from universql.agent.cloudflared import get_cloudflare_url 27 | from universql.plugin import APPS 28 | from universql.util import unpack_request_body, session_from_request, parameters, \ 29 | print_dict_as_markdown_table, QueryError, LOCALHOSTCOMPUTING_COM, current_context 30 | from fastapi.encoders import jsonable_encoder 31 | from starlette.concurrency import run_in_threadpool 32 | from universql.protocol.session import UniverSQLSession, sessions, harakiri, kill_event 33 | 34 | logger = logging.getLogger("🧵") 35 | 36 | app = FastAPI() 37 | query_results = {} 38 | 39 | # register all warehouses and plugins 40 | for module in [ 41 | "universql.warehouse.duckdb", 42 | "universql.warehouse.bigquery", 43 | "universql.warehouse.snowflake", 44 | "universql.warehouse.redshift", 45 | "universql.plugins.snow", 46 | "universql.plugins.ui", 47 | ]: 48 | __import__(module) 49 | 50 | 51 | @app.post("/session/v1/login-request") 52 | async def login_request(request: Request) -> JSONResponse: 53 | body = await unpack_request_body(request) 54 | 55 | login_data = body.get('data') 56 | client_environment = login_data.get('CLIENT_ENVIRONMENT') 57 | credentials = {key: client_environment[key] for key in ["schema", "warehouse", "role", "user", "database"] if 58 | key in client_environment} 59 | if login_data.get('PASSWORD') is not None: 60 | credentials['password'] = login_data.get('PASSWORD') 61 | if "user" not in credentials and login_data.get('LOGIN_NAME') is not None: 62 | credentials["user"] = login_data.get("LOGIN_NAME") 63 | 64 | params = request.query_params 65 | if "database" not in credentials: 66 | credentials["database"] = params.get('databaseName') 67 | if "warehouse" not in credentials: 68 | credentials["warehouse"] = params.get('warehouse') 69 | if "role" not in credentials: 70 | credentials["role"] = params.get('roleName') 71 | if "schema" not in credentials: 72 | # TODO: support different default schemas stored in `SHOW PARAMETERS LIKE 'search_path'` 73 | credentials["schema"] = params.get('schemaName') or "PUBLIC" 74 | 75 | token = str(uuid4()) 76 | message = None 77 | try: 78 | session = UniverSQLSession(current_context, token, credentials, login_data.get("SESSION_PARAMETERS")) 79 | sessions[session.session_id] = session 80 | except QueryError as e: 81 | message = e.message 82 | 83 | client = f"{request.client.host}:{request.client.port}" 84 | 85 | if message is None: 86 | logger.info( 87 | f"[{token}] Created local session for user {credentials.get('user')} from {client}") 88 | else: 89 | logger.error( 90 | f"Rejected login request from {client} for user {credentials.get('user')}. Reason: {message}") 91 | 92 | return JSONResponse( 93 | { 94 | "data": 95 | { 96 | "token": token, 97 | # TODO: figure out how to generate safer token 98 | "masterToken": token, 99 | "parameters": parameters, 100 | "sessionInfo": {f'{k}Name': v for k, v in credentials.items()}, 101 | "idToken": None, 102 | "idTokenValidityInSeconds": 0, 103 | "responseData": None, 104 | "mfaToken": None, 105 | "mfaTokenValidityInSeconds": 0 106 | }, 107 | "message": message, 108 | "success": message is None, 109 | "code": None, 110 | "validityInSeconds": 3600, 111 | "masterValidityInSeconds": 14400, 112 | "displayUserName": "", 113 | "serverVersion": "duck", 114 | "firstLogin": False, 115 | "remMeToken": None, 116 | "remMeValidityInSeconds": 0, 117 | "healthCheckInterval": 45, 118 | }) 119 | 120 | 121 | @app.post("/session") 122 | async def delete_session(request: Request): 123 | if request.query_params.get("delete") == "true": 124 | try: 125 | session = session_from_request(sessions, request) 126 | except HTTPException as e: 127 | if e.status_code == 401: 128 | # most likely the server has started 129 | return JSONResponse({"success": True}) 130 | 131 | del sessions[session.session_id] 132 | logger.info(f"[{session.session_id}] Session closed, cleaning up resources.") 133 | session.close() 134 | return JSONResponse({"success": True}) 135 | return Response(status_code=404) 136 | 137 | 138 | @app.post("/telemetry/send") 139 | async def telemetry_request(request: Request) -> JSONResponse: 140 | request = await unpack_request_body(request) 141 | logs = request.get('logs') 142 | return JSONResponse(content={"success": True}, status_code=200) 143 | 144 | 145 | @app.post("/session/heartbeat") 146 | async def session_heartbeat(request: Request) -> JSONResponse: 147 | try: 148 | session_from_request(sessions, request) 149 | except HTTPException as e: 150 | if e.status_code == 401: 151 | # most likely the server has started 152 | return JSONResponse({"success": False}) 153 | 154 | return JSONResponse( 155 | {"data": {}, "success": True}) 156 | 157 | 158 | def get_columns_for_sf_compat(schema: Schema) -> list[dict[str, Any]]: 159 | columns = [] 160 | for idx, col in enumerate(schema): 161 | # args = sqlglot.expressions.DataType.build(col[1], dialect=dialect).args 162 | # precision = args.get('expressions')[0].this 163 | # scale = args.get('expressions')[1].this 164 | columns.append({ 165 | "name": col.name, 166 | "database": "", 167 | "schema": "", 168 | "table": "", 169 | "nullable": True, 170 | "type": col.metadata.get(b'logicalType').decode(), 171 | "length": None, 172 | "scale": None, 173 | "precision": None, 174 | # "scale": int(scale.name), 175 | # "precision": int(precision.name), 176 | "byteLength": None, 177 | "collation": None 178 | }) 179 | return columns 180 | 181 | 182 | @app.post("/queries/v1/abort-request") 183 | async def abort_request(request: Request) -> JSONResponse: 184 | return JSONResponse( 185 | {"data": {}, "success": True}) 186 | 187 | 188 | @app.post("/queries/v1/query-request") 189 | async def query_request(request: Request) -> JSONResponse: 190 | query_id = str(uuid4()) 191 | query = None 192 | try: 193 | session = session_from_request(sessions, request) 194 | body = await unpack_request_body(request) 195 | query = body["sqlText"] 196 | queryResultFormat = "arrow" 197 | transaction = sentry_sdk.get_current_scope().transaction 198 | transaction.set_tag("query", query) 199 | result = await run_in_threadpool(session.do_query, query) 200 | columns = get_columns_for_sf_compat(result.schema) 201 | if result is None: 202 | return JSONResponse({"success": False, "message": "no query provided"}) 203 | data: dict = { 204 | "finalDatabaseName": session.credentials.get("database"), 205 | "finalSchemaName": session.credentials.get("schema"), 206 | "finalWarehouseName": session.credentials.get("warehouse"), 207 | "finalRoleName": session.credentials.get("role"), 208 | "rowtype": columns, 209 | "queryResultFormat": queryResultFormat, 210 | "databaseProvider": None, 211 | "numberOfBinds": 0, 212 | "arrayBindSupported": False, 213 | "queryId": query_id, 214 | "parameters": parameters 215 | } 216 | if body.get("asyncExec", False): 217 | # we should store the result in the server for when it gets retrieved 218 | query_results[query_id] = result 219 | 220 | format = data.get('queryResultFormat') 221 | if format == "json": 222 | data["rowset"] = jsonable_encoder(result) 223 | elif format == "arrow": 224 | number_of_rows = len(result) 225 | data["returned"] = number_of_rows 226 | if number_of_rows > 0: 227 | sink = pa.BufferOutputStream() 228 | with pa.ipc.new_stream(sink, result.schema) as writer: 229 | batches = result.to_batches() 230 | if len(batches) == 0: 231 | empty_batch = pa.RecordBatch.from_pylist([], schema=result.schema) 232 | writer.write_batch(empty_batch) 233 | else: 234 | for batch in batches: 235 | writer.write_batch(batch) 236 | buf = sink.getvalue() 237 | # one-copy 238 | pybytes = buf.to_pybytes() 239 | b_encode = base64.b64encode(pybytes) 240 | encode = b_encode.decode('utf-8') 241 | else: 242 | encode = "" 243 | data["rowsetBase64"] = encode 244 | else: 245 | raise Exception(f"Format {format} is not supported") 246 | return JSONResponse({"data": data, "success": True}) 247 | except QueryError as e: 248 | # print_exc(limit=1) 249 | return JSONResponse({"id": query_id, "success": False, "message": e.message, "data": {"sqlState": e.sql_state}}) 250 | except DatabaseError as e: 251 | print_exc(limit=10) 252 | return JSONResponse({"id": query_id, "success": False, 253 | "message": f"Error running query on Snowflake: {e.raw_msg}", 254 | "data": {"sqlState": e.sqlstate}}) 255 | except Exception as e: 256 | if not isinstance(e, HTTPException): 257 | print_exc(limit=10) 258 | if query is not None: 259 | logger.exception(f"Error processing query: {query}") 260 | else: 261 | logger.exception(f"Error processing query request", e) 262 | return JSONResponse({"id": query_id, "success": False, 263 | "message": "Unable to run the query due to a system error. Please create issue on https://github.com/buremba/universql/issues", 264 | "data": {"sqlState": "0000"}}) 265 | 266 | 267 | @app.get("/jupyterlite/new") 268 | async def jupyter(request: Request) -> JSONResponse: 269 | return HTMLResponse(" " + f""" 270 | 276 | """) 277 | 278 | 279 | @app.get("/") 280 | async def home(request: Request) -> JSONResponse: 281 | return JSONResponse( 282 | {"status": "X-Duck is ducking 🐥", "new": {"jupyter": "/jupyterlite/new", "streamlit": "/streamlit/new"}}) 283 | 284 | def get_files(path): 285 | files_dict = {} 286 | prefix = f"app/{path}/" 287 | for filepath in glob.glob(f"{prefix}**/*", recursive=True): 288 | if os.path.isfile(filepath): 289 | with open(filepath, 'r', encoding='utf-8') as file: 290 | # Create relative path as key 291 | relative_path = os.path.relpath(filepath, prefix) 292 | files_dict[relative_path] = file.read() 293 | return files_dict 294 | 295 | def get_requirements(path): 296 | requirements_file = os.path.abspath(f"app/{path}/requirements.txt") 297 | if os.path.exists(requirements_file): 298 | with open(requirements_file, 'r', encoding='utf-8') as file: 299 | return file.read().splitlines() 300 | else: 301 | return [] 302 | 303 | 304 | @app.get("/{rest_of_path:path}") 305 | async def streamlit_renderer(request: Request, rest_of_path: str) -> JSONResponse: 306 | settings = {"requirements": get_requirements(rest_of_path), "entrypoint": "page.py", "streamlitConfig": {}, 307 | "files": get_files(rest_of_path)} 308 | return HTMLResponse(f""" 309 | 310 | 311 | 312 | 313 | 314 | 318 | Stlite App 319 | 323 | 324 | 325 |
326 | 327 | 332 | 333 | 334 | """) 335 | 336 | 337 | @app.get("/monitoring/queries/{query_id:str}") 338 | async def query_monitoring_query(request: Request) -> JSONResponse: 339 | query_id = request.path_params["query_id"] 340 | if query_id not in query_results: 341 | return JSONResponse({"success": False, "message": "query not found"}) 342 | # note that we always execute synchronously, so we can just return a static success 343 | return JSONResponse({"data": {"queries": [{"status": "SUCCESS"}]}, "success": True}) 344 | 345 | 346 | @app.on_event("shutdown") 347 | async def shutdown_event(): 348 | kill_event.set() 349 | for token, session in sessions.items(): 350 | session.close() 351 | 352 | 353 | @app.on_event("startup") 354 | async def startup_event(): 355 | params = {k: v for k, v in current_context.items() if 356 | v is not None and k not in ["host", "port"]} 357 | click.secho(yaml.dump(params).strip()) 358 | if threading.current_thread() is threading.main_thread(): 359 | try: 360 | signal.signal(signal.SIGINT, harakiri) 361 | except Exception as e: 362 | logger.warning("Failed to set signal handler for SIGINT: %s", str(e)) 363 | host = current_context.get('host') 364 | tunnel = current_context.get('tunnel') 365 | port = current_context.get('port') 366 | if tunnel == "cloudflared": 367 | metrics_port = current_context.get('metrics_port') 368 | click.secho(f" * Traffic stats available on http://127.0.0.1:{metrics_port}/metrics") 369 | host_port = host = get_cloudflare_url(metrics_port) 370 | port = 443 371 | click.secho(f" * Tunnel available at https://{host_port}") 372 | elif os.getenv('USE_LOCALCOMPUTING_COM') == '1': 373 | host_port = f"{LOCALHOSTCOMPUTING_COM}:{port}" 374 | else: 375 | host_port = f"{host}:{port}" 376 | connections = { 377 | "Node.js": f"snowflake.createConnection({{accessUrl: '{host_port}'}})", 378 | "JDBC": f"jdbc:snowflake://{host_port}/dbname", 379 | "Python": f"snowflake.connector.connect(host='{host}', port='{port}')", 380 | "PHP": f"new PDO('snowflake:host={host_port}', '', '')", 381 | "Go": f"sql.Open('snowflake', 'user:pass@{host_port}/dbname')", 382 | ".NET": f"host={host_port};db=testdb", 383 | "ODBC": f"Server={host}; Database=dbname; Port={port}", 384 | } 385 | click.secho(print_dict_as_markdown_table(connections, footer_message=( 386 | "You can connect to UniverSQL with any Snowflake client using your Snowflake credentials.", 387 | "For application support, see https://github.com/buremba/universql",))) 388 | 389 | 390 | loop = asyncio.get_event_loop() 391 | 392 | for app_to_be_installed in APPS: 393 | app_to_be_installed(app) 394 | # loop.run_until_complete(app_to_be_installed(app)) 395 | # loop.close() -------------------------------------------------------------------------------- /universql/protocol/utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import duckdb 4 | import pyarrow as pa 5 | from pyarrow import ChunkedArray, Table 6 | from snowflake.connector.constants import FIELD_TYPES, FIELD_NAME_TO_ID 7 | from snowflake.connector.cursor import ResultMetadataV2 8 | 9 | 10 | 11 | class DuckDBFunctions: 12 | @staticmethod 13 | def register(db: duckdb.DuckDBPyConnection): 14 | db.create_function("CURRENT_WAREHOUSE", DuckDBFunctions.current_warehouse, [], duckdb.typing.VARCHAR) 15 | 16 | @staticmethod 17 | def current_warehouse() -> str: 18 | return "x-duck" 19 | 20 | 21 | def get_field_for_snowflake(column: ResultMetadataV2, value: typing.Optional[pa.Array] = None) -> \ 22 | typing.Tuple[ 23 | pa.Field, pa.Array]: 24 | arrow_field = FIELD_TYPES[column.type_code] 25 | 26 | metadata = { 27 | "logicalType": arrow_field.name, 28 | "charLength": "8388608", 29 | "byteLength": "8388608", 30 | } 31 | 32 | if arrow_field.name == "GEOGRAPHY": 33 | metadata["logicalType"] = "OBJECT" 34 | 35 | precision = column.precision 36 | scale = column.precision 37 | 38 | if arrow_field.name == 'FIXED': 39 | ### Rmove this line 40 | if value is not None and pa.types.is_integer(value.type): 41 | pa_type = value.type 42 | else: 43 | ### 44 | pa_type = pa.decimal128(column.precision 45 | # or value.type.precision 46 | or 38, 47 | column.scale 48 | # or value.type.precision 49 | or 0) 50 | precision = precision or 1 51 | scale = scale or 0 52 | if value is not None: 53 | try: 54 | value = value.cast(pa_type, safe=False) 55 | except Exception as e: 56 | raise e 57 | elif arrow_field.name == 'DATE': 58 | pa_type = pa.date32() 59 | if value is not None: 60 | value = value.cast(pa_type) 61 | elif arrow_field.name == 'TIME': 62 | pa_type = pa.int64() 63 | if value is not None: 64 | value = value.cast(pa_type) 65 | elif arrow_field.name == 'TIMESTAMP_LTZ' or arrow_field.name == 'TIMESTAMP_NTZ' or arrow_field.name == 'TIMESTAMP': 66 | metadata["final_type"] = "T" 67 | timestamp_fields = [ 68 | pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), 69 | pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), 70 | ] 71 | pa_type = pa.struct(timestamp_fields) 72 | if value is not None: 73 | epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() 74 | value = pa.StructArray.from_arrays(arrays=[epoch, pa.nulls(len(value), type=pa.int32())], 75 | fields=timestamp_fields) 76 | elif arrow_field.name == 'TIMESTAMP_TZ': 77 | metadata["final_type"] = "T" 78 | timestamp_fields = [ 79 | pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), 80 | pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), 81 | pa.field("timezone", nullable=False, type=pa.int32(), metadata=metadata), 82 | ] 83 | pa_type = pa.struct(timestamp_fields) 84 | if value is not None: 85 | epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() 86 | 87 | value = pa.StructArray.from_arrays( 88 | arrays=[epoch, 89 | # TODO: modulos 1_000_000_000 to get the fraction of a second, pyarrow doesn't support the operator yet 90 | pa.nulls(len(value), type=pa.int32()), 91 | # TODO: reverse engineer the timezone conversion 92 | pa.nulls(len(value), type=pa.int32()), 93 | ], 94 | fields=timestamp_fields) 95 | else: 96 | pa_type = arrow_field.pa_type(column) 97 | 98 | if precision is not None: 99 | metadata["precision"] = str(precision) 100 | if scale is not None: 101 | metadata["scale"] = str(scale) 102 | 103 | field = pa.field(column.name, type=pa_type, nullable=column.is_nullable, metadata=metadata) 104 | return (field, value) 105 | 106 | 107 | def arrow_to_snowflake_type_id(column_type: pa.DataType): 108 | if pa.types.is_decimal(column_type) or pa.types.is_integer(column_type): 109 | type_code = FIELD_NAME_TO_ID["FIXED"] 110 | elif pa.types.is_date(column_type): 111 | type_code = FIELD_NAME_TO_ID["DATE"] 112 | elif pa.types.is_floating(column_type): 113 | type_code = FIELD_NAME_TO_ID["REAL"] 114 | elif pa.types.is_timestamp(column_type): 115 | # Check if it should be TIMESTAMP_TZ 116 | type_code = FIELD_NAME_TO_ID["TIMESTAMP"] 117 | elif pa.types.is_boolean(column_type): 118 | type_code = FIELD_NAME_TO_ID["BOOLEAN"] 119 | elif pa.types.is_string(column_type): 120 | type_code = FIELD_NAME_TO_ID["TEXT"] 121 | elif pa.types.is_struct(column_type): 122 | # Check if it should be MAP 123 | type_code = FIELD_NAME_TO_ID["VARIANT"] 124 | elif pa.types.is_list(column_type): 125 | type_code = FIELD_NAME_TO_ID["ARRAY"] 126 | elif pa.types.is_binary(column_type): 127 | type_code = FIELD_NAME_TO_ID["BINARY"] 128 | elif pa.types.is_time(column_type): 129 | type_code = FIELD_NAME_TO_ID["TIME"] 130 | else: 131 | # Unsupported types: VECTOR, GEOMETRY, GEOGRAPHY 132 | raise ValueError(f"Unsupported type: {column_type}") 133 | 134 | return type_code 135 | 136 | 137 | def get_field_from_duckdb(column: list[str], arrow_table: Table, idx: int) -> typing.Tuple[ 138 | typing.Optional[ChunkedArray], pa.Field]: 139 | (field_name, field_type) = column[0], column[1] 140 | pa_type = arrow_table.schema[idx].type 141 | 142 | metadata = {} 143 | value = arrow_table[idx] 144 | 145 | if field_type == 'NUMBER': 146 | 147 | if ( # no harm for int types 148 | pa_type != pa.int64() and 149 | pa_type != pa.int32() and 150 | pa_type != pa.int16() and 151 | pa_type != pa.int8()): 152 | pa_type = pa.decimal128(getattr(value.type, 'precision', 38), getattr(value.type, 'scale', 0)) 153 | try: 154 | value = value.cast(pa_type) 155 | except Exception: 156 | pass 157 | metadata["logicalType"] = "FIXED" 158 | metadata["precision"] = "1" 159 | metadata["scale"] = "0" 160 | metadata["physicalType"] = "SB1" 161 | metadata["final_type"] = "T" 162 | elif field_type == 'Date': 163 | pa_type = pa.date32() 164 | value = value.cast(pa_type) 165 | metadata["logicalType"] = "DATE" 166 | elif field_type == 'Time': 167 | pa_type = pa.int64() 168 | value = value.cast(pa_type) 169 | metadata["logicalType"] = "TIME" 170 | elif field_type == "BINARY": 171 | pa_type = pa.binary() 172 | metadata["logicalType"] = "BINARY" 173 | elif field_type == "TIMESTAMP" or field_type == "DATETIME" or field_type == "TIMESTAMP_LTZ": 174 | metadata["logicalType"] = "TIMESTAMP_LTZ" 175 | metadata["precision"] = "0" 176 | metadata["scale"] = "9" 177 | metadata["physicalType"] = "SB16" 178 | metadata["final_type"] = "T" 179 | timestamp_fields = [ 180 | pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), 181 | pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), 182 | ] 183 | pa_type = pa.struct(timestamp_fields) 184 | epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() 185 | value = pa.StructArray.from_arrays(arrays=[epoch, pa.nulls(len(value), type=pa.int32())], 186 | fields=timestamp_fields) 187 | elif field_type == "TIMESTAMP_NTZ": 188 | metadata["logicalType"] = "TIMESTAMP_NTZ" 189 | metadata["precision"] = "0" 190 | metadata["scale"] = "9" 191 | metadata["physicalType"] = "SB16" 192 | timestamp_fields = [ 193 | pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), 194 | pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), 195 | ] 196 | pa_type = pa.struct(timestamp_fields) 197 | epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() 198 | value = pa.StructArray.from_arrays(arrays=[epoch, pa.nulls(len(value), type=pa.int32())], 199 | fields=timestamp_fields) 200 | elif field_type == "TIMESTAMP_TZ": 201 | timestamp_fields = [ 202 | pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), 203 | pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), 204 | pa.field("timezone", nullable=False, type=pa.int32(), metadata=metadata), 205 | ] 206 | pa_type = pa.struct(timestamp_fields) 207 | epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() 208 | 209 | value = pa.StructArray.from_arrays( 210 | arrays=[epoch, 211 | # TODO: modulos 1_000_000_000 to get the fraction of a second, pyarrow doesn't support the operator yet 212 | pa.nulls(len(value), type=pa.int32()), 213 | # TODO: reverse engineer the timezone conversion 214 | pa.nulls(len(value), type=pa.int32()), 215 | ], 216 | fields=timestamp_fields) 217 | metadata["logicalType"] = "TIMESTAMP_TZ" 218 | metadata["precision"] = "0" 219 | metadata["scale"] = "9" 220 | metadata["physicalType"] = "SB16" 221 | elif field_type == "JSON": 222 | pa_type = pa.utf8() 223 | metadata["logicalType"] = "OBJECT" 224 | metadata["charLength"] = "16777216" 225 | metadata["byteLength"] = "16777216" 226 | metadata["scale"] = "0" 227 | metadata["precision"] = "38" 228 | metadata["finalType"] = "T" 229 | elif pa_type == pa.bool_(): 230 | metadata["logicalType"] = "BOOLEAN" 231 | elif field_type == 'list': 232 | pa_type = pa.utf8() 233 | arrow_to_project = duckdb.from_arrow(arrow_table.select([field_name])) 234 | metadata["logicalType"] = "ARRAY" 235 | metadata["charLength"] = "16777216" 236 | metadata["byteLength"] = "16777216" 237 | metadata["scale"] = "0" 238 | metadata["precision"] = "38" 239 | metadata["finalType"] = "T" 240 | value = (arrow_to_project.project(f"to_json({field_name})").arrow())[0] 241 | elif pa_type == pa.string(): 242 | metadata["logicalType"] = "TEXT" 243 | metadata["charLength"] = "16777216" 244 | metadata["byteLength"] = "16777216" 245 | else: 246 | raise Exception() 247 | 248 | field = pa.field(field_name, type=pa_type, nullable=True, metadata=metadata) 249 | return value, field 250 | 251 | 252 | -------------------------------------------------------------------------------- /universql/streamlit/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="dark" 3 | primaryColor="#c3d1ea" 4 | backgroundColor="#15233c" 5 | secondaryBackgroundColor="#08110f" 6 | textColor="#bbd2b2" -------------------------------------------------------------------------------- /universql/streamlit/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | st.title("Stlite Sharing: Serverless Streamlit app platform") 4 | 5 | st.markdown(""" 6 | ### Stlite 7 | **Stlite** is a port of _Streamlit_ to Wasm, powered by Pyodide, 8 | that runs completely on web browsers. 9 | 10 | The official repository is [🔗 here](https://github.com/whitphx/stlite). 11 | 12 | If you are new to Streamlit, read the Getting Started tutorial [🔗 here](https://docs.streamlit.io/library/get-started) first 13 | (don't worry, it only takes a few minutes 👍), 14 | but **you can skip the "Installation" section** because you are here 😎. 15 | You can start writing code right out of the box on this online editor 👈! 16 | (If there is not an editor on the left, you are seeing the shared app. 17 | Navigate to the editor mode: https://edit.share.stlite.net/) 18 | 19 | ### Stlite Sharing 20 | This page is built on **Stlite Sharing**, an online code editor & sharing platform for _Stlite_. \\ 21 | If you see the editor and preview panes side by side, you are in the editor mode, https://edit.share.stlite.net/. \\ 22 | If you see only this Streamlit app, you are in the sharing mode, https://share.stlite.net/. 23 | (If you want to edit the app, please go to the [editor mode](https://edit.share.stlite.net/)!) 24 | 25 | The app code and data are encoded into the URL as a hash like `https://share.stlite.net/#!ChBz...`, 26 | so you can save, share and restore the app only this the URL. 27 | If you are on the editor page, click the "Open App" link on the top right toolbar to see the standalone app! 28 | 29 | You can switch the editor and sharing modes by replacing the host naem in the URL, 30 | `edit.share.stlite.net` and `share.stlite.net`. 31 | 32 | ### Tell your story! 33 | When you create some apps with _Stlite_, please share it! 34 | All you need to do is copy and paste the URL 👍 35 | 36 | * **Stlite** GitHub Discussions [🔗 here](https://github.com/whitphx/stlite/discussions/categories/show-and-tell) 37 | * Streamlit community forum [🔗 here](https://discuss.streamlit.io/) 38 | """) 39 | 40 | st.header("Streamlit Component Samples") 41 | st.markdown(""" 42 | All these features are working on your browser! 43 | """) 44 | 45 | name = st.text_input("Your name?") 46 | st.write("Hello,", name or "world", "!") 47 | 48 | value = st.slider("Value?") 49 | st.write("The slider value is", value) 50 | 51 | import numpy as np 52 | import pandas as pd 53 | 54 | st.subheader("Chart sample") 55 | chart_data = pd.DataFrame( 56 | np.random.randn(20, 3), 57 | columns=['a', 'b', 'c']) 58 | 59 | tab1, tab2, tab3 = st.tabs(["Line chart", "Area chart", "Bar chart"]) 60 | with tab1: 61 | st.line_chart(chart_data) 62 | with tab2: 63 | st.area_chart(chart_data) 64 | with tab3: 65 | st.bar_chart(chart_data) 66 | 67 | st.subheader("DataFrame sample") 68 | df = pd.DataFrame( 69 | np.random.randn(50, 20), 70 | columns=('col %d' % i for i in range(20))) 71 | 72 | st.dataframe(df) 73 | 74 | st.subheader("Camera input") 75 | st.info("Don't worry! The photo data is processed on your browser and never uploaded to any remote servers.") 76 | enable_camera_input = st.checkbox("Use the camera input") 77 | if enable_camera_input: 78 | picture = st.camera_input("Take a picture") 79 | 80 | if picture: 81 | st.image(picture) 82 | -------------------------------------------------------------------------------- /universql/streamlit/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | duckdb -------------------------------------------------------------------------------- /universql/warehouse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buremba/universql/4618227732018ce9d0aecdebf0029d155954976b/universql/warehouse/__init__.py -------------------------------------------------------------------------------- /universql/warehouse/bigquery.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import List 3 | 4 | import google.api_core.exceptions 5 | import sqlglot 6 | from duckdb.experimental.spark.errors import UnsupportedOperationException 7 | from google.cloud import bigquery 8 | from google.cloud.bigquery import ExternalConfig 9 | from snowflake.connector.cursor import ResultMetadataV2 10 | from snowflake.connector.options import pyarrow 11 | from universql.protocol.session import UniverSQLSession 12 | from universql.plugin import Executor, ICatalog, Locations, register 13 | from universql.util import sizeof_fmt, pprint_secs, QueryError 14 | from universql.protocol.utils import get_field_for_snowflake, arrow_to_snowflake_type_id 15 | 16 | 17 | class BigQueryIcebergExecutor(Executor): 18 | 19 | def __init__(self, catalog: "BigQueryCatalog"): 20 | super().__init__(catalog) 21 | self.query = self.result = None 22 | self.client = bigquery.Client() 23 | 24 | @staticmethod 25 | def replace_full_reference_as_table(expression: sqlglot.exp.Expression) -> sqlglot.exp.Expression: 26 | if isinstance(expression, sqlglot.exp.Table): 27 | bq_identifier = '___'.join([part.sql() for part in expression.parts]) 28 | return sqlglot.exp.parse_identifier(bq_identifier, dialect="bigquery") 29 | return expression 30 | 31 | def execute_raw(self, raw_query: str, catalog_executor, config: typing.Optional[bigquery.QueryJobConfig] = None) -> None: 32 | self.query = self.client.query(raw_query, location="europe-west2", project="jinjat-demo", 33 | job_config=config) 34 | 35 | def execute(self, ast: sqlglot.exp.Expression, catalog_executor: Executor, 36 | locations: typing.Dict[sqlglot.exp.Table, str]) -> None: 37 | sql = ast.transform(self.replace_full_reference_as_table).sql(dialect="bigquery") 38 | 39 | definitions = {'___'.join([part.sql() for part in table.parts]): 40 | BigQueryIcebergExecutor._get_config(location) for table, location in locations.items()} 41 | self.execute_raw(sql, catalog_executor) 42 | 43 | @staticmethod 44 | def _get_config(location: str) -> ExternalConfig: 45 | config = ExternalConfig("ICEBERG") 46 | config.source_uris = location.replace('gcs://', 'gs://') 47 | return config 48 | 49 | def get_as_table(self) -> pyarrow.Table: 50 | try: 51 | self.result = self.query.result(timeout=None) 52 | except google.api_core.exceptions.GoogleAPIError as e: 53 | raise QueryError(f"Unable to run BigQuery: {e.args[0]}") 54 | arrow_all = self.result.to_arrow() 55 | for idx, column in enumerate(self.result.schema): 56 | column_type = arrow_all.schema.types[idx] 57 | try: 58 | type_code = arrow_to_snowflake_type_id(column_type) 59 | except ValueError as e: 60 | raise QueryError(e.args[0]) 61 | 62 | (field, value) = get_field_for_snowflake(ResultMetadataV2(name=column.name, 63 | type_code=type_code, 64 | is_nullable=column.is_nullable, 65 | precision=column.precision, 66 | scale=column.scale, 67 | vector_dimension=None, 68 | fields=None, 69 | ), arrow_all[idx]) 70 | arrow_all = arrow_all.set_column(idx, field, value) 71 | 72 | return arrow_all 73 | 74 | def get_query_log(self, query_duration) -> str: 75 | return f"Run on BigQuery bytes billed {sizeof_fmt(self.query.total_bytes_billed)}, slot milliseconds {pprint_secs(self.query.slot_millis)}" 76 | 77 | def close(self): 78 | self.result 79 | 80 | 81 | @register(name="bigquery") 82 | class BigQueryCatalog(ICatalog): 83 | 84 | def __init__(self, session : UniverSQLSession): 85 | super().__init__(session) 86 | self.tables = None 87 | 88 | def register_locations(self, tables: Locations): 89 | self.tables = tables 90 | 91 | def get_table_paths(self, tables: List[sqlglot.exp.Table]): 92 | raise UnsupportedOperationException("BigQuery does not support registering tables") 93 | def executor(self) -> Executor: 94 | return BigQueryIcebergExecutor(self) -------------------------------------------------------------------------------- /universql/warehouse/redshift.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import List 3 | 4 | import sqlglot 5 | from duckdb.experimental.spark.errors import UnsupportedOperationException 6 | from snowflake.connector.options import pyarrow 7 | 8 | from universql.protocol.session import UniverSQLSession 9 | from universql.plugin import ICatalog, Executor, Locations, Tables, register 10 | 11 | 12 | @register(name="redshift") 13 | class RedshiftCatalog(ICatalog): 14 | def __init__(self, session: UniverSQLSession): 15 | super().__init__(session) 16 | 17 | def get_table_paths(self, tables: List[sqlglot.exp.Table]) -> Tables: 18 | raise UnsupportedOperationException("BigQuery does not support registering tables") 19 | 20 | def register_locations(self, tables: Locations): 21 | pass 22 | 23 | def executor(self) -> Executor: 24 | return RedshiftExecutor(self) 25 | 26 | class RedshiftExecutor(Executor): 27 | def __init__(self, catalog: RedshiftCatalog): 28 | super().__init__(catalog) 29 | 30 | def execute(self, ast: sqlglot.exp.Expression, catalog_executor: Executor, locations: Tables) -> typing.Optional[Locations]: 31 | return None 32 | 33 | def execute_raw(self, raw_query: str, catalog_executor: Executor) -> None: 34 | pass 35 | 36 | def get_as_table(self) -> pyarrow.Table: 37 | pass 38 | 39 | def get_query_log(self, total_duration) -> str: 40 | pass 41 | 42 | def close(self): 43 | pass 44 | -------------------------------------------------------------------------------- /universql/warehouse/snowflake.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | import typing 5 | from typing import List 6 | from uuid import uuid4 7 | 8 | import pyarrow as pa 9 | import pyiceberg.table 10 | import sentry_sdk 11 | import snowflake.connector 12 | import sqlglot 13 | from pyarrow import ArrowInvalid 14 | from pyiceberg.table import StaticTable 15 | from pyiceberg.table.snapshots import Summary, Operation 16 | from pyiceberg.typedef import IcebergBaseModel 17 | from snowflake.connector import NotSupportedError, DatabaseError, Error 18 | from snowflake.connector.constants import FieldType 19 | from universql.protocol.session import UniverSQLSession 20 | from universql.protocol.utils import get_field_for_snowflake 21 | from universql.util import SNOWFLAKE_HOST, QueryError, prepend_to_lines, get_friendly_time_since 22 | from universql.plugin import ICatalog, Executor, Locations, Tables, register 23 | 24 | MAX_LIMIT = 10000 25 | 26 | logger = logging.getLogger("❄️") 27 | 28 | logging.getLogger('snowflake.connector').setLevel(logging.WARNING) 29 | 30 | 31 | # temporary workaround until pyiceberg bug is resolved 32 | def summary_init(summary, **kwargs): 33 | operation = kwargs.get('operation', Operation.APPEND) 34 | if "operation" in kwargs: 35 | del kwargs["operation"] 36 | super(IcebergBaseModel, summary).__init__(operation=operation, **kwargs) 37 | summary._additional_properties = kwargs 38 | 39 | 40 | Summary.__init__ = summary_init 41 | 42 | @register(name="snowflake") 43 | class SnowflakeCatalog(ICatalog): 44 | 45 | def __init__(self, session : UniverSQLSession): 46 | super().__init__(session) 47 | if session.context.get('account') is not None: 48 | session.credentials["account"] = session.context.get('account') 49 | if SNOWFLAKE_HOST is not None: 50 | session.credentials["host"] = SNOWFLAKE_HOST 51 | 52 | self.databases = {} 53 | # lazily create 54 | self._cursor = None 55 | 56 | def clear_cache(self): 57 | self._cursor = None 58 | 59 | def executor(self) -> Executor: 60 | return SnowflakeExecutor(self) 61 | 62 | def cursor(self): 63 | if self._cursor is not None: 64 | return self._cursor 65 | with sentry_sdk.start_span(op="snowflake", name="Initialize Snowflake Connection"): 66 | try: 67 | self._cursor = snowflake.connector.connect(**self.credentials).cursor() 68 | except DatabaseError as e: 69 | raise QueryError(e.msg, e.sqlstate) 70 | 71 | return self._cursor 72 | 73 | def register_locations(self, tables: Locations): 74 | start_time = time.perf_counter() 75 | queries = [] 76 | for location in tables.values(): 77 | queries.append(location.sql(dialect='snowflake')) 78 | final_query = '\n'.join(queries) 79 | if final_query: 80 | logger.info(f"[{self.session_id}] Syncing Snowflake catalog \n{prepend_to_lines(final_query)}") 81 | try: 82 | self.cursor().execute(final_query) 83 | performance_counter = time.perf_counter() 84 | logger.info( 85 | f"[{self.session_id}] Synced catalog with Snowflake ❄️ " 86 | f"({get_friendly_time_since(start_time, performance_counter)})") 87 | except snowflake.connector.Error as e: 88 | raise QueryError(e.msg, e.sqlstate) 89 | 90 | def _get_ref(self, table_information) -> pyiceberg.table.Table: 91 | location = table_information.get('metadataLocation') 92 | try: 93 | return StaticTable.from_metadata(location, self.iceberg_catalog.properties) 94 | except PermissionError as e: 95 | raise QueryError(f"Unable to access Iceberg metadata {location}. Cause: \n" + str(e)) 96 | 97 | def get_table_paths(self, tables: List[sqlglot.exp.Table]) -> Tables: 98 | if len(tables) == 0: 99 | return {} 100 | cursor = self.cursor() 101 | sqls = ["SYSTEM$GET_ICEBERG_TABLE_INFORMATION(%s)" for _ in tables] 102 | values = [table.sql(comments=False, dialect="snowflake") for table in tables] 103 | final_query = f"SELECT {(', '.join(sqls))}" 104 | try: 105 | cursor.execute(final_query, values) 106 | result = cursor.fetchall() 107 | return {table: self._get_ref(json.loads(result[0][idx])) for idx, table in 108 | enumerate(tables)} 109 | except DatabaseError as e: 110 | err_message = f"Unable to find location of Iceberg tables. See: https://github.com/buremba/universql#cant-query-native-snowflake-tables. Cause: \n {e.msg} \n{final_query}" 111 | raise QueryError(err_message, e.sqlstate) 112 | 113 | def get_volume_lake_path(self, volume: str) -> str: 114 | cursor = self.cursor() 115 | cursor.execute("DESC EXTERNAL VOLUME identifier(%s)", [volume]) 116 | volume_location = cursor.fetchall() 117 | 118 | # Find the active storage location name 119 | active_storage_name = next( 120 | (item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None) 121 | 122 | # Extract the STORAGE_BASE_URL from the corresponding storage location 123 | storage_base_url = None 124 | if active_storage_name: 125 | for item in volume_location: 126 | if item[1].startswith('STORAGE_LOCATION_'): 127 | storage_data = json.loads(item[3]) 128 | if storage_data.get('NAME') == active_storage_name: 129 | storage_base_url = storage_data.get('STORAGE_BASE_URL') 130 | break 131 | 132 | if storage_base_url is None: 133 | raise QueryError(f"Unable to find storage location for volume {volume}.") 134 | 135 | return storage_base_url 136 | # def find_table_location(self, database: str, schema: str, table_name: str, lazy_check: bool = True) -> str: 137 | # table_location = self.databases.get(database, {}).get(schema, {}).get(table_name) 138 | # if table_location is None: 139 | # if lazy_check: 140 | # self.load_database_schema(database, schema) 141 | # return self.find_table_location(database, schema, table_name, lazy_check=False) 142 | # else: 143 | # raise Exception(f"Table {table_name} not found in {database}.{schema}") 144 | # return table_location 145 | # def load_external_volumes_for_tables(self, tables: pd.DataFrame) -> pd.DataFrame: 146 | # volumes = tables["external_volume_name"].unique() 147 | # 148 | # volume_mapping = {} 149 | # for volume in volumes: 150 | # volume_location = pd.read_sql("DESC EXTERNAL VOLUME identifier(%s)", self.connection, params=[volume]) 151 | # active_storage = duckdb.sql("""select property_value from volume_location 152 | # where parent_property = 'STORAGE_LOCATIONS' and property = 'ACTIVE' 153 | # """).fetchall()[0][0] 154 | # all_properties = duckdb.execute("""select property_value from volume_location 155 | # where parent_property = 'STORAGE_LOCATIONS' and property like 'STORAGE_LOCATION_%'""").fetchall() 156 | # for properties in all_properties: 157 | # loads = json.loads(properties[0]) 158 | # if loads.get('NAME') == active_storage: 159 | # volume_mapping[volume] = loads 160 | # break 161 | # return volume_mapping 162 | 163 | # def load_database_schema(self, database: str, schema: str): 164 | # tables = self.load_iceberg_tables(database, schema) 165 | # external_volumes = self.load_external_volumes_for_tables(tables) 166 | # 167 | # tables["external_location"] = tables.apply( 168 | # lambda x: (external_volumes[x["external_volume_name"]].get('STORAGE_BASE_URL') 169 | # + x["base_location"]), axis=1) 170 | # if database not in self.databases: 171 | # self.databases[database] = {} 172 | # 173 | # self.databases[database][schema] = dict(zip(tables.name, tables.external_location)) 174 | 175 | # def load_iceberg_tables(self, database: str, schema: str, after: Optional[str] = None) -> pd.DataFrame: 176 | # query = "SHOW ICEBERG TABLES IN SCHEMA IDENTIFIER(%s) LIMIT %s", [database + '.' + schema, MAX_LIMIT] 177 | # if after is not None: 178 | # query[0] += " AFTER %s" 179 | # query[1].append(after) 180 | # tables = pd.read_sql(query[0], self.connection, params=query[1]) 181 | # if len(tables.index) >= MAX_LIMIT: 182 | # after = tables.iloc[-1, :]["name"] 183 | # return tables + self.load_iceberg_tables(database, schema, after=after) 184 | # else: 185 | # return tables 186 | 187 | 188 | class SnowflakeExecutor(Executor): 189 | 190 | def __init__(self, catalog: SnowflakeCatalog): 191 | super().__init__(catalog) 192 | 193 | def _convert_snowflake_to_iceberg_type(self, snowflake_type: FieldType) -> str: 194 | if snowflake_type.name == 'TIMESTAMP_LTZ': 195 | return 'TIMESTAMP' 196 | if snowflake_type.name == 'VARIANT': 197 | # No support for semi-structured data. Maybe we should try OBJECT([SCHEMA])? 198 | return 'TEXT' 199 | if snowflake_type.name == 'ARRAY': 200 | # Relies on TO_VARIANT transformation 201 | return 'TEXT' 202 | if snowflake_type.name == 'OBJECT': 203 | # The schema is not available 204 | return 'TEXT' 205 | return snowflake_type.name 206 | 207 | def test(self): 208 | return self.catalog.cursor() 209 | 210 | def execute_raw(self, raw_query: str, catalog_executor: Executor, run_on_warehouse=None) -> None: 211 | try: 212 | emoji = "☁️(user cloud services)" if not run_on_warehouse else "💰(used warehouse)" 213 | logger.info(f"[{self.catalog.session_id}] Running on Snowflake.. {emoji} \n {raw_query}") 214 | self.catalog.cursor().execute(raw_query) 215 | except Error as e: 216 | message = f"{e.sfqid}: {e.msg} \n{raw_query}" 217 | raise QueryError(message, e.sqlstate) 218 | 219 | def execute(self, ast: sqlglot.exp.Expression, catalog_executor: Executor, locations: Tables) -> \ 220 | typing.Optional[typing.Dict[sqlglot.exp.Table, str]]: 221 | compiled_sql = (ast 222 | # .transform(self.default_create_table_as_iceberg) 223 | .sql(dialect="snowflake", pretty=True)) 224 | self.execute_raw(compiled_sql, catalog_executor) 225 | return None 226 | 227 | def get_query_log(self, total_duration) -> str: 228 | return "Run on Snowflake" 229 | 230 | def close(self): 231 | cursor = self.catalog._cursor 232 | if cursor is not None: 233 | cursor.close() 234 | 235 | def get_as_table(self) -> pa.Table: 236 | try: 237 | arrow_all = self.catalog.cursor().fetch_arrow_all(force_return_table=True) 238 | for idx, column in enumerate(self.catalog.cursor()._description): 239 | (field, value) = get_field_for_snowflake(column, arrow_all[idx]) 240 | arrow_all = arrow_all.set_column(idx, field, value) 241 | return arrow_all 242 | # return from snowflake is not using arrow 243 | except NotSupportedError: 244 | row = self.catalog.cursor().fetchone() 245 | values = [[] for _ in range(len(self.catalog.cursor()._description))] 246 | 247 | while row is not None: 248 | for idx, column in enumerate(row): 249 | values[idx].append(column) 250 | row = self.catalog.cursor().fetchone() 251 | 252 | fields = [] 253 | for idx, column in enumerate(self.catalog.cursor()._description): 254 | (field, _) = get_field_for_snowflake(column) 255 | fields.append(field) 256 | schema = pa.schema(fields) 257 | 258 | result_data = pa.Table.from_arrays([pa.array(value) for value in values], names=schema.names) 259 | 260 | for idx, column in enumerate(self.catalog.cursor()._description): 261 | (field, value) = get_field_for_snowflake(column, result_data[idx]) 262 | try: 263 | result_data = result_data.set_column(idx, field, value) 264 | except ArrowInvalid as e: 265 | # TODO: find a better approach (maybe casting?) 266 | if any(value is not None for value in values): 267 | result_data = result_data.set_column(idx, field, pa.nulls(len(result_data), field.type)) 268 | else: 269 | raise QueryError(f"Unable to transform response: {e}") 270 | 271 | return result_data --------------------------------------------------------------------------------